@@ -53,6 +53,106 @@ def sum(xs, dim=-1):
53
53
return _SampledLogSumExp .apply (xs , dim )
54
54
55
55
56
+ def GumbelMaxSemiring (temp ):
57
+ class _GumbelMaxLogSumExp (torch .autograd .Function ):
58
+ @staticmethod
59
+ def forward (ctx , input , dim ):
60
+ ctx .save_for_backward (input , torch .tensor (dim ))
61
+ return torch .logsumexp (input , dim = dim )
62
+
63
+ @staticmethod
64
+ def backward (ctx , grad_output ):
65
+ logits , dim = ctx .saved_tensors
66
+ grad_input = None
67
+ if ctx .needs_input_grad [0 ]:
68
+
69
+ def sample (ls ):
70
+ pre_shape = ls .shape
71
+ update = (
72
+ ls + torch .distributions .Gumbel (0 , 1 ).sample ((ls .shape [- 1 ],))
73
+ ) / temp
74
+ out = torch .nn .functional .one_hot (update .max (- 1 )[1 ], pre_shape [- 1 ])
75
+ return out
76
+
77
+ if dim == - 1 :
78
+ s = sample (logits )
79
+ else :
80
+ dim = dim if dim >= 0 else logits .dim () + dim
81
+ perm = [i for i in range (logits .dim ()) if i != dim ] + [dim ]
82
+ rev_perm = [
83
+ a for a , b in sorted (enumerate (perm ), key = lambda a : a [1 ])
84
+ ]
85
+ s = sample (logits .permute (perm )).permute (rev_perm )
86
+
87
+ grad_input = grad_output .unsqueeze (dim ).mul (s )
88
+ return grad_input , None
89
+
90
+ class _GumbelMaxSemiring (_BaseLog ):
91
+ @staticmethod
92
+ def sum (xs , dim = - 1 ):
93
+ return _GumbelMaxLogSumExp .apply (xs , dim )
94
+
95
+ return _GumbelMaxSemiring
96
+
97
+
98
+ def GumbelCRFSemiring (temp ):
99
+ class ST (torch .autograd .Function ):
100
+ @staticmethod
101
+ def forward (ctx , logits , dim ):
102
+ out = torch .nn .functional .one_hot (logits .max (- 1 )[1 ], dim )
103
+ out = out .type_as (logits )
104
+ ctx .save_for_backward (logits , out )
105
+ return out
106
+
107
+ @staticmethod
108
+ def backward (ctx , grad_output ):
109
+ logits , out = ctx .saved_tensors
110
+ with torch .enable_grad ():
111
+ ret = torch .autograd .grad (
112
+ logits .softmax (- 1 ), logits , out * grad_output
113
+ )[0 ]
114
+ return ret , None
115
+
116
+ class _GumbelCRFLogSumExp (torch .autograd .Function ):
117
+ @staticmethod
118
+ def forward (ctx , input , dim ):
119
+ ctx .save_for_backward (input , torch .tensor (dim ))
120
+ return torch .logsumexp (input , dim = dim )
121
+
122
+ @staticmethod
123
+ def backward (ctx , grad_output ):
124
+ logits , dim = ctx .saved_tensors
125
+ grad_input = None
126
+ if ctx .needs_input_grad [0 ]:
127
+
128
+ def sample (ls ):
129
+ update = (
130
+ ls + torch .distributions .Gumbel (0 , 1 ).sample ((ls .shape [- 1 ],))
131
+ ) / temp
132
+ out = ST .apply (update , ls .shape [- 1 ])
133
+ return out
134
+
135
+ if dim == - 1 :
136
+ s = sample (logits )
137
+ else :
138
+ dim = dim if dim >= 0 else logits .dim () + dim
139
+ perm = [i for i in range (logits .dim ()) if i != dim ] + [dim ]
140
+ rev_perm = [
141
+ a for a , b in sorted (enumerate (perm ), key = lambda a : a [1 ])
142
+ ]
143
+ s = sample (logits .permute (perm )).permute (rev_perm )
144
+
145
+ grad_input = grad_output .unsqueeze (dim ).mul (s )
146
+ return grad_input , None
147
+
148
+ class _GumbelCRFSemiring (_BaseLog ):
149
+ @staticmethod
150
+ def sum (xs , dim = - 1 ):
151
+ return _GumbelCRFLogSumExp .apply (xs , dim )
152
+
153
+ return _GumbelCRFSemiring
154
+
155
+
56
156
bits = torch .tensor ([pow (2 , i ) for i in range (1 , 18 )])
57
157
58
158
0 commit comments