Skip to content

Commit 91a1d48

Browse files
committed
create state from seed
1 parent 8243fbf commit 91a1d48

File tree

4 files changed

+98
-44
lines changed

4 files changed

+98
-44
lines changed

Diff for: ctgan/synthesizers/base.py

+43-18
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,34 @@
77

88

99
@contextlib.contextmanager
10-
def random_seed(seed):
11-
"""Context manager for managing the random seed.
10+
def set_random_states(random_state, set_model_random_state):
11+
"""Context manager for managing the random state.
1212
1313
Args:
14-
seed (int):
15-
The random seed.
14+
random_state (int or tuple):
15+
The random seed or a tuple of (numpy.random.RandomState, torch.Generator).
16+
set_model_random_state (function):
17+
Function to set the random state on the model.
1618
"""
17-
state = np.random.get_state()
18-
torch_state = torch.get_rng_state()
19+
original_np_state = np.random.get_state()
20+
original_torch_state = torch.get_rng_state()
1921

20-
np.random.seed(seed)
21-
torch.manual_seed(seed)
22+
random_np_state, random_torch_state = random_state
23+
24+
np.random.set_state(random_np_state.get_state())
25+
torch.set_rng_state(random_torch_state.get_state())
2226

2327
try:
2428
yield
2529
finally:
26-
np.random.set_state(state)
27-
torch.set_rng_state(torch_state)
30+
current_np_state = np.random.RandomState()
31+
current_np_state.set_state(np.random.get_state())
32+
current_torch_state = torch.Generator()
33+
current_torch_state.set_state(torch.get_rng_state())
34+
set_model_random_state((current_np_state, current_torch_state))
35+
36+
np.random.set_state(original_np_state)
37+
torch.set_rng_state(original_torch_state)
2838

2939

3040
def random_state(function):
@@ -36,11 +46,11 @@ def random_state(function):
3646
"""
3747

3848
def wrapper(self, *args, **kwargs):
39-
if self._random_seed is None:
49+
if self.random_states is None:
4050
return function(self, *args, **kwargs)
4151

4252
else:
43-
with random_seed(self._random_seed):
53+
with set_random_states(self.random_states, self.set_random_state):
4454
return function(self, *args, **kwargs)
4555

4656
return wrapper
@@ -52,7 +62,7 @@ class BaseSynthesizer:
5262
This should contain the save/load methods.
5363
"""
5464

55-
_random_seed = None
65+
random_states = None
5666

5767
def save(self, path):
5868
"""Save the model in the passed `path`."""
@@ -69,11 +79,26 @@ def load(cls, path):
6979
model.set_device(device)
7080
return model
7181

72-
def set_random_seed(self, random_seed):
73-
"""Set the random seed.
82+
def set_random_state(self, random_state):
83+
"""Set the random state.
7484
7585
Args:
76-
random_seed (int):
77-
Seed for the random generator.
86+
random_state (tuple or int):
87+
Either a tuple containing the (numpy.random.RandomState, torch.Generator)
88+
or an int representing the random seed to use for both random states.
7889
"""
79-
self._random_seed = random_seed
90+
if isinstance(random_state, int):
91+
self.random_states = (
92+
np.random.RandomState(seed=random_state),
93+
torch.Generator().manual_seed(random_state),
94+
)
95+
elif (
96+
isinstance(random_state, tuple) and
97+
isinstance(random_state[0], np.random.RandomState) and
98+
isinstance(random_state[1], torch.Generator)
99+
):
100+
self.random_states = random_state
101+
else:
102+
raise ValueError(
103+
'`random_state` {random_state} expected to be an int or a tuple of '
104+
'(`np.random.RandomState`, `torch.Generator`)')

Diff for: tests/integration/synthesizer/test_ctgan.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,20 @@ def test_fixed_random_seed():
215215

216216
# Run
217217
sampled_random = ctgan.sample(10)
218-
ctgan.set_random_seed(0)
219-
sampled_0 = ctgan.sample(10)
220-
sampled_1 = ctgan.sample(10)
218+
219+
ctgan.set_random_state(0)
220+
sampled_0_0 = ctgan.sample(10)
221+
sampled_0_1 = ctgan.sample(10)
222+
223+
ctgan.set_random_state(0)
224+
sampled_1_0 = ctgan.sample(10)
225+
sampled_1_1 = ctgan.sample(10)
221226

222227
# Assert
223-
assert not np.array_equal(sampled_random, sampled_0)
224-
np.testing.assert_array_equal(sampled_0, sampled_1)
228+
assert not np.array_equal(sampled_random, sampled_0_0)
229+
assert not np.array_equal(sampled_random, sampled_0_1)
230+
np.testing.assert_array_equal(sampled_0_0, sampled_1_0)
231+
np.testing.assert_array_equal(sampled_0_1, sampled_1_1)
225232

226233

227234
# Below are CTGAN tests that should be implemented in the future

Diff for: tests/integration/synthesizer/test_tvae.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,17 @@ def test_fixed_random_seed():
115115

116116
# Run
117117
sampled_random = tvae.sample(10)
118-
tvae.set_random_seed(0)
119-
sampled_0 = tvae.sample(10)
120-
sampled_1 = tvae.sample(10)
118+
119+
tvae.set_random_state(0)
120+
sampled_0_0 = tvae.sample(10)
121+
sampled_0_1 = tvae.sample(10)
122+
123+
tvae.set_random_state(0)
124+
sampled_1_0 = tvae.sample(10)
125+
sampled_1_1 = tvae.sample(10)
121126

122127
# Assert
123-
assert not np.array_equal(sampled_random, sampled_0)
124-
np.testing.assert_array_equal(sampled_0, sampled_1)
128+
assert not np.array_equal(sampled_random, sampled_0_0)
129+
assert not np.array_equal(sampled_random, sampled_0_1)
130+
np.testing.assert_array_equal(sampled_0_0, sampled_1_0)
131+
np.testing.assert_array_equal(sampled_0_1, sampled_1_1)

Diff for: tests/unit/synthesizer/test_base.py

+31-16
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,37 @@
11

22
"""BaseSynthesizer unit testing module."""
33

4-
from unittest.mock import MagicMock, patch
4+
from unittest.mock import MagicMock, call, patch
5+
6+
import numpy as np
7+
import torch
58

69
from ctgan.synthesizers.base import BaseSynthesizer, random_state
710

811

912
@patch('ctgan.synthesizers.base.torch')
1013
@patch('ctgan.synthesizers.base.np.random')
11-
def test_valid_random_seed(random_mock, torch_mock):
12-
"""Test the ``random_seed`` attribute with a valid random seed.
14+
def test_valid_random_state(random_mock, torch_mock):
15+
"""Test the ``random_state`` attribute with a valid random state.
1316
1417
Expect that the decorated function uses the random_state attribute.
1518
"""
1619
# Setup
1720
my_function = MagicMock()
1821
instance = MagicMock()
19-
instance._random_seed = 42
22+
23+
random_state_mock = MagicMock()
24+
random_state_mock.get_state.return_value = 'desired numpy state'
25+
torch_generator_mock = MagicMock()
26+
torch_generator_mock.get_state.return_value = 'desired torch state'
27+
instance.random_states = (random_state_mock, torch_generator_mock)
2028

2129
args = {'some', 'args'}
2230
kwargs = {'keyword': 'value'}
2331

32+
random_mock.RandomState.return_value = random_state_mock
2433
random_mock.get_state.return_value = 'random state'
34+
torch_mock.Generator.return_value = torch_generator_mock
2535
torch_mock.get_rng_state.return_value = 'torch random state'
2636

2737
# Run
@@ -32,25 +42,27 @@ def test_valid_random_seed(random_mock, torch_mock):
3242
my_function.assert_called_once_with(instance, *args, **kwargs)
3343

3444
instance.assert_not_called
35-
random_mock.get_state.assert_called_once_with()
36-
torch_mock.get_rng_state.assert_called_once_with()
37-
random_mock.seed.assert_called_once_with(42)
38-
random_mock.set_state.assert_called_once_with('random state')
39-
torch_mock.set_rng_state.assert_called_once_with('torch random state')
45+
assert random_mock.get_state.call_count == 2
46+
assert torch_mock.get_rng_state.call_count == 2
47+
random_mock.RandomState.assert_has_calls(
48+
[call().get_state(), call(), call().set_state('random state')])
49+
random_mock.set_state.assert_has_calls([call('desired numpy state'), call('random state')])
50+
torch_mock.set_rng_state.assert_has_calls(
51+
[call('desired torch state'), call('torch random state')])
4052

4153

4254
@patch('ctgan.synthesizers.base.torch')
4355
@patch('ctgan.synthesizers.base.np.random')
4456
def test_no_random_seed(random_mock, torch_mock):
45-
"""Test the ``random_seed`` attribute with no random seed.
57+
"""Test the ``random_state`` attribute with no random state.
4658
4759
Expect that the decorated function calls the original function
4860
when there is no random state.
4961
"""
5062
# Setup
5163
my_function = MagicMock()
5264
instance = MagicMock()
53-
instance._random_seed = None
65+
instance.random_states = None
5466

5567
args = {'some', 'args'}
5668
kwargs = {'keyword': 'value'}
@@ -64,21 +76,24 @@ def test_no_random_seed(random_mock, torch_mock):
6476

6577
instance.assert_not_called
6678
random_mock.get_state.assert_not_called()
67-
random_mock.seed.assert_not_called()
79+
random_mock.RandomState.assert_not_called()
6880
random_mock.set_state.assert_not_called()
6981
torch_mock.get_rng_state.assert_not_called()
82+
torch_mock.Generator.assert_not_called()
7083
torch_mock.set_rng_state.assert_not_called()
7184

7285

7386
class TestBaseSynthesizer:
7487

75-
def test_set_random_seed(self):
76-
"""Test ``set_random_seed`` works as expected."""
88+
def test_set_random_state(self):
89+
"""Test ``set_random_state`` works as expected."""
7790
# Setup
7891
instance = BaseSynthesizer()
7992

8093
# Run
81-
instance.set_random_seed(3)
94+
instance.set_random_state(3)
8295

8396
# Assert
84-
assert instance._random_seed == 3
97+
assert isinstance(instance.random_states, tuple)
98+
assert isinstance(instance.random_states[0], np.random.RandomState)
99+
assert isinstance(instance.random_states[1], torch.Generator)

0 commit comments

Comments
 (0)