3
3
from .helpers import _make_chart
4
4
5
5
6
- def semimarkov_forward (edge , semiring = LogSemiring , lengths = None ):
6
+ def semimarkov_forward (edge , semiring = LogSemiring , lengths = None , force_grad = False ):
7
7
"""
8
8
Compute the forward pass of a semimarkov CRF.
9
9
@@ -17,12 +17,17 @@ def semimarkov_forward(edge, semiring=LogSemiring, lengths=None):
17
17
spans: list of N, b x K x C x C table
18
18
19
19
"""
20
- batch , N , K , C , _ = edge .shape
20
+ batch , N , K , C , C2 = edge .shape
21
21
if lengths is None :
22
22
lengths = torch .LongTensor ([N ] * batch )
23
+ assert max (lengths ) <= N , "Length longer than edge scores"
24
+ assert C == C2 , "Transition shape doesn't match"
25
+
23
26
spans = [None for _ in range (N )]
24
- alpha = [_make_chart ((batch , K , C ), edge , semiring ) for n in range (N + 1 )]
25
- beta = [_make_chart ((batch , C ), edge , semiring ) for n in range (N + 1 )]
27
+ alpha = [
28
+ _make_chart ((batch , K , C ), edge , semiring , force_grad ) for n in range (N + 1 )
29
+ ]
30
+ beta = [_make_chart ((batch , C ), edge , semiring , force_grad ) for n in range (N + 1 )]
26
31
beta [0 ].data .fill_ (semiring .one ())
27
32
for n in range (1 , N + 1 ):
28
33
spans [n - 1 ] = semiring .times (
@@ -33,7 +38,7 @@ def semimarkov_forward(edge, semiring=LogSemiring, lengths=None):
33
38
f1 = torch .arange (n - 1 , t , - 1 )
34
39
f2 = torch .arange (1 , len (f1 ) + 1 )
35
40
print (n - 1 , f1 , f2 )
36
- beta [n ] = semiring .sum (
41
+ beta [n ][:] = semiring .sum (
37
42
torch .stack ([alpha [a ][:, b ] for a , b in zip (f1 , f2 )]), dim = 0
38
43
)
39
44
v = semiring .sum (torch .stack ([beta [l ][i ] for i , l in enumerate (lengths )]), dim = 1 )
@@ -52,7 +57,7 @@ def semimarkov(edge, semiring=LogSemiring, lengths=None):
52
57
marginals: b x N x K x C table
53
58
54
59
"""
55
- v , spans = semimarkov_forward (edge , semiring , lengths )
60
+ v , spans = semimarkov_forward (edge , semiring , lengths , force_grad = True )
56
61
marg = torch .autograd .grad (
57
62
v .sum (dim = 0 ), spans , create_graph = True , only_inputs = True , allow_unused = False
58
63
)
0 commit comments