Skip to content

Commit d75f6ec

Browse files
authored
Major docs update (#32)
1 parent 034e8a2 commit d75f6ec

10 files changed

+1942
-226
lines changed

docs/source/index.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ PyTorch-Struct
99
README
1010
model
1111
networks
12-
advanced
12+
semiring
1313
refs
1414

1515

docs/source/model.ipynb

+745-62
Large diffs are not rendered by default.

docs/source/semiring.ipynb

+733
Large diffs are not rendered by default.

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="torch_struct",
5-
version="0.2",
5+
version="0.3",
66
author="Alexander Rush",
77
author_email="arush@cornell.edu",
88
packages=["torch_struct", "torch_struct.data", "torch_struct.networks"],

torch_struct/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
TreeCRF,
99
SentCFG,
1010
AlignmentCRF,
11+
HMM,
1112
)
12-
from .autoregressive import Autoregressive
13+
from .autoregressive import Autoregressive, AutoregressiveModel
1314
from .cky_crf import CKY_CRF
1415
from .deptree import DepTree
1516
from .linearchain import LinearChain
@@ -48,12 +49,14 @@
4849
SelfCritical,
4950
StructDistribution,
5051
Autoregressive,
52+
AutoregressiveModel,
5153
LinearChainCRF,
5254
SemiMarkovCRF,
5355
DependencyCRF,
5456
NonProjectiveDependencyCRF,
5557
TreeCRF,
5658
SentCFG,
59+
HMM,
5760
AlignmentCRF,
5861
Alignment,
5962
]

torch_struct/autoregressive.py

+145-57
Original file line numberDiff line numberDiff line change
@@ -3,48 +3,33 @@
33
from torch.distributions.distribution import Distribution
44

55

6-
class AutoregressiveModel:
6+
class AutoregressiveModel(torch.nn.Module):
77
"""
88
User should implement as their favorite RNN / Transformer / etc.
99
"""
1010

11-
def sequence_logits(self, init, seq_inputs):
12-
"""
13-
Compute the logits for all tokens in a batched sequence :math:`p(y_1, ... y_{T})`
11+
def forward(self, inputs, state=None):
12+
r"""
13+
Compute the logits for all tokens in a batched sequence :math:`p(y_{t+1}, ... y_{T}| y_1 \ldots t)`
1414
1515
Parameters:
16-
init (batch_size x hidden_shape): everything needed for conditioning.
1716
inputs (batch_size x N x C): next tokens to update representation
17+
state (tuple of batch_size x ...): everything needed for conditioning.
1818
1919
Retuns:
20-
logits (batch_size x C): next set of logits.
21-
"""
22-
pass
23-
24-
def local_logits(self, state):
25-
"""
26-
Compute the local logits of :math:`p(y_t | y_{1:t-1})`
27-
28-
Parameters:
29-
state (batch_size x hidden_shape): everything needed for conditioning.
20+
logits (*batch_size x C*): next set of logits.
3021
31-
Retuns:
32-
logits (batch_size x C): next set of logits.
22+
state (*tuple of batch_size x ...*): next set of logits.
3323
"""
3424
pass
3525

36-
def update_state(self, prev_state, inputs):
37-
"""
38-
Update the model state based on previous state and inputs
3926

40-
Parameters:
41-
prev_state (batch_size x hidden_shape): everything needed for conditioning.
42-
inputs (batch_size x C): next tokens to update representation
27+
def wrap(state, ssize):
28+
return state.contiguous().view(ssize, -1, *state.shape[1:])
4329

44-
Retuns:
45-
state (batch_size x hidden_shape): everything needed for next conditioning.
46-
"""
47-
pass
30+
31+
def unwrap(state):
32+
return state.contiguous().view(-1, *state.shape[2:])
4833

4934

5035
class Autoregressive(Distribution):
@@ -56,60 +41,127 @@ class Autoregressive(Distribution):
5641
5742
Parameters:
5843
model (AutoregressiveModel): A lazily computed autoregressive model.
59-
init (tensor, batch_shape x hidden_shape): initial state of autoregressive model.
44+
init (tuple of tensors, batch_shape x ...): initial state of autoregressive model.
6045
n_classes (int): number of classes in each time step
6146
n_length (int): max length of sequence
62-
6347
"""
6448

65-
def __init__(self, model, init, n_classes, n_length):
49+
def __init__(
50+
self,
51+
model,
52+
initial_state,
53+
n_classes,
54+
n_length,
55+
normalize=True,
56+
start_class=0,
57+
end_class=None,
58+
):
6659
self.model = model
67-
self.init = init
60+
self.init = initial_state
6861
self.n_length = n_length
6962
self.n_classes = n_classes
63+
self.start_class = start_class
64+
self.normalize = normalize
7065
event_shape = (n_length, n_classes)
71-
batch_shape = init.shape[:1]
66+
batch_shape = initial_state[0].shape[:1]
7267
super().__init__(batch_shape=batch_shape, event_shape=event_shape)
7368

74-
def log_prob(self, value, normalize=True):
69+
def log_prob(self, value, sparse=False):
7570
"""
7671
Compute log probability over values :math:`p(z)`.
7772
7873
Parameters:
79-
value (tensor): One-hot events (*sample_shape x batch_shape x event_shape*)
74+
value (tensor): One-hot events (*sample_shape x batch_shape x N*)
8075
8176
Returns:
8277
log_probs (*sample_shape x batch_shape*)
8378
"""
84-
batch_shape, n_length, n_classes = value.shape
8579
value = value.long()
86-
logits = self.model.sequence_logits(self.init, value)
87-
if normalize:
80+
if not sparse:
81+
sample, batch_shape, n_length, n_classes = value.shape
82+
value = (
83+
(value * torch.arange(n_classes).view(1, 1, n_classes)).sum(-1).long()
84+
)
85+
else:
86+
sample, batch_shape, n_length = value.shape
87+
88+
value = torch.cat(
89+
[torch.zeros(sample, batch_shape, 1).fill_(self.start_class).long(), value],
90+
dim=2,
91+
)
92+
value = unwrap(value)
93+
state = tuple(
94+
(unwrap(i.unsqueeze(0).expand((sample,) + i.shape)) for i in self.init)
95+
)
96+
97+
logits, _ = self.model(value, state)
98+
b2, n2, c2 = logits.shape
99+
assert (
100+
(b2 == sample * batch_shape)
101+
and (n2 == n_length + 1)
102+
and (c2 == self.n_classes)
103+
), "Model should return logits of shape `batch x N x C` "
104+
105+
if self.normalize:
88106
log_probs = logits.log_softmax(-1)
89107
else:
90108
log_probs = logits
91109

92-
# batch_shape x event_shape (N x C)
93-
return log_probs.masked_fill_(value == 0, 0).sum(-1).sum(-1)
110+
scores = log_probs[:, :-1].gather(2, value[:, 1:].unsqueeze(-1)).sum(-1).sum(-1)
111+
return wrap(scores, sample)
94112

95-
def _beam_search(self, semiring, gumbel=True):
113+
def _beam_search(self, semiring, gumbel=False):
96114
beam = semiring.one_(torch.zeros((semiring.size(),) + self.batch_shape))
97-
state = self.init.unsqueeze(0).expand((semiring.size(),) + self.init.shape)
115+
ssize = semiring.size()
116+
117+
def take(state, indices):
118+
return tuple(
119+
(
120+
s.contiguous()[
121+
(
122+
indices * self.batch_shape[0]
123+
+ torch.arange(self.batch_shape[0]).unsqueeze(0)
124+
)
125+
.contiguous()
126+
.view(-1)
127+
]
128+
for s in state
129+
)
130+
)
131+
132+
tokens = (
133+
torch.zeros((ssize * self.batch_shape[0])).long().fill_(self.start_class)
134+
)
135+
state = tuple(
136+
(unwrap(i.unsqueeze(0).expand((ssize,) + i.shape)) for i in self.init)
137+
)
98138

99139
# Beam Search
100140
all_beams = []
101141
for t in range(0, self.n_length):
102-
logits = self.model.local_logits(state)
142+
logits, state = self.model(unwrap(tokens).unsqueeze(1), state)
143+
b2, n2, c2 = logits.shape
144+
assert (
145+
(b2 == ssize * self.batch_shape[0])
146+
and (n2 == 1)
147+
and (c2 == self.n_classes)
148+
), "Model should return logits of shape `batch x N x C` "
149+
for s in state:
150+
assert (
151+
s.shape[0] == ssize * self.batch_shape[0]
152+
), "Model should return state tuple with shapes `batch x ...` "
153+
logits = wrap(logits.squeeze(1), ssize)
103154
if gumbel:
104-
logits = logits + torch.distributions.Gumbel(0.0, 0.0).sample(
155+
logits = logits + torch.distributions.Gumbel(0.0, 1.0).sample(
105156
logits.shape
106157
)
107-
158+
if self.normalize:
159+
logits = logits.log_softmax(-1)
108160
ex_beam = beam.unsqueeze(-1) + logits
109161
ex_beam.requires_grad_(True)
110162
all_beams.append(ex_beam)
111-
beam, tokens = semiring.sparse_sum(ex_beam)
112-
state = self.model.update_state(state, tokens)
163+
beam, (positions, tokens) = semiring.sparse_sum(ex_beam)
164+
state = take(state, positions)
113165

114166
# Back pointers
115167
v = beam
@@ -121,43 +173,79 @@ def _beam_search(self, semiring, gumbel=True):
121173
)
122174
marg = torch.stack(marg, dim=2)
123175
all_m.append(marg.sum(0))
124-
return torch.stack(all_m, dim=0)
176+
return torch.stack(all_m, dim=0), v
125177

126178
def greedy_argmax(self):
127179
"""
128-
Compute "argmax" using greedy search
180+
Compute "argmax" using greedy search.
181+
182+
Returns:
183+
greedy_path (*batch x N x C*)
129184
"""
130-
return self._beam_search(MaxSemiring).squeeze(0)
185+
return self._beam_search(MaxSemiring)[0].squeeze(0)
186+
187+
def _greedy_max(self):
188+
return self._beam_search(MaxSemiring)[1].squeeze(0)
131189

132190
def beam_topk(self, K):
133191
"""
134192
Compute "top-k" using beam search
193+
194+
Returns:
195+
paths (*K x batch x N x C*)
196+
135197
"""
136-
return self._beam_search(KMaxSemiring(K))
198+
return self._beam_search(KMaxSemiring(K))[0]
199+
200+
def _beam_max(self, K):
201+
return self._beam_search(KMaxSemiring(K))[1]
137202

138203
def sample_without_replacement(self, sample_shape=torch.Size()):
139204
"""
140205
Compute sampling without replacement using Gumbel trick.
206+
207+
Based on:
208+
209+
* Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for
210+
Sampling Sequences Without Replacement :cite:`DBLP:journals/corr/abs-1903-06059`
211+
212+
Parameters:
213+
sample_shape (torch.Size): batch_size
214+
215+
Returns:
216+
paths (*K x batch x N x C*)
217+
141218
"""
142219
K = sample_shape[0]
143-
return self._beam_search(KMaxSemiring(K), gumbel=True)
220+
return self._beam_search(KMaxSemiring(K), gumbel=True)[0]
144221

145222
def sample(self, sample_shape=torch.Size()):
146223
r"""
147224
Compute structured samples from the distribution :math:`z \sim p(z)`.
148225
149226
Parameters:
150-
sample_shape (int): number of samples
227+
sample_shape (torch.Size): number of samples
151228
152229
Returns:
153230
samples (*sample_shape x batch_shape x event_shape*)
154231
"""
155232
sample_shape = sample_shape[0]
156-
state = self.init.unsqueeze(0).expand((sample_shape,) + self.init.shape)
233+
state = tuple(
234+
(
235+
unwrap(i.unsqueeze(0).expand((sample_shape,) + i.shape))
236+
for i in self.init
237+
)
238+
)
157239
all_tokens = []
240+
tokens = (
241+
torch.zeros((sample_shape * self.batch_shape[0]))
242+
.long()
243+
.fill_(self.start_class)
244+
)
158245
for t in range(0, self.n_length):
159-
logits = self.model.local_logits(state)
160-
tokens = torch.distributions.OneHotCategorical(logits).sample((1,))[0]
161-
state = self.model.update_state(state, tokens)
246+
logits, state = self.model(tokens.unsqueeze(-1), state)
247+
logits = logits.squeeze(1)
248+
tokens = torch.distributions.Categorical(logits=logits).sample((1,))[0]
162249
all_tokens.append(tokens)
163-
return torch.stack(all_tokens, dim=2)
250+
v = wrap(torch.stack(all_tokens, dim=1), sample_shape)
251+
return torch.nn.functional.one_hot(v, self.n_classes)

0 commit comments

Comments
 (0)