Skip to content

Commit 0339886

Browse files
authored
Gumbel-CRF Semiring (#81)
1 parent f8f46ee commit 0339886

File tree

5 files changed

+135
-3
lines changed

5 files changed

+135
-3
lines changed

tests/test_algorithms.py

+14
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
LogSemiring,
44
CheckpointSemiring,
55
CheckpointShardSemiring,
6+
GumbelCRFSemiring,
67
KMaxSemiring,
78
SparseMaxSemiring,
89
MaxSemiring,
@@ -511,3 +512,16 @@ def test_lc_custom():
511512
# s2 = struct.sum(vals)
512513
# assert torch.isclose(s, s2).all()
513514
# assert torch.isclose(marginals, marginals2).all()
515+
516+
517+
@given(data())
518+
def test_gumbel(data):
519+
model = data.draw(sampled_from([LinearChain, SemiMarkov, DepTree]))
520+
semiring = GumbelCRFSemiring(1.0)
521+
test = test_lookup[model]()
522+
struct = model(semiring)
523+
vals, (batch, N) = test._rand()
524+
vals.requires_grad_(True)
525+
alpha = struct.marginals(vals)
526+
print(alpha[0])
527+
print(torch.autograd.grad(alpha, vals, alpha.detach())[0][0])

torch_struct/distributions.py

+8
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
MultiSampledSemiring,
1717
KMaxSemiring,
1818
StdSemiring,
19+
GumbelCRFSemiring,
1920
)
2021

2122

@@ -183,6 +184,13 @@ def count(self):
183184
ones[self.log_potentials.eq(-float("inf"))] = 0
184185
return self._struct(StdSemiring).sum(ones, self.lengths)
185186

187+
def gumbel_crf(self, temperature=1.0):
188+
with torch.enable_grad():
189+
st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(
190+
self.log_potentials, self.lengths
191+
)
192+
return st_gumbel
193+
186194
# @constraints.dependent_property
187195
# def support(self):
188196
# pass

torch_struct/helpers.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def backward(ctx, grad_v):
104104

105105
return DPManual.apply(edge)
106106

107-
def marginals(self, edge, lengths=None, _autograd=True, _raw=False):
107+
def marginals(self, edge, lengths=None, _autograd=True, _raw=False, _combine=False):
108108
"""
109109
Compute the marginals of a structured model.
110110
@@ -135,6 +135,13 @@ def marginals(self, edge, lengths=None, _autograd=True, _raw=False):
135135
)
136136
all_m.append(self.semiring.unconvert(self._arrange_marginals(marg)))
137137
return torch.stack(all_m, dim=0)
138+
elif _combine:
139+
obj = v.sum(dim=0).sum(dim=0)
140+
marg = torch.autograd.grad(
141+
obj, edges, create_graph=True, only_inputs=True, allow_unused=False
142+
)
143+
a_m = self._arrange_marginals(marg)
144+
return a_m
138145
else:
139146
obj = self.semiring.unconvert(v).sum(dim=0)
140147
marg = torch.autograd.grad(

torch_struct/semirings/sample.py

+100
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,106 @@ def sum(xs, dim=-1):
5353
return _SampledLogSumExp.apply(xs, dim)
5454

5555

56+
def GumbelMaxSemiring(temp):
57+
class _GumbelMaxLogSumExp(torch.autograd.Function):
58+
@staticmethod
59+
def forward(ctx, input, dim):
60+
ctx.save_for_backward(input, torch.tensor(dim))
61+
return torch.logsumexp(input, dim=dim)
62+
63+
@staticmethod
64+
def backward(ctx, grad_output):
65+
logits, dim = ctx.saved_tensors
66+
grad_input = None
67+
if ctx.needs_input_grad[0]:
68+
69+
def sample(ls):
70+
pre_shape = ls.shape
71+
update = (
72+
ls + torch.distributions.Gumbel(0, 1).sample((ls.shape[-1],))
73+
) / temp
74+
out = torch.nn.functional.one_hot(update.max(-1)[1], pre_shape[-1])
75+
return out
76+
77+
if dim == -1:
78+
s = sample(logits)
79+
else:
80+
dim = dim if dim >= 0 else logits.dim() + dim
81+
perm = [i for i in range(logits.dim()) if i != dim] + [dim]
82+
rev_perm = [
83+
a for a, b in sorted(enumerate(perm), key=lambda a: a[1])
84+
]
85+
s = sample(logits.permute(perm)).permute(rev_perm)
86+
87+
grad_input = grad_output.unsqueeze(dim).mul(s)
88+
return grad_input, None
89+
90+
class _GumbelMaxSemiring(_BaseLog):
91+
@staticmethod
92+
def sum(xs, dim=-1):
93+
return _GumbelMaxLogSumExp.apply(xs, dim)
94+
95+
return _GumbelMaxSemiring
96+
97+
98+
def GumbelCRFSemiring(temp):
99+
class ST(torch.autograd.Function):
100+
@staticmethod
101+
def forward(ctx, logits, dim):
102+
out = torch.nn.functional.one_hot(logits.max(-1)[1], dim)
103+
out = out.type_as(logits)
104+
ctx.save_for_backward(logits, out)
105+
return out
106+
107+
@staticmethod
108+
def backward(ctx, grad_output):
109+
logits, out = ctx.saved_tensors
110+
with torch.enable_grad():
111+
ret = torch.autograd.grad(
112+
logits.softmax(-1), logits, out * grad_output
113+
)[0]
114+
return ret, None
115+
116+
class _GumbelCRFLogSumExp(torch.autograd.Function):
117+
@staticmethod
118+
def forward(ctx, input, dim):
119+
ctx.save_for_backward(input, torch.tensor(dim))
120+
return torch.logsumexp(input, dim=dim)
121+
122+
@staticmethod
123+
def backward(ctx, grad_output):
124+
logits, dim = ctx.saved_tensors
125+
grad_input = None
126+
if ctx.needs_input_grad[0]:
127+
128+
def sample(ls):
129+
update = (
130+
ls + torch.distributions.Gumbel(0, 1).sample((ls.shape[-1],))
131+
) / temp
132+
out = ST.apply(update, ls.shape[-1])
133+
return out
134+
135+
if dim == -1:
136+
s = sample(logits)
137+
else:
138+
dim = dim if dim >= 0 else logits.dim() + dim
139+
perm = [i for i in range(logits.dim()) if i != dim] + [dim]
140+
rev_perm = [
141+
a for a, b in sorted(enumerate(perm), key=lambda a: a[1])
142+
]
143+
s = sample(logits.permute(perm)).permute(rev_perm)
144+
145+
grad_input = grad_output.unsqueeze(dim).mul(s)
146+
return grad_input, None
147+
148+
class _GumbelCRFSemiring(_BaseLog):
149+
@staticmethod
150+
def sum(xs, dim=-1):
151+
return _GumbelCRFLogSumExp.apply(xs, dim)
152+
153+
return _GumbelCRFSemiring
154+
155+
56156
bits = torch.tensor([pow(2, i) for i in range(1, 18)])
57157

58158

torch_struct/semirings/semirings.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,12 @@ class KLDivergenceSemiring(Semiring):
277277
278278
Based on descriptions in:
279279
280-
* Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter`
281-
* First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
280+
* Parameter estimation for probabilistic finite-state
281+
transducers :cite:`eisner2002parameter`
282+
* First-and second-order expectation semirings with applications to
283+
minimumrisk training on translation forests :cite:`li2009first`
282284
* Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf`
285+
283286
"""
284287

285288
zero = 0

0 commit comments

Comments
 (0)