-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathLoading_Utils.py
81 lines (70 loc) · 3.9 KB
/
Loading_Utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
from pykeen.datasets import FB15k237, WN18RR
from pykeen.models import BoxE, RotatE
from pykeen.triples import TriplesFactory
import json
from ExpressivEModel import ExpressivE
def load_checkpoint(config_path, checkpoint_path):
with open(config_path, 'r') as f:
config = json.loads(f.read())
checkpoint = torch.load(checkpoint_path)
if config['dataset'] == 'FB15k237':
dataset = FB15k237()
elif config['dataset'] == 'WN18RR':
dataset = WN18RR()
else:
raise Exception('Dataset %s unknown!' % config['dataset'])
triples_factory = TriplesFactory(
mapped_triples=dataset.training.mapped_triples,
relation_to_id=checkpoint['relation_to_id_dict'],
entity_to_id=checkpoint['entity_to_id_dict'],
create_inverse_triples=config['dataset_kwargs']['create_inverse_triples'],
)
if config['loss'] == 'NSALoss' or config['loss'] == 'NSSALoss':
loss_str = 'nssa'
else:
raise Exception('Unknown loss \'%s\'' % config['loss'])
if config['model'] == 'ExpressivEModel':
if 'interactionMode' in config['model_kwargs']:
trained_model = ExpressivE(triples_factory=triples_factory,
embedding_dim=config['model_kwargs']['embedding_dim'],
p=config['model_kwargs']['p'],
min_denom=config['model_kwargs']['min_denom'],
tanh_map=config['model_kwargs']['tanh_map'],
interactionMode=config['model_kwargs']['interactionMode'],
loss=loss_str,
loss_kwargs=dict(
reduction=config['loss_kwargs']['reduction'],
margin=config['loss_kwargs']['margin'],
adversarial_temperature=config['loss_kwargs']['adversarial_temperature'])
)
else:
trained_model = ExpressivE(triples_factory=triples_factory,
embedding_dim=config['model_kwargs']['embedding_dim'],
p=config['model_kwargs']['p'],
min_denom=config['model_kwargs']['min_denom'],
tanh_map=config['model_kwargs']['tanh_map'],
loss=loss_str,
loss_kwargs=dict(
reduction=config['loss_kwargs']['reduction'],
margin=config['loss_kwargs']['margin'],
adversarial_temperature=config['loss_kwargs']['adversarial_temperature'])
)
elif config['model'] == 'BoxE':
trained_model = BoxE(triples_factory=triples_factory,
embedding_dim=config['model_kwargs']['embedding_dim'],
p=config['model_kwargs']['p'],
loss=loss_str,
loss_kwargs=dict(
reduction=config['loss_kwargs']['reduction'],
margin=config['loss_kwargs']['margin'],
adversarial_temperature=config['loss_kwargs'][
'adversarial_temperature']),
)
elif config['model'] == 'RotatE':
trained_model = RotatE(triples_factory=triples_factory,
embedding_dim=config['model_kwargs']['embedding_dim'],
loss=loss_str,
)
trained_model.load_state_dict(checkpoint['model_state_dict'])
return config, dataset, trained_model