Skip to content

Commit 9083dd4

Browse files
authored
multi-root non-projective dependency crf with lengths mask (#87)
1 parent 0339886 commit 9083dd4

File tree

2 files changed

+83
-26
lines changed

2 files changed

+83
-26
lines changed

Diff for: torch_struct/deptree.py

+77-22
Original file line numberDiff line numberDiff line change
@@ -170,17 +170,38 @@ def from_parts(arcs):
170170
return labels, None
171171

172172

173-
def deptree_part(arc_scores, eps=1e-5):
173+
def deptree_part(arc_scores, multi_root, lengths, eps=1e-5):
174+
if lengths is not None:
175+
batch, N, N = arc_scores.shape
176+
x = torch.arange(N, device=arc_scores.device).expand(batch, N)
177+
if not torch.is_tensor(lengths):
178+
lengths = torch.tensor(lengths, device=arc_scores.device)
179+
lengths = lengths.unsqueeze(1)
180+
x = x < lengths
181+
det_offset = torch.diag_embed((~x).float())
182+
x = x.unsqueeze(2).expand(-1, -1, N)
183+
mask = torch.transpose(x, 1, 2) * x
184+
mask = mask.float()
185+
mask[mask==0] = float('-inf')
186+
mask[mask==1] = 0
187+
arc_scores = arc_scores + mask
174188
input = arc_scores
175189
eye = torch.eye(input.shape[1], device=input.device)
176190
laplacian = input.exp() + eps
177191
lap = laplacian.masked_fill(eye != 0, 0)
178192
lap = -lap + torch.diag_embed(lap.sum(1), offset=0, dim1=-2, dim2=-1)
179-
lap[:, 0] = torch.diagonal(input, 0, -2, -1).exp()
193+
if lengths is not None:
194+
lap += det_offset
195+
196+
if multi_root:
197+
rss = torch.diagonal(input, 0, -2, -1).exp() # root selection scores
198+
lap = lap + torch.diag_embed(rss, offset=0, dim1=-2, dim2=-1)
199+
else:
200+
lap[:, 0] = torch.diagonal(input, 0, -2, -1).exp()
180201
return lap.logdet()
181-
182-
183-
def deptree_nonproj(arc_scores, eps=1e-5):
202+
203+
204+
def deptree_nonproj(arc_scores, multi_root, lengths, eps=1e-5):
184205
"""
185206
Compute the marginals of a non-projective dependency tree using the
186207
matrix-tree theorem.
@@ -196,27 +217,61 @@ def deptree_nonproj(arc_scores, eps=1e-5):
196217
Returns:
197218
arc_marginals : b x N x N.
198219
"""
199-
220+
if lengths is not None:
221+
batch, N, N = arc_scores.shape
222+
x = torch.arange(N, device=arc_scores.device).expand(batch, N)
223+
if not torch.is_tensor(lengths):
224+
lengths = torch.tensor(lengths, device=arc_scores.device)
225+
lengths = lengths.unsqueeze(1)
226+
x = x < lengths
227+
det_offset = torch.diag_embed((~x).float())
228+
x = x.unsqueeze(2).expand(-1, -1, N)
229+
mask = torch.transpose(x, 1, 2) * x
230+
mask = mask.float()
231+
mask[mask==0] = float('-inf')
232+
mask[mask==1] = 0
233+
arc_scores = arc_scores + mask
234+
200235
input = arc_scores
201236
eye = torch.eye(input.shape[1], device=input.device)
202237
laplacian = input.exp() + eps
203238
lap = laplacian.masked_fill(eye != 0, 0)
204239
lap = -lap + torch.diag_embed(lap.sum(1), offset=0, dim1=-2, dim2=-1)
205-
lap[:, 0] = torch.diagonal(input, 0, -2, -1).exp()
206-
inv_laplacian = lap.inverse()
207-
factor = (
208-
torch.diagonal(inv_laplacian, 0, -2, -1)
209-
.unsqueeze(2)
210-
.expand_as(input)
211-
.transpose(1, 2)
212-
)
213-
term1 = input.exp().mul(factor).clone()
214-
term2 = input.exp().mul(inv_laplacian.transpose(1, 2)).clone()
215-
term1[:, :, 0] = 0
216-
term2[:, 0] = 0
217-
output = term1 - term2
218-
roots_output = (
219-
torch.diagonal(input, 0, -2, -1).exp().mul(inv_laplacian.transpose(1, 2)[:, 0])
220-
)
240+
if lengths is not None:
241+
lap += det_offset
242+
243+
if multi_root:
244+
rss = torch.diagonal(input, 0, -2, -1).exp() # root selection scores
245+
lap = lap + torch.diag_embed(rss, offset=0, dim1=-2, dim2=-1)
246+
inv_laplacian = lap.inverse()
247+
factor = (
248+
torch.diagonal(inv_laplacian, 0, -2, -1)
249+
.unsqueeze(2)
250+
.expand_as(input)
251+
.transpose(1, 2)
252+
)
253+
term1 = input.exp().mul(factor).clone()
254+
term2 = input.exp().mul(inv_laplacian.transpose(1, 2)).clone()
255+
output = term1 - term2
256+
roots_output = (
257+
torch.diagonal(input, 0, -2, -1).exp().mul(torch.diagonal(inv_laplacian.transpose(1, 2), 0, -2, -1))
258+
)
259+
else:
260+
lap[:, 0] = torch.diagonal(input, 0, -2, -1).exp()
261+
inv_laplacian = lap.inverse()
262+
factor = (
263+
torch.diagonal(inv_laplacian, 0, -2, -1)
264+
.unsqueeze(2)
265+
.expand_as(input)
266+
.transpose(1, 2)
267+
)
268+
term1 = input.exp().mul(factor).clone()
269+
term2 = input.exp().mul(inv_laplacian.transpose(1, 2)).clone()
270+
term1[:, :, 0] = 0
271+
term2[:, 0] = 0
272+
output = term1 - term2
273+
roots_output = (
274+
torch.diagonal(input, 0, -2, -1).exp().mul(inv_laplacian.transpose(1, 2)[:, 0])
275+
)
221276
output = output + torch.diag_embed(roots_output, 0, -2, -1)
222277
return output

Diff for: torch_struct/distributions.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -466,8 +466,10 @@ class NonProjectiveDependencyCRF(StructDistribution):
466466
Note: Does not currently implement argmax (Chiu-Liu) or sampling.
467467
468468
"""
469-
470-
struct = DepTree
469+
def __init__(self, log_potentials, lengths=None, args={}, multiroot=False):
470+
super(NonProjectiveDependencyCRF, self).__init__(log_potentials, lengths, args)
471+
self.multiroot = multiroot
472+
471473

472474
@lazy_property
473475
def marginals(self):
@@ -479,7 +481,7 @@ def marginals(self):
479481
Returns:
480482
marginals (*batch_shape x event_shape*)
481483
"""
482-
return deptree_nonproj(self.log_potentials)
484+
return deptree_nonproj(self.log_potentials, self.multiroot, self.lengths)
483485

484486
def sample(self, sample_shape=torch.Size()):
485487
raise NotImplementedError()
@@ -489,7 +491,7 @@ def partition(self):
489491
"""
490492
Compute the partition function.
491493
"""
492-
return deptree_part(self.log_potentials)
494+
return deptree_part(self.log_potentials, self.multiroot, self.lengths)
493495

494496
@lazy_property
495497
def argmax(self):

0 commit comments

Comments
 (0)