@@ -66,10 +66,22 @@ def logpartition(self, arc_scores_in, lengths=None, force_grad=False):
66
66
]
67
67
for _ in range (2 )
68
68
]
69
- semiring .one_ (alpha [A ][C ][L ].data [:, :, :, 0 ].data )
70
- semiring .one_ (alpha [A ][C ][R ].data [:, :, :, 0 ].data )
71
- semiring .one_ (alpha [B ][C ][L ].data [:, :, :, - 1 ].data )
72
- semiring .one_ (alpha [B ][C ][R ].data [:, :, :, - 1 ].data )
69
+ mask = torch .zeros (alpha [A ][C ][L ].data .shape ).bool ()
70
+ mask [:, :, :, 0 ].fill_ (True )
71
+ alpha [A ][C ][L ].data [:] = semiring .fill (
72
+ alpha [A ][C ][L ].data [:], mask , semiring .one
73
+ )
74
+ alpha [A ][C ][R ].data [:] = semiring .fill (
75
+ alpha [A ][C ][R ].data [:], mask , semiring .one
76
+ )
77
+ mask = torch .zeros (alpha [B ][C ][L ].data [:].shape ).bool ()
78
+ mask [:, :, :, - 1 ].fill_ (True )
79
+ alpha [B ][C ][L ].data [:] = semiring .fill (
80
+ alpha [B ][C ][L ].data [:], mask , semiring .one
81
+ )
82
+ alpha [B ][C ][R ].data [:] = semiring .fill (
83
+ alpha [B ][C ][R ].data [:], mask , semiring .one
84
+ )
73
85
74
86
if multiroot :
75
87
start_idx = 0
@@ -119,10 +131,13 @@ def _check_potentials(self, arc_scores, lengths=None):
119
131
lengths = torch .LongTensor ([N - 1 ] * batch ).to (arc_scores .device )
120
132
assert max (lengths ) <= N , "Length longer than N"
121
133
arc_scores = semiring .convert (arc_scores )
122
- for b in range (batch ):
123
- semiring .zero_ (arc_scores [:, b , lengths [b ] + 1 :, :])
124
- semiring .zero_ (arc_scores [:, b , :, lengths [b ] + 1 :])
125
134
135
+ # Set the extra elements of the log-potentials to zero.
136
+ keep = torch .ones_like (arc_scores ).bool ()
137
+ for b in range (batch ):
138
+ keep [:, b , lengths [b ] + 1 :, :].fill_ (0.0 )
139
+ keep [:, b , :, lengths [b ] + 1 :].fill_ (0.0 )
140
+ arc_scores = semiring .fill (arc_scores , ~ keep , semiring .zero )
126
141
return arc_scores , batch , N , lengths
127
142
128
143
def _arrange_marginals (self , grads ):
0 commit comments