-
Notifications
You must be signed in to change notification settings - Fork 92
Mini-batch setting with Semi Markov CRF #110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
thanks. Can you by any chance provide an example? I will take a look. |
The problem also occurs during inference : import torch, torch_struct
import matplotlib.pyplot as plt
torch.manual_seed(1)
batch, N, C, K = 3, 10, 2, 6
def show_sm(chain):
plt.imshow(chain.detach().sum(1).sum(-1).transpose(0, 1))
log_potentials = torch.randn(batch, N, K, C, C)
# dist with and withoud mask length (we do not pad the 0th element of the batch)
dist_1 = torch_struct.SemiMarkovCRF(log_potentials)
dist_2 = torch_struct.SemiMarkovCRF(log_potentials, lengths=torch.LongTensor([N+1, 5, 1]))
dist_3 = torch_struct.SemiMarkovCRF(log_potentials, lengths=torch.LongTensor([N+1, 5, 4]))
# argmax for the 0th index should be the same for every dist since there is no padding on this index
assert torch.allclose(dist_1.argmax[0], dist_2.argmax[0])
assert torch.allclose(dist_1.argmax[0], dist_3.argmax[0])
assert torch.allclose(dist_2.argmax[0], dist_3.argmax[0]) |
oh thanks, this is a useful test (and sounds like a bug) @da03 we should fix this. Any chance you could take a first look? |
@urchade Thanks for pointing this out! It's fixed in PR #114. The issue was due to this line pytorch-struct/torch_struct/semimarkov.py Line 67 in 5328ec5
Besides, I also added back another implementation
|
Oh wow, impressive @da03 ! This code is really complex. Long term let make SemiMarkovParallel and SemiMarkovFlat their own classes and let CRF pick which one to use. |
I encounter learning instability when using a batch size > 1 with the semi-markovian CRF (loss goes to very large negative number), even when explicitly providing "lengths". I think the bug comes from the masking.
The model train well when setting batch size 1.
The text was updated successfully, but these errors were encountered: