Skip to content

Commit 7146de5

Browse files
authored
add initial pass at arg_constraints and arg validation. (#123)
1 parent 29faea2 commit 7146de5

File tree

1 file changed

+28
-14
lines changed

1 file changed

+28
-14
lines changed

Diff for: torch_struct/distributions.py

+28-14
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from torch.distributions import constraints
23
from torch.distributions.distribution import Distribution
34
from torch.distributions.utils import lazy_property
45
from .linearchain import LinearChain
@@ -36,15 +37,18 @@ class StructDistribution(Distribution):
3637
log_potentials (tensor, batch_shape x event_shape) : log-potentials :math:`\phi`
3738
lengths (long tensor, batch_shape) : integers for length masking
3839
"""
39-
validate_args = False
40+
arg_constraints = {
41+
"log_potentials": constraints.real,
42+
"lengths": constraints.nonnegative_integer
43+
}
4044

41-
def __init__(self, log_potentials, lengths=None, args={}):
45+
def __init__(self, log_potentials, lengths=None, args={}, validate_args=False):
4246
batch_shape = log_potentials.shape[:1]
4347
event_shape = log_potentials.shape[1:]
4448
self.log_potentials = log_potentials
4549
self.lengths = lengths
4650
self.args = args
47-
super().__init__(batch_shape=batch_shape, event_shape=event_shape)
51+
super().__init__(batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args)
4852

4953
def _new(self, *args, **kwargs):
5054
return self._param.new(*args, **kwargs)
@@ -295,11 +299,17 @@ class AlignmentCRF(StructDistribution):
295299
296300
"""
297301
struct = Alignment
298-
299-
def __init__(self, log_potentials, local=False, lengths=None, max_gap=None):
302+
arg_constraints = {
303+
"log_potentials": constraints.real,
304+
"local": constraints.boolean,
305+
"max_gap": constraints.nonnegative_integer,
306+
"lengths": constraints.nonnegative_integer
307+
}
308+
309+
def __init__(self, log_potentials, local=False, lengths=None, max_gap=None, validate_args=False):
300310
self.local = local
301311
self.max_gap = max_gap
302-
super().__init__(log_potentials, lengths)
312+
super().__init__(log_potentials, lengths, validate_args=validate_args)
303313

304314
def _struct(self, sr=None):
305315
return self.struct(
@@ -324,9 +334,9 @@ class HMM(StructDistribution):
324334
Implemented as a special case of linear chain CRF.
325335
"""
326336

327-
def __init__(self, transition, emission, init, observations, lengths=None):
337+
def __init__(self, transition, emission, init, observations, lengths=None, validate_args=False):
328338
log_potentials = HMM.struct.hmm(transition, emission, init, observations)
329-
super().__init__(log_potentials, lengths)
339+
super().__init__(log_potentials, lengths, validate_args=validate_args)
330340

331341
struct = LinearChain
332342

@@ -380,8 +390,8 @@ class DependencyCRF(StructDistribution):
380390
381391
"""
382392

383-
def __init__(self, log_potentials, lengths=None, args={}, multiroot=True):
384-
super(DependencyCRF, self).__init__(log_potentials, lengths, args)
393+
def __init__(self, log_potentials, lengths=None, args={}, multiroot=True, validate_args=False):
394+
super(DependencyCRF, self).__init__(log_potentials, lengths, args, validate_args=validate_args)
385395
self.struct = DepTree
386396
setattr(self.struct, "multiroot", multiroot)
387397

@@ -436,13 +446,13 @@ class SentCFG(StructDistribution):
436446

437447
struct = CKY
438448

439-
def __init__(self, log_potentials, lengths=None):
449+
def __init__(self, log_potentials, lengths=None, validate_args=False):
440450
batch_shape = log_potentials[0].shape[:1]
441451
event_shape = log_potentials[0].shape[1:]
442452
self.log_potentials = log_potentials
443453
self.lengths = lengths
444454
super(StructDistribution, self).__init__(
445-
batch_shape=batch_shape, event_shape=event_shape
455+
batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args
446456
)
447457

448458

@@ -468,8 +478,12 @@ class NonProjectiveDependencyCRF(StructDistribution):
468478
469479
"""
470480

471-
def __init__(self, log_potentials, lengths=None, args={}, multiroot=False):
472-
super(NonProjectiveDependencyCRF, self).__init__(log_potentials, lengths, args)
481+
arg_constraints = {
482+
"log_potentials": constraints.real
483+
}
484+
485+
def __init__(self, log_potentials, lengths=None, args={}, multiroot=False, validate_args=False):
486+
super(NonProjectiveDependencyCRF, self).__init__(log_potentials, lengths, args, validate_args=validate_args)
473487
self.multiroot = multiroot
474488

475489
@lazy_property

0 commit comments

Comments
 (0)