1
1
2
2
"""BaseSynthesizer unit testing module."""
3
3
4
- from unittest .mock import MagicMock , patch
4
+ from unittest .mock import MagicMock , call , patch
5
+
6
+ import numpy as np
7
+ import torch
5
8
6
9
from ctgan .synthesizers .base import BaseSynthesizer , random_state
7
10
8
11
9
12
@patch ('ctgan.synthesizers.base.torch' )
10
13
@patch ('ctgan.synthesizers.base.np.random' )
11
- def test_valid_random_seed (random_mock , torch_mock ):
12
- """Test the ``random_seed `` attribute with a valid random seed .
14
+ def test_valid_random_state (random_mock , torch_mock ):
15
+ """Test the ``random_state `` attribute with a valid random state .
13
16
14
17
Expect that the decorated function uses the random_state attribute.
15
18
"""
16
19
# Setup
17
20
my_function = MagicMock ()
18
21
instance = MagicMock ()
19
- instance ._random_seed = 42
22
+
23
+ random_state_mock = MagicMock ()
24
+ random_state_mock .get_state .return_value = 'desired numpy state'
25
+ torch_generator_mock = MagicMock ()
26
+ torch_generator_mock .get_state .return_value = 'desired torch state'
27
+ instance .random_states = (random_state_mock , torch_generator_mock )
20
28
21
29
args = {'some' , 'args' }
22
30
kwargs = {'keyword' : 'value' }
23
31
32
+ random_mock .RandomState .return_value = random_state_mock
24
33
random_mock .get_state .return_value = 'random state'
34
+ torch_mock .Generator .return_value = torch_generator_mock
25
35
torch_mock .get_rng_state .return_value = 'torch random state'
26
36
27
37
# Run
@@ -32,25 +42,27 @@ def test_valid_random_seed(random_mock, torch_mock):
32
42
my_function .assert_called_once_with (instance , * args , ** kwargs )
33
43
34
44
instance .assert_not_called
35
- random_mock .get_state .assert_called_once_with ()
36
- torch_mock .get_rng_state .assert_called_once_with ()
37
- random_mock .seed .assert_called_once_with (42 )
38
- random_mock .set_state .assert_called_once_with ('random state' )
39
- torch_mock .set_rng_state .assert_called_once_with ('torch random state' )
45
+ assert random_mock .get_state .call_count == 2
46
+ assert torch_mock .get_rng_state .call_count == 2
47
+ random_mock .RandomState .assert_has_calls (
48
+ [call ().get_state (), call (), call ().set_state ('random state' )])
49
+ random_mock .set_state .assert_has_calls ([call ('desired numpy state' ), call ('random state' )])
50
+ torch_mock .set_rng_state .assert_has_calls (
51
+ [call ('desired torch state' ), call ('torch random state' )])
40
52
41
53
42
54
@patch ('ctgan.synthesizers.base.torch' )
43
55
@patch ('ctgan.synthesizers.base.np.random' )
44
56
def test_no_random_seed (random_mock , torch_mock ):
45
- """Test the ``random_seed `` attribute with no random seed .
57
+ """Test the ``random_state `` attribute with no random state .
46
58
47
59
Expect that the decorated function calls the original function
48
60
when there is no random state.
49
61
"""
50
62
# Setup
51
63
my_function = MagicMock ()
52
64
instance = MagicMock ()
53
- instance ._random_seed = None
65
+ instance .random_states = None
54
66
55
67
args = {'some' , 'args' }
56
68
kwargs = {'keyword' : 'value' }
@@ -64,21 +76,24 @@ def test_no_random_seed(random_mock, torch_mock):
64
76
65
77
instance .assert_not_called
66
78
random_mock .get_state .assert_not_called ()
67
- random_mock .seed .assert_not_called ()
79
+ random_mock .RandomState .assert_not_called ()
68
80
random_mock .set_state .assert_not_called ()
69
81
torch_mock .get_rng_state .assert_not_called ()
82
+ torch_mock .Generator .assert_not_called ()
70
83
torch_mock .set_rng_state .assert_not_called ()
71
84
72
85
73
86
class TestBaseSynthesizer :
74
87
75
- def test_set_random_seed (self ):
76
- """Test ``set_random_seed `` works as expected."""
88
+ def test_set_random_state (self ):
89
+ """Test ``set_random_state `` works as expected."""
77
90
# Setup
78
91
instance = BaseSynthesizer ()
79
92
80
93
# Run
81
- instance .set_random_seed (3 )
94
+ instance .set_random_state (3 )
82
95
83
96
# Assert
84
- assert instance ._random_seed == 3
97
+ assert isinstance (instance .random_states , tuple )
98
+ assert isinstance (instance .random_states [0 ], np .random .RandomState )
99
+ assert isinstance (instance .random_states [1 ], torch .Generator )
0 commit comments