Skip to content

Mini-batch setting with Semi Markov CRF #110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
urchade opened this issue Oct 2, 2021 · 5 comments · Fixed by #114
Closed

Mini-batch setting with Semi Markov CRF #110

urchade opened this issue Oct 2, 2021 · 5 comments · Fixed by #114

Comments

@urchade
Copy link
Contributor

urchade commented Oct 2, 2021

I encounter learning instability when using a batch size > 1 with the semi-markovian CRF (loss goes to very large negative number), even when explicitly providing "lengths". I think the bug comes from the masking.
The model train well when setting batch size 1.

@srush
Copy link
Collaborator

srush commented Oct 4, 2021

thanks. Can you by any chance provide an example? I will take a look.

@urchade
Copy link
Contributor Author

urchade commented Oct 5, 2021

The problem also occurs during inference :

import torch, torch_struct
import matplotlib.pyplot as plt

torch.manual_seed(1)

batch, N, C, K = 3, 10, 2, 6

def show_sm(chain):
    plt.imshow(chain.detach().sum(1).sum(-1).transpose(0, 1))

log_potentials = torch.randn(batch, N, K, C, C)

# dist with and withoud mask length (we do not pad the 0th element of the batch)
dist_1 = torch_struct.SemiMarkovCRF(log_potentials)
dist_2 = torch_struct.SemiMarkovCRF(log_potentials, lengths=torch.LongTensor([N+1, 5, 1]))
dist_3 = torch_struct.SemiMarkovCRF(log_potentials, lengths=torch.LongTensor([N+1, 5, 4]))

# argmax for the 0th index should be the same for every dist since there is no padding on this index
assert torch.allclose(dist_1.argmax[0], dist_2.argmax[0])
assert torch.allclose(dist_1.argmax[0], dist_3.argmax[0])
assert torch.allclose(dist_2.argmax[0], dist_3.argmax[0])

@srush
Copy link
Collaborator

srush commented Oct 6, 2021

oh thanks, this is a useful test (and sounds like a bug)

@da03 we should fix this. Any chance you could take a first look?

@da03
Copy link
Contributor

da03 commented Oct 14, 2021

@urchade Thanks for pointing this out! It's fixed in PR #114. The issue was due to this line

mask[:, :, : end - (k - 1), k - 1, k].diagonal(0, -2, -1).fill_(True)
not considering different ending positions for sentences of different lengths.

Besides, I also added back another implementation _dp_standard for log partition calculation that's more memory-efficient, which can be used like below:

import torch, torch_struct

torch.manual_seed(1)

batch, N, C, K = 3, 10, 2, 6

log_potentials = torch.randn(batch, N, K, C, C)

dist_1 = torch_struct.SemiMarkov()
dist_2 = torch_struct.SemiMarkovCRF(log_potentials, lengths=torch.LongTensor([N+1, 5, 1]))

assert torch.allclose(dist_1._dp_standard(log_potentials, lengths=torch.LongTensor([N+1, 5, 1]))[0], dist_2.partition)

@srush
Copy link
Collaborator

srush commented Oct 14, 2021

Oh wow, impressive @da03 ! This code is really complex.

Long term let make SemiMarkovParallel and SemiMarkovFlat their own classes and let CRF pick which one to use.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants