3
3
from torch .distributions .distribution import Distribution
4
4
5
5
6
- class AutoregressiveModel :
6
+ class AutoregressiveModel ( torch . nn . Module ) :
7
7
"""
8
8
User should implement as their favorite RNN / Transformer / etc.
9
9
"""
10
10
11
- def sequence_logits (self , init , seq_inputs ):
12
- """
13
- Compute the logits for all tokens in a batched sequence :math:`p(y_1 , ... y_{T})`
11
+ def forward (self , inputs , state = None ):
12
+ r """
13
+ Compute the logits for all tokens in a batched sequence :math:`p(y_{t+1} , ... y_{T}| y_1 \ldots t )`
14
14
15
15
Parameters:
16
- init (batch_size x hidden_shape): everything needed for conditioning.
17
16
inputs (batch_size x N x C): next tokens to update representation
17
+ state (tuple of batch_size x ...): everything needed for conditioning.
18
18
19
19
Retuns:
20
- logits (batch_size x C): next set of logits.
21
- """
22
- pass
23
-
24
- def local_logits (self , state ):
25
- """
26
- Compute the local logits of :math:`p(y_t | y_{1:t-1})`
27
-
28
- Parameters:
29
- state (batch_size x hidden_shape): everything needed for conditioning.
20
+ logits (*batch_size x C*): next set of logits.
30
21
31
- Retuns:
32
- logits (batch_size x C): next set of logits.
22
+ state (*tuple of batch_size x ...*): next set of logits.
33
23
"""
34
24
pass
35
25
36
- def update_state (self , prev_state , inputs ):
37
- """
38
- Update the model state based on previous state and inputs
39
26
40
- Parameters:
41
- prev_state (batch_size x hidden_shape): everything needed for conditioning.
42
- inputs (batch_size x C): next tokens to update representation
27
+ def wrap (state , ssize ):
28
+ return state .contiguous ().view (ssize , - 1 , * state .shape [1 :])
43
29
44
- Retuns:
45
- state (batch_size x hidden_shape): everything needed for next conditioning.
46
- """
47
- pass
30
+
31
+ def unwrap (state ):
32
+ return state .contiguous ().view (- 1 , * state .shape [2 :])
48
33
49
34
50
35
class Autoregressive (Distribution ):
@@ -56,60 +41,127 @@ class Autoregressive(Distribution):
56
41
57
42
Parameters:
58
43
model (AutoregressiveModel): A lazily computed autoregressive model.
59
- init (tensor , batch_shape x hidden_shape ): initial state of autoregressive model.
44
+ init (tuple of tensors , batch_shape x ... ): initial state of autoregressive model.
60
45
n_classes (int): number of classes in each time step
61
46
n_length (int): max length of sequence
62
-
63
47
"""
64
48
65
- def __init__ (self , model , init , n_classes , n_length ):
49
+ def __init__ (
50
+ self ,
51
+ model ,
52
+ initial_state ,
53
+ n_classes ,
54
+ n_length ,
55
+ normalize = True ,
56
+ start_class = 0 ,
57
+ end_class = None ,
58
+ ):
66
59
self .model = model
67
- self .init = init
60
+ self .init = initial_state
68
61
self .n_length = n_length
69
62
self .n_classes = n_classes
63
+ self .start_class = start_class
64
+ self .normalize = normalize
70
65
event_shape = (n_length , n_classes )
71
- batch_shape = init .shape [:1 ]
66
+ batch_shape = initial_state [ 0 ] .shape [:1 ]
72
67
super ().__init__ (batch_shape = batch_shape , event_shape = event_shape )
73
68
74
- def log_prob (self , value , normalize = True ):
69
+ def log_prob (self , value , sparse = False ):
75
70
"""
76
71
Compute log probability over values :math:`p(z)`.
77
72
78
73
Parameters:
79
- value (tensor): One-hot events (*sample_shape x batch_shape x event_shape *)
74
+ value (tensor): One-hot events (*sample_shape x batch_shape x N *)
80
75
81
76
Returns:
82
77
log_probs (*sample_shape x batch_shape*)
83
78
"""
84
- batch_shape , n_length , n_classes = value .shape
85
79
value = value .long ()
86
- logits = self .model .sequence_logits (self .init , value )
87
- if normalize :
80
+ if not sparse :
81
+ sample , batch_shape , n_length , n_classes = value .shape
82
+ value = (
83
+ (value * torch .arange (n_classes ).view (1 , 1 , n_classes )).sum (- 1 ).long ()
84
+ )
85
+ else :
86
+ sample , batch_shape , n_length = value .shape
87
+
88
+ value = torch .cat (
89
+ [torch .zeros (sample , batch_shape , 1 ).fill_ (self .start_class ).long (), value ],
90
+ dim = 2 ,
91
+ )
92
+ value = unwrap (value )
93
+ state = tuple (
94
+ (unwrap (i .unsqueeze (0 ).expand ((sample ,) + i .shape )) for i in self .init )
95
+ )
96
+
97
+ logits , _ = self .model (value , state )
98
+ b2 , n2 , c2 = logits .shape
99
+ assert (
100
+ (b2 == sample * batch_shape )
101
+ and (n2 == n_length + 1 )
102
+ and (c2 == self .n_classes )
103
+ ), "Model should return logits of shape `batch x N x C` "
104
+
105
+ if self .normalize :
88
106
log_probs = logits .log_softmax (- 1 )
89
107
else :
90
108
log_probs = logits
91
109
92
- # batch_shape x event_shape (N x C )
93
- return log_probs . masked_fill_ ( value == 0 , 0 ). sum ( - 1 ). sum ( - 1 )
110
+ scores = log_probs [:, : - 1 ]. gather ( 2 , value [:, 1 :]. unsqueeze ( - 1 )). sum ( - 1 ). sum ( - 1 )
111
+ return wrap ( scores , sample )
94
112
95
- def _beam_search (self , semiring , gumbel = True ):
113
+ def _beam_search (self , semiring , gumbel = False ):
96
114
beam = semiring .one_ (torch .zeros ((semiring .size (),) + self .batch_shape ))
97
- state = self .init .unsqueeze (0 ).expand ((semiring .size (),) + self .init .shape )
115
+ ssize = semiring .size ()
116
+
117
+ def take (state , indices ):
118
+ return tuple (
119
+ (
120
+ s .contiguous ()[
121
+ (
122
+ indices * self .batch_shape [0 ]
123
+ + torch .arange (self .batch_shape [0 ]).unsqueeze (0 )
124
+ )
125
+ .contiguous ()
126
+ .view (- 1 )
127
+ ]
128
+ for s in state
129
+ )
130
+ )
131
+
132
+ tokens = (
133
+ torch .zeros ((ssize * self .batch_shape [0 ])).long ().fill_ (self .start_class )
134
+ )
135
+ state = tuple (
136
+ (unwrap (i .unsqueeze (0 ).expand ((ssize ,) + i .shape )) for i in self .init )
137
+ )
98
138
99
139
# Beam Search
100
140
all_beams = []
101
141
for t in range (0 , self .n_length ):
102
- logits = self .model .local_logits (state )
142
+ logits , state = self .model (unwrap (tokens ).unsqueeze (1 ), state )
143
+ b2 , n2 , c2 = logits .shape
144
+ assert (
145
+ (b2 == ssize * self .batch_shape [0 ])
146
+ and (n2 == 1 )
147
+ and (c2 == self .n_classes )
148
+ ), "Model should return logits of shape `batch x N x C` "
149
+ for s in state :
150
+ assert (
151
+ s .shape [0 ] == ssize * self .batch_shape [0 ]
152
+ ), "Model should return state tuple with shapes `batch x ...` "
153
+ logits = wrap (logits .squeeze (1 ), ssize )
103
154
if gumbel :
104
- logits = logits + torch .distributions .Gumbel (0.0 , 0 .0 ).sample (
155
+ logits = logits + torch .distributions .Gumbel (0.0 , 1 .0 ).sample (
105
156
logits .shape
106
157
)
107
-
158
+ if self .normalize :
159
+ logits = logits .log_softmax (- 1 )
108
160
ex_beam = beam .unsqueeze (- 1 ) + logits
109
161
ex_beam .requires_grad_ (True )
110
162
all_beams .append (ex_beam )
111
- beam , tokens = semiring .sparse_sum (ex_beam )
112
- state = self . model . update_state (state , tokens )
163
+ beam , ( positions , tokens ) = semiring .sparse_sum (ex_beam )
164
+ state = take (state , positions )
113
165
114
166
# Back pointers
115
167
v = beam
@@ -121,43 +173,79 @@ def _beam_search(self, semiring, gumbel=True):
121
173
)
122
174
marg = torch .stack (marg , dim = 2 )
123
175
all_m .append (marg .sum (0 ))
124
- return torch .stack (all_m , dim = 0 )
176
+ return torch .stack (all_m , dim = 0 ), v
125
177
126
178
def greedy_argmax (self ):
127
179
"""
128
- Compute "argmax" using greedy search
180
+ Compute "argmax" using greedy search.
181
+
182
+ Returns:
183
+ greedy_path (*batch x N x C*)
129
184
"""
130
- return self ._beam_search (MaxSemiring ).squeeze (0 )
185
+ return self ._beam_search (MaxSemiring )[0 ].squeeze (0 )
186
+
187
+ def _greedy_max (self ):
188
+ return self ._beam_search (MaxSemiring )[1 ].squeeze (0 )
131
189
132
190
def beam_topk (self , K ):
133
191
"""
134
192
Compute "top-k" using beam search
193
+
194
+ Returns:
195
+ paths (*K x batch x N x C*)
196
+
135
197
"""
136
- return self ._beam_search (KMaxSemiring (K ))
198
+ return self ._beam_search (KMaxSemiring (K ))[0 ]
199
+
200
+ def _beam_max (self , K ):
201
+ return self ._beam_search (KMaxSemiring (K ))[1 ]
137
202
138
203
def sample_without_replacement (self , sample_shape = torch .Size ()):
139
204
"""
140
205
Compute sampling without replacement using Gumbel trick.
206
+
207
+ Based on:
208
+
209
+ * Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for
210
+ Sampling Sequences Without Replacement :cite:`DBLP:journals/corr/abs-1903-06059`
211
+
212
+ Parameters:
213
+ sample_shape (torch.Size): batch_size
214
+
215
+ Returns:
216
+ paths (*K x batch x N x C*)
217
+
141
218
"""
142
219
K = sample_shape [0 ]
143
- return self ._beam_search (KMaxSemiring (K ), gumbel = True )
220
+ return self ._beam_search (KMaxSemiring (K ), gumbel = True )[ 0 ]
144
221
145
222
def sample (self , sample_shape = torch .Size ()):
146
223
r"""
147
224
Compute structured samples from the distribution :math:`z \sim p(z)`.
148
225
149
226
Parameters:
150
- sample_shape (int ): number of samples
227
+ sample_shape (torch.Size ): number of samples
151
228
152
229
Returns:
153
230
samples (*sample_shape x batch_shape x event_shape*)
154
231
"""
155
232
sample_shape = sample_shape [0 ]
156
- state = self .init .unsqueeze (0 ).expand ((sample_shape ,) + self .init .shape )
233
+ state = tuple (
234
+ (
235
+ unwrap (i .unsqueeze (0 ).expand ((sample_shape ,) + i .shape ))
236
+ for i in self .init
237
+ )
238
+ )
157
239
all_tokens = []
240
+ tokens = (
241
+ torch .zeros ((sample_shape * self .batch_shape [0 ]))
242
+ .long ()
243
+ .fill_ (self .start_class )
244
+ )
158
245
for t in range (0 , self .n_length ):
159
- logits = self .model . local_logits ( state )
160
- tokens = torch . distributions . OneHotCategorical ( logits ). sample (( 1 ,))[ 0 ]
161
- state = self . model . update_state ( state , tokens )
246
+ logits , state = self .model ( tokens . unsqueeze ( - 1 ), state )
247
+ logits = logits . squeeze ( 1 )
248
+ tokens = torch . distributions . Categorical ( logits = logits ). sample (( 1 ,))[ 0 ]
162
249
all_tokens .append (tokens )
163
- return torch .stack (all_tokens , dim = 2 )
250
+ v = wrap (torch .stack (all_tokens , dim = 1 ), sample_shape )
251
+ return torch .nn .functional .one_hot (v , self .n_classes )
0 commit comments