1
1
import torch
2
+ from torch .distributions import constraints
2
3
from torch .distributions .distribution import Distribution
3
4
from torch .distributions .utils import lazy_property
4
5
from .linearchain import LinearChain
@@ -36,15 +37,18 @@ class StructDistribution(Distribution):
36
37
log_potentials (tensor, batch_shape x event_shape) : log-potentials :math:`\phi`
37
38
lengths (long tensor, batch_shape) : integers for length masking
38
39
"""
39
- validate_args = False
40
+ arg_constraints = {
41
+ "log_potentials" : constraints .real ,
42
+ "lengths" : constraints .nonnegative_integer
43
+ }
40
44
41
- def __init__ (self , log_potentials , lengths = None , args = {}):
45
+ def __init__ (self , log_potentials , lengths = None , args = {}, validate_args = False ):
42
46
batch_shape = log_potentials .shape [:1 ]
43
47
event_shape = log_potentials .shape [1 :]
44
48
self .log_potentials = log_potentials
45
49
self .lengths = lengths
46
50
self .args = args
47
- super ().__init__ (batch_shape = batch_shape , event_shape = event_shape )
51
+ super ().__init__ (batch_shape = batch_shape , event_shape = event_shape , validate_args = validate_args )
48
52
49
53
def _new (self , * args , ** kwargs ):
50
54
return self ._param .new (* args , ** kwargs )
@@ -295,11 +299,17 @@ class AlignmentCRF(StructDistribution):
295
299
296
300
"""
297
301
struct = Alignment
298
-
299
- def __init__ (self , log_potentials , local = False , lengths = None , max_gap = None ):
302
+ arg_constraints = {
303
+ "log_potentials" : constraints .real ,
304
+ "local" : constraints .boolean ,
305
+ "max_gap" : constraints .nonnegative_integer ,
306
+ "lengths" : constraints .nonnegative_integer
307
+ }
308
+
309
+ def __init__ (self , log_potentials , local = False , lengths = None , max_gap = None , validate_args = False ):
300
310
self .local = local
301
311
self .max_gap = max_gap
302
- super ().__init__ (log_potentials , lengths )
312
+ super ().__init__ (log_potentials , lengths , validate_args = validate_args )
303
313
304
314
def _struct (self , sr = None ):
305
315
return self .struct (
@@ -324,9 +334,9 @@ class HMM(StructDistribution):
324
334
Implemented as a special case of linear chain CRF.
325
335
"""
326
336
327
- def __init__ (self , transition , emission , init , observations , lengths = None ):
337
+ def __init__ (self , transition , emission , init , observations , lengths = None , validate_args = False ):
328
338
log_potentials = HMM .struct .hmm (transition , emission , init , observations )
329
- super ().__init__ (log_potentials , lengths )
339
+ super ().__init__ (log_potentials , lengths , validate_args = validate_args )
330
340
331
341
struct = LinearChain
332
342
@@ -380,8 +390,8 @@ class DependencyCRF(StructDistribution):
380
390
381
391
"""
382
392
383
- def __init__ (self , log_potentials , lengths = None , args = {}, multiroot = True ):
384
- super (DependencyCRF , self ).__init__ (log_potentials , lengths , args )
393
+ def __init__ (self , log_potentials , lengths = None , args = {}, multiroot = True , validate_args = False ):
394
+ super (DependencyCRF , self ).__init__ (log_potentials , lengths , args , validate_args = validate_args )
385
395
self .struct = DepTree
386
396
setattr (self .struct , "multiroot" , multiroot )
387
397
@@ -436,13 +446,13 @@ class SentCFG(StructDistribution):
436
446
437
447
struct = CKY
438
448
439
- def __init__ (self , log_potentials , lengths = None ):
449
+ def __init__ (self , log_potentials , lengths = None , validate_args = False ):
440
450
batch_shape = log_potentials [0 ].shape [:1 ]
441
451
event_shape = log_potentials [0 ].shape [1 :]
442
452
self .log_potentials = log_potentials
443
453
self .lengths = lengths
444
454
super (StructDistribution , self ).__init__ (
445
- batch_shape = batch_shape , event_shape = event_shape
455
+ batch_shape = batch_shape , event_shape = event_shape , validate_args = validate_args
446
456
)
447
457
448
458
@@ -468,8 +478,12 @@ class NonProjectiveDependencyCRF(StructDistribution):
468
478
469
479
"""
470
480
471
- def __init__ (self , log_potentials , lengths = None , args = {}, multiroot = False ):
472
- super (NonProjectiveDependencyCRF , self ).__init__ (log_potentials , lengths , args )
481
+ arg_constraints = {
482
+ "log_potentials" : constraints .real
483
+ }
484
+
485
+ def __init__ (self , log_potentials , lengths = None , args = {}, multiroot = False , validate_args = False ):
486
+ super (NonProjectiveDependencyCRF , self ).__init__ (log_potentials , lengths , args , validate_args = validate_args )
473
487
self .multiroot = multiroot
474
488
475
489
@lazy_property
0 commit comments