Skip to content

Commit e51fecc

Browse files
authored
Remove imperative filling functions _ (#105)
1 parent 84ee7cd commit e51fecc

10 files changed

+98
-141
lines changed

tests/extensions.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,17 @@ def enumerate(semiring, edge, lengths=None):
2626
semiring = semiring
2727
ssize = semiring.size()
2828
edge, batch, N, C, lengths = model._check_potentials(edge, lengths)
29-
chains = [[([c], semiring.one_(torch.zeros(ssize, batch))) for c in range(C)]]
29+
chains = [
30+
[
31+
(
32+
[c],
33+
semiring.fill(
34+
torch.zeros(ssize, batch), torch.tensor(True), semiring.one
35+
),
36+
)
37+
for c in range(C)
38+
]
39+
]
3040

3141
enum_lengths = torch.LongTensor(lengths.shape)
3242
for n in range(1, N):
@@ -128,7 +138,13 @@ def enumerate(semiring, edge):
128138
edge = semiring.convert(edge)
129139
chains = {}
130140
chains[0] = [
131-
([(c, 0)], semiring.one_(torch.zeros(ssize, batch))) for c in range(C)
141+
(
142+
[(c, 0)],
143+
semiring.fill(
144+
torch.zeros(ssize, batch), torch.tensor(True), semiring.one
145+
),
146+
)
147+
for c in range(C)
132148
]
133149

134150
for n in range(1, N + 1):

tests/test_algorithms.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def test_generic_lengths(model_test, data):
263263
part = model().sum(vals, lengths=lengths)
264264

265265
# Check that max is correct
266-
assert (maxes <= part).all()
266+
assert (maxes <= part + 1e-3).all()
267267
m_part = model(MaxSemiring).sum(vals, lengths=lengths)
268268
assert (torch.isclose(maxes, m_part)).all(), maxes - m_part
269269

torch_struct/autoregressive.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,10 @@ def log_prob(self, value, sparse=False):
118118
return wrap(scores, sample)
119119

120120
def _beam_search(self, semiring, gumbel=False):
121-
beam = semiring.one_(
122-
torch.zeros((semiring.size(),) + self.batch_shape, device=self.device)
121+
beam = semiring.fill(
122+
torch.zeros((semiring.size(),) + self.batch_shape, device=self.device),
123+
torch.tensor(True),
124+
semiring.one,
123125
)
124126
ssize = semiring.size()
125127

torch_struct/deptree.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,22 @@ def logpartition(self, arc_scores_in, lengths=None, force_grad=False):
6666
]
6767
for _ in range(2)
6868
]
69-
semiring.one_(alpha[A][C][L].data[:, :, :, 0].data)
70-
semiring.one_(alpha[A][C][R].data[:, :, :, 0].data)
71-
semiring.one_(alpha[B][C][L].data[:, :, :, -1].data)
72-
semiring.one_(alpha[B][C][R].data[:, :, :, -1].data)
69+
mask = torch.zeros(alpha[A][C][L].data.shape).bool()
70+
mask[:, :, :, 0].fill_(True)
71+
alpha[A][C][L].data[:] = semiring.fill(
72+
alpha[A][C][L].data[:], mask, semiring.one
73+
)
74+
alpha[A][C][R].data[:] = semiring.fill(
75+
alpha[A][C][R].data[:], mask, semiring.one
76+
)
77+
mask = torch.zeros(alpha[B][C][L].data[:].shape).bool()
78+
mask[:, :, :, -1].fill_(True)
79+
alpha[B][C][L].data[:] = semiring.fill(
80+
alpha[B][C][L].data[:], mask, semiring.one
81+
)
82+
alpha[B][C][R].data[:] = semiring.fill(
83+
alpha[B][C][R].data[:], mask, semiring.one
84+
)
7385

7486
if multiroot:
7587
start_idx = 0
@@ -119,10 +131,13 @@ def _check_potentials(self, arc_scores, lengths=None):
119131
lengths = torch.LongTensor([N - 1] * batch).to(arc_scores.device)
120132
assert max(lengths) <= N, "Length longer than N"
121133
arc_scores = semiring.convert(arc_scores)
122-
for b in range(batch):
123-
semiring.zero_(arc_scores[:, b, lengths[b] + 1 :, :])
124-
semiring.zero_(arc_scores[:, b, :, lengths[b] + 1 :])
125134

135+
# Set the extra elements of the log-potentials to zero.
136+
keep = torch.ones_like(arc_scores).bool()
137+
for b in range(batch):
138+
keep[:, b, lengths[b] + 1 :, :].fill_(0.0)
139+
keep[:, b, :, lengths[b] + 1 :].fill_(0.0)
140+
arc_scores = semiring.fill(arc_scores, ~keep, semiring.zero)
126141
return arc_scores, batch, N, lengths
127142

128143
def _arrange_marginals(self, grads):

torch_struct/distributions.py

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class StructDistribution(Distribution):
3636
log_potentials (tensor, batch_shape x event_shape) : log-potentials :math:`\phi`
3737
lengths (long tensor, batch_shape) : integers for length masking
3838
"""
39+
validate_args = False
3940

4041
def __init__(self, log_potentials, lengths=None, args={}):
4142
batch_shape = log_potentials.shape[:1]

torch_struct/helpers.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55

66
class Chart:
77
def __init__(self, size, potentials, semiring):
8-
self.data = semiring.zero_(
9-
torch.zeros(
10-
*((semiring.size(),) + size),
11-
dtype=potentials.dtype,
12-
device=potentials.device
13-
)
8+
c = torch.zeros(
9+
*((semiring.size(),) + size),
10+
dtype=potentials.dtype,
11+
device=potentials.device
1412
)
13+
c[:] = semiring.zero.view((semiring.size(),) + len(size) * (1,))
14+
15+
self.data = c
1516
self.grad = self.data.detach().clone().fill_(0.0)
1617

1718
def __getitem__(self, ind):
@@ -50,18 +51,17 @@ def _chart(self, size, potentials, force_grad):
5051
return self._make_chart(1, size, potentials, force_grad)[0]
5152

5253
def _make_chart(self, N, size, potentials, force_grad=False):
53-
return [
54-
(
55-
self.semiring.zero_(
56-
torch.zeros(
57-
*((self.semiring.size(),) + size),
58-
dtype=potentials.dtype,
59-
device=potentials.device
60-
)
61-
).requires_grad_(force_grad and not potentials.requires_grad)
54+
chart = []
55+
for _ in range(N):
56+
c = torch.zeros(
57+
*((self.semiring.size(),) + size),
58+
dtype=potentials.dtype,
59+
device=potentials.device
6260
)
63-
for _ in range(N)
64-
]
61+
c[:] = self.semiring.zero.view((self.semiring.size(),) + len(size) * (1,))
62+
c.requires_grad_(force_grad and not potentials.requires_grad)
63+
chart.append(c)
64+
return chart
6565

6666
def sum(self, logpotentials, lengths=None, _raw=False):
6767
"""

torch_struct/linearchain.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
5353
chart = self._chart((batch, bin_N, C, C), log_potentials, force_grad)
5454

5555
# Init
56-
semiring.one_(chart[:, :, :].diagonal(0, 3, 4))
56+
init = torch.zeros(*chart.shape).bool()
57+
init.diagonal(0, 3, 4).fill_(True)
58+
chart = semiring.fill(chart, init, semiring.one)
5759

5860
# Length mask
5961
big = torch.zeros(
@@ -71,8 +73,8 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
7173
mask = torch.arange(bin_N).view(1, bin_N).expand(batch, bin_N).type_as(c)
7274
mask = mask >= (lengths - 1).view(batch, 1)
7375
mask = mask.view(batch * bin_N, 1, 1).to(lp.device)
74-
semiring.zero_mask_(lp.data, mask)
75-
semiring.zero_mask_(c.data, (~mask))
76+
lp.data[:] = semiring.fill(lp.data, mask, semiring.zero)
77+
c.data[:] = semiring.fill(c.data, ~mask, semiring.zero)
7678

7779
c[:] = semiring.sum(torch.stack([c.data, lp], dim=-1))
7880

torch_struct/semimarkov.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
3434
)
3535

3636
# Init.
37-
semiring.one_(init.data[:, :, :, 0, 0].diagonal(0, -2, -1))
37+
mask = torch.zeros(*init.shape).bool()
38+
mask[:, :, :, 0, 0].diagonal(0, -2, -1).fill_(True)
39+
init = semiring.fill(init, mask, semiring.one)
3840

3941
# Length mask
4042
big = torch.zeros(
@@ -54,16 +56,16 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
5456
mask = mask.to(log_potentials.device)
5557
mask = mask >= (lengths - 1).view(batch, 1)
5658
mask = mask.view(batch * bin_N, 1, 1, 1).to(lp.device)
57-
semiring.zero_mask_(lp.data, mask)
58-
semiring.zero_mask_(c.data[:, :, :, 0], (~mask))
59+
lp.data[:] = semiring.fill(lp.data, mask, semiring.zero)
60+
c.data[:, :, :, 0] = semiring.fill(c.data[:, :, :, 0], (~mask), semiring.zero)
5961
c[:, :, : K - 1, 0] = semiring.sum(
6062
torch.stack([c.data[:, :, : K - 1, 0], lp[:, :, 1:K]], dim=-1)
6163
)
6264
end = torch.min(lengths) - 1
65+
mask = torch.zeros(*init.shape).bool()
6366
for k in range(1, K - 1):
64-
semiring.one_(
65-
init.data[:, :, : end - (k - 1), k - 1, k].diagonal(0, -2, -1)
66-
)
67+
mask[:, :, : end - (k - 1), k - 1, k].diagonal(0, -2, -1).fill_(True)
68+
init = semiring.fill(init, mask, semiring.one)
6769

6870
K_1 = K - 1
6971

torch_struct/semirings/checkpoint.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
try:
55
import genbmm
66
from genbmm import BandedMatrix
7+
78
has_genbmm = True
89
except ImportError:
910
pass

0 commit comments

Comments
 (0)