Skip to content

Commit b2d7fcd

Browse files
author
Sasha
committed
Add some more tests
1 parent e3837bb commit b2d7fcd

File tree

8 files changed

+80
-26
lines changed

8 files changed

+80
-26
lines changed

examples/supervised.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import torchtext
2+
3+
torchtext.datsets.UDPos

setup.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
version="0.0.1",
66
author="Alexander Rush",
77
author_email="arush@cornell.edu",
8-
packages=["torch_struct", ],
8+
packages=["torch_struct"],
99
package_data={"torch_struct": []},
1010
url="https://github.com/harvardnlp/pytorch_struct",
1111
install_requires=["torch"],
1212
setup_requires=["pytest-runner"],
13-
tests_require=["pytest"]
14-
13+
tests_require=["pytest"],
1514
)

torch_struct/cky.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
A, B = 0, 1
66

77

8-
def cky_inside(terms, rules, roots, semiring=LogSemiring, lengths=None):
8+
def cky_inside(
9+
terms, rules, roots, semiring=LogSemiring, lengths=None, force_grad=False
10+
):
911
"""
1012
Compute the inside pass of a CFG using CKY.
1113
@@ -23,9 +25,14 @@ def cky_inside(terms, rules, roots, semiring=LogSemiring, lengths=None):
2325
_, NT, _, _ = rules.shape
2426
if lengths is None:
2527
lengths = torch.LongTensor([N] * batch)
26-
beta = [_make_chart((batch, N, N, NT + T), rules, semiring) for _ in range(2)]
27-
28-
span = [_make_chart((batch, N, NT + T), rules, semiring) for _ in range(N)]
28+
beta = [
29+
_make_chart((batch, N, N, NT + T), rules, semiring, force_grad)
30+
for _ in range(2)
31+
]
32+
33+
span = [
34+
_make_chart((batch, N, NT + T), rules, semiring, force_grad) for _ in range(N)
35+
]
2936
rule_use = [None for _ in range(N - 1)]
3037
term_use = terms.requires_grad_(True)
3138
beta[A][:, :, 0, NT:] = term_use

torch_struct/deptree.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _unconvert(logits):
3434
A, B, R, C, L, I = 0, 1, 1, 1, 0, 0
3535

3636

37-
def deptree_inside(arc_scores, semiring=LogSemiring, lengths=None):
37+
def deptree_inside(arc_scores, semiring=LogSemiring, lengths=None, force_grad=False):
3838
"""
3939
Compute the inside pass of a projective dependency CRF.
4040
@@ -49,10 +49,12 @@ def deptree_inside(arc_scores, semiring=LogSemiring, lengths=None):
4949
5050
"""
5151
arc_scores = _convert(arc_scores)
52-
batch, N, _ = arc_scores.shape
52+
batch, N, N2 = arc_scores.shape
53+
assert N == N2, "Non-square potentials"
5354
DIRS = 2
5455
if lengths is None:
5556
lengths = torch.LongTensor([N] * batch)
57+
assert max(lengths) <= N, "Length longer than N"
5658

5759
def stack(a, b):
5860
return torch.stack([a, b])
@@ -61,10 +63,16 @@ def sstack(a):
6163
return torch.stack([a, a])
6264

6365
alpha = [
64-
[_make_chart((DIRS, batch, N, N), arc_scores, semiring) for _ in [I, C]]
66+
[
67+
_make_chart((DIRS, batch, N, N), arc_scores, semiring, force_grad)
68+
for _ in [I, C]
69+
]
6570
for _ in range(2)
6671
]
67-
arcs = [_make_chart((DIRS, batch, N), arc_scores, semiring) for _ in range(N)]
72+
arcs = [
73+
_make_chart((DIRS, batch, N), arc_scores, semiring, force_grad)
74+
for _ in range(N)
75+
]
6876

6977
# Inside step. assumes first token is root symbol
7078
alpha[A][C][:, :, :, 0].data.fill_(semiring.one())
@@ -108,7 +116,7 @@ def deptree(arc_scores, semiring=LogSemiring, lengths=None):
108116
"""
109117
batch, N, _ = arc_scores.shape
110118
N = N + 1
111-
v, arcs = deptree_inside(arc_scores, semiring, lengths)
119+
v, arcs = deptree_inside(arc_scores, semiring, lengths, force_grad=True)
112120
grads = torch.autograd.grad(
113121
v.sum(dim=0), arcs[1:], create_graph=True, only_inputs=True, allow_unused=False
114122
)

torch_struct/helpers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import torch
22

33

4-
def _make_chart(size, potentials, semiring):
4+
def _make_chart(size, potentials, semiring, force_grad):
55
return (
66
torch.zeros(*size)
77
.type_as(potentials)
88
.fill_(semiring.zero())
9-
.requires_grad_(True)
9+
.requires_grad_(force_grad and not potentials.requires_grad)
1010
)

torch_struct/linearchain.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .helpers import _make_chart
44

55

6-
def linearchain_forward(edge, semiring=LogSemiring, lengths=None):
6+
def linearchain_forward(edge, semiring=LogSemiring, lengths=None, force_grad=False):
77
"""
88
Compute the forward pass of a linear chain CRF.
99
@@ -18,10 +18,16 @@ def linearchain_forward(edge, semiring=LogSemiring, lengths=None):
1818
inside: list of N, b x C x C table
1919
2020
"""
21-
batch, N, C, _ = edge.shape
21+
batch, N, C, C2 = edge.shape
2222
if lengths is None:
2323
lengths = torch.LongTensor([N] * batch)
24-
alpha = [_make_chart((batch, C), edge, semiring) for n in range(N + 1)]
24+
assert max(lengths) <= N, "Length longer than edge scores"
25+
assert C == C2, "Transition shape doesn't match"
26+
27+
alpha = [
28+
_make_chart((batch, C), edge, semiring, force_grad=force_grad)
29+
for n in range(N + 1)
30+
]
2531
edge_store = [None for _ in range(N)]
2632
alpha[0].data.fill_(semiring.one())
2733
for n in range(1, N + 1):
@@ -33,20 +39,20 @@ def linearchain_forward(edge, semiring=LogSemiring, lengths=None):
3339
return v, edge_store
3440

3541

36-
def linearchain(edge, semiring=LogSemiring):
42+
def linearchain(edge, semiring=LogSemiring, lengths=None):
3743
"""
3844
Compute the marginals of a linear chain CRF.
3945
4046
Parameters:
4147
edge : b x N x C x C markov potentials
4248
(t x z_t x z_{t-1})
4349
semiring
44-
50+
lengths: None or b long tensor mask
4551
Returns:
4652
marginals: b x N x C x C table
4753
4854
"""
49-
v, alpha = linearchain_forward(edge, semiring)
55+
v, alpha = linearchain_forward(edge, semiring, force_grad=True)
5056
marg = torch.autograd.grad(
5157
v.sum(dim=0), alpha, create_graph=True, only_inputs=True, allow_unused=False
5258
)

torch_struct/semimarkov.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .helpers import _make_chart
44

55

6-
def semimarkov_forward(edge, semiring=LogSemiring, lengths=None):
6+
def semimarkov_forward(edge, semiring=LogSemiring, lengths=None, force_grad=False):
77
"""
88
Compute the forward pass of a semimarkov CRF.
99
@@ -17,12 +17,17 @@ def semimarkov_forward(edge, semiring=LogSemiring, lengths=None):
1717
spans: list of N, b x K x C x C table
1818
1919
"""
20-
batch, N, K, C, _ = edge.shape
20+
batch, N, K, C, C2 = edge.shape
2121
if lengths is None:
2222
lengths = torch.LongTensor([N] * batch)
23+
assert max(lengths) <= N, "Length longer than edge scores"
24+
assert C == C2, "Transition shape doesn't match"
25+
2326
spans = [None for _ in range(N)]
24-
alpha = [_make_chart((batch, K, C), edge, semiring) for n in range(N + 1)]
25-
beta = [_make_chart((batch, C), edge, semiring) for n in range(N + 1)]
27+
alpha = [
28+
_make_chart((batch, K, C), edge, semiring, force_grad) for n in range(N + 1)
29+
]
30+
beta = [_make_chart((batch, C), edge, semiring, force_grad) for n in range(N + 1)]
2631
beta[0].data.fill_(semiring.one())
2732
for n in range(1, N + 1):
2833
spans[n - 1] = semiring.times(
@@ -33,7 +38,7 @@ def semimarkov_forward(edge, semiring=LogSemiring, lengths=None):
3338
f1 = torch.arange(n - 1, t, -1)
3439
f2 = torch.arange(1, len(f1) + 1)
3540
print(n - 1, f1, f2)
36-
beta[n] = semiring.sum(
41+
beta[n][:] = semiring.sum(
3742
torch.stack([alpha[a][:, b] for a, b in zip(f1, f2)]), dim=0
3843
)
3944
v = semiring.sum(torch.stack([beta[l][i] for i, l in enumerate(lengths)]), dim=1)
@@ -52,7 +57,7 @@ def semimarkov(edge, semiring=LogSemiring, lengths=None):
5257
marginals: b x N x K x C table
5358
5459
"""
55-
v, spans = semimarkov_forward(edge, semiring, lengths)
60+
v, spans = semimarkov_forward(edge, semiring, lengths, force_grad=True)
5661
marg = torch.autograd.grad(
5762
v.sum(dim=0), spans, create_graph=True, only_inputs=True, allow_unused=False
5863
)

torch_struct/test_algorithms.py

+26
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ def test_linearchain(batch, N, C):
4141
assert torch.isclose(score.sum(), marginals.mul(vals).sum()).all()
4242

4343

44+
@given(smint, smint, smint)
45+
def test_params(batch, N, C):
46+
vals = torch.ones(batch, N, C, C, requires_grad=True)
47+
semiring = StdSemiring
48+
alpha, _ = linearchain_forward(vals, semiring)
49+
alpha.sum().backward()
50+
51+
4452
def test_hmm():
4553
C, V, batch, N = 5, 20, 2, 5
4654
transition = torch.rand(C, C)
@@ -84,6 +92,13 @@ def test_dep(N):
8492
assert torch.isclose(score.sum(), marginals.mul(scores).sum()).all()
8593

8694

95+
def test_dep_params():
96+
batch, N = 2, 2
97+
scores = torch.rand(batch, N, N, requires_grad=True)
98+
top, arcs = deptree_inside(scores)
99+
top.sum().backward()
100+
101+
87102
def test_dep_np():
88103
N = 5
89104
batch = 2
@@ -114,3 +129,14 @@ def test_cky(N, NT, T):
114129
+ m_root.mul(roots).sum()
115130
).sum(),
116131
).all()
132+
133+
134+
@given(smint, tint, tint)
135+
@settings(max_examples=3)
136+
def test_cky_params(N, NT, T):
137+
batch = 2
138+
terms = torch.rand(batch, N, T)
139+
rules = torch.rand(batch, NT, (NT + T), (NT + T), requires_grad=True)
140+
roots = torch.rand(batch, NT, requires_grad=True)
141+
v, _ = cky_inside(terms, rules, roots)
142+
v.sum().backward()

0 commit comments

Comments
 (0)