@@ -170,17 +170,38 @@ def from_parts(arcs):
170
170
return labels , None
171
171
172
172
173
- def deptree_part (arc_scores , eps = 1e-5 ):
173
+ def deptree_part (arc_scores , multi_root , lengths , eps = 1e-5 ):
174
+ if lengths is not None :
175
+ batch , N , N = arc_scores .shape
176
+ x = torch .arange (N , device = arc_scores .device ).expand (batch , N )
177
+ if not torch .is_tensor (lengths ):
178
+ lengths = torch .tensor (lengths , device = arc_scores .device )
179
+ lengths = lengths .unsqueeze (1 )
180
+ x = x < lengths
181
+ det_offset = torch .diag_embed ((~ x ).float ())
182
+ x = x .unsqueeze (2 ).expand (- 1 , - 1 , N )
183
+ mask = torch .transpose (x , 1 , 2 ) * x
184
+ mask = mask .float ()
185
+ mask [mask == 0 ] = float ('-inf' )
186
+ mask [mask == 1 ] = 0
187
+ arc_scores = arc_scores + mask
174
188
input = arc_scores
175
189
eye = torch .eye (input .shape [1 ], device = input .device )
176
190
laplacian = input .exp () + eps
177
191
lap = laplacian .masked_fill (eye != 0 , 0 )
178
192
lap = - lap + torch .diag_embed (lap .sum (1 ), offset = 0 , dim1 = - 2 , dim2 = - 1 )
179
- lap [:, 0 ] = torch .diagonal (input , 0 , - 2 , - 1 ).exp ()
193
+ if lengths is not None :
194
+ lap += det_offset
195
+
196
+ if multi_root :
197
+ rss = torch .diagonal (input , 0 , - 2 , - 1 ).exp () # root selection scores
198
+ lap = lap + torch .diag_embed (rss , offset = 0 , dim1 = - 2 , dim2 = - 1 )
199
+ else :
200
+ lap [:, 0 ] = torch .diagonal (input , 0 , - 2 , - 1 ).exp ()
180
201
return lap .logdet ()
181
-
182
-
183
- def deptree_nonproj (arc_scores , eps = 1e-5 ):
202
+
203
+
204
+ def deptree_nonproj (arc_scores , multi_root , lengths , eps = 1e-5 ):
184
205
"""
185
206
Compute the marginals of a non-projective dependency tree using the
186
207
matrix-tree theorem.
@@ -196,27 +217,61 @@ def deptree_nonproj(arc_scores, eps=1e-5):
196
217
Returns:
197
218
arc_marginals : b x N x N.
198
219
"""
199
-
220
+ if lengths is not None :
221
+ batch , N , N = arc_scores .shape
222
+ x = torch .arange (N , device = arc_scores .device ).expand (batch , N )
223
+ if not torch .is_tensor (lengths ):
224
+ lengths = torch .tensor (lengths , device = arc_scores .device )
225
+ lengths = lengths .unsqueeze (1 )
226
+ x = x < lengths
227
+ det_offset = torch .diag_embed ((~ x ).float ())
228
+ x = x .unsqueeze (2 ).expand (- 1 , - 1 , N )
229
+ mask = torch .transpose (x , 1 , 2 ) * x
230
+ mask = mask .float ()
231
+ mask [mask == 0 ] = float ('-inf' )
232
+ mask [mask == 1 ] = 0
233
+ arc_scores = arc_scores + mask
234
+
200
235
input = arc_scores
201
236
eye = torch .eye (input .shape [1 ], device = input .device )
202
237
laplacian = input .exp () + eps
203
238
lap = laplacian .masked_fill (eye != 0 , 0 )
204
239
lap = - lap + torch .diag_embed (lap .sum (1 ), offset = 0 , dim1 = - 2 , dim2 = - 1 )
205
- lap [:, 0 ] = torch .diagonal (input , 0 , - 2 , - 1 ).exp ()
206
- inv_laplacian = lap .inverse ()
207
- factor = (
208
- torch .diagonal (inv_laplacian , 0 , - 2 , - 1 )
209
- .unsqueeze (2 )
210
- .expand_as (input )
211
- .transpose (1 , 2 )
212
- )
213
- term1 = input .exp ().mul (factor ).clone ()
214
- term2 = input .exp ().mul (inv_laplacian .transpose (1 , 2 )).clone ()
215
- term1 [:, :, 0 ] = 0
216
- term2 [:, 0 ] = 0
217
- output = term1 - term2
218
- roots_output = (
219
- torch .diagonal (input , 0 , - 2 , - 1 ).exp ().mul (inv_laplacian .transpose (1 , 2 )[:, 0 ])
220
- )
240
+ if lengths is not None :
241
+ lap += det_offset
242
+
243
+ if multi_root :
244
+ rss = torch .diagonal (input , 0 , - 2 , - 1 ).exp () # root selection scores
245
+ lap = lap + torch .diag_embed (rss , offset = 0 , dim1 = - 2 , dim2 = - 1 )
246
+ inv_laplacian = lap .inverse ()
247
+ factor = (
248
+ torch .diagonal (inv_laplacian , 0 , - 2 , - 1 )
249
+ .unsqueeze (2 )
250
+ .expand_as (input )
251
+ .transpose (1 , 2 )
252
+ )
253
+ term1 = input .exp ().mul (factor ).clone ()
254
+ term2 = input .exp ().mul (inv_laplacian .transpose (1 , 2 )).clone ()
255
+ output = term1 - term2
256
+ roots_output = (
257
+ torch .diagonal (input , 0 , - 2 , - 1 ).exp ().mul (torch .diagonal (inv_laplacian .transpose (1 , 2 ), 0 , - 2 , - 1 ))
258
+ )
259
+ else :
260
+ lap [:, 0 ] = torch .diagonal (input , 0 , - 2 , - 1 ).exp ()
261
+ inv_laplacian = lap .inverse ()
262
+ factor = (
263
+ torch .diagonal (inv_laplacian , 0 , - 2 , - 1 )
264
+ .unsqueeze (2 )
265
+ .expand_as (input )
266
+ .transpose (1 , 2 )
267
+ )
268
+ term1 = input .exp ().mul (factor ).clone ()
269
+ term2 = input .exp ().mul (inv_laplacian .transpose (1 , 2 )).clone ()
270
+ term1 [:, :, 0 ] = 0
271
+ term2 [:, 0 ] = 0
272
+ output = term1 - term2
273
+ roots_output = (
274
+ torch .diagonal (input , 0 , - 2 , - 1 ).exp ().mul (inv_laplacian .transpose (1 , 2 )[:, 0 ])
275
+ )
221
276
output = output + torch .diag_embed (roots_output , 0 , - 2 , - 1 )
222
277
return output
0 commit comments