2
2
3
3
import pytest
4
4
5
- import arrayfire_wrapper .dtypes as dtypes
6
5
import arrayfire_wrapper .lib as wrapper
6
+ from arrayfire_wrapper .dtypes import (
7
+ Dtype ,
8
+ c32 ,
9
+ c64 ,
10
+ c_api_value_to_dtype ,
11
+ f16 ,
12
+ f32 ,
13
+ f64 ,
14
+ s16 ,
15
+ s32 ,
16
+ s64 ,
17
+ u8 ,
18
+ u16 ,
19
+ u32 ,
20
+ u64 ,
21
+ )
7
22
8
23
invalid_shape = (
9
24
random .randint (1 , 10 ),
14
29
)
15
30
16
31
32
+ all_types = [s16 , s32 , s64 , u8 , u16 , u32 , u64 , f16 , f32 , f64 , c32 , c64 ]
33
+
34
+
17
35
@pytest .mark .parametrize (
18
36
"shape" ,
19
37
[
27
45
def test_constant_shape (shape : tuple ) -> None :
28
46
"""Test if constant creates an array with the correct shape."""
29
47
number = 5.0
30
- dtype = dtypes . s16
48
+ dtype = s16
31
49
32
50
result = wrapper .constant (number , shape , dtype )
33
51
@@ -46,9 +64,8 @@ def test_constant_shape(shape: tuple) -> None:
46
64
)
47
65
def test_constant_complex_shape (shape : tuple ) -> None :
48
66
"""Test if constant_complex creates an array with the correct shape."""
49
- dtype = dtypes . c32
67
+ dtype = c32
50
68
51
- dtype = dtypes .c32
52
69
rand_array = wrapper .randu ((1 , 1 ), dtype )
53
70
number = wrapper .get_scalar (rand_array , dtype )
54
71
@@ -71,7 +88,7 @@ def test_constant_complex_shape(shape: tuple) -> None:
71
88
)
72
89
def test_constant_long_shape (shape : tuple ) -> None :
73
90
"""Test if constant_long creates an array with the correct shape."""
74
- dtype = dtypes . s64
91
+ dtype = s64
75
92
rand_array = wrapper .randu ((1 , 1 ), dtype )
76
93
number = wrapper .get_scalar (rand_array , dtype )
77
94
@@ -93,7 +110,7 @@ def test_constant_long_shape(shape: tuple) -> None:
93
110
)
94
111
def test_constant_ulong_shape (shape : tuple ) -> None :
95
112
"""Test if constant_ulong creates an array with the correct shape."""
96
- dtype = dtypes . u64
113
+ dtype = u64
97
114
rand_array = wrapper .randu ((1 , 1 ), dtype )
98
115
number = wrapper .get_scalar (rand_array , dtype )
99
116
@@ -109,15 +126,15 @@ def test_constant_shape_invalid() -> None:
109
126
"""Test if constant handles a shape with greater than 4 dimensions"""
110
127
with pytest .raises (TypeError ):
111
128
number = 5.0
112
- dtype = dtypes . s16
129
+ dtype = s16
113
130
114
131
wrapper .constant (number , invalid_shape , dtype )
115
132
116
133
117
134
def test_constant_complex_shape_invalid () -> None :
118
135
"""Test if constant_complex handles a shape with greater than 4 dimensions"""
119
136
with pytest .raises (TypeError ):
120
- dtype = dtypes . c32
137
+ dtype = c32
121
138
rand_array = wrapper .randu ((1 , 1 ), dtype )
122
139
number = wrapper .get_scalar (rand_array , dtype )
123
140
@@ -128,7 +145,7 @@ def test_constant_complex_shape_invalid() -> None:
128
145
def test_constant_long_shape_invalid () -> None :
129
146
"""Test if constant_long handles a shape with greater than 4 dimensions"""
130
147
with pytest .raises (TypeError ):
131
- dtype = dtypes . s64
148
+ dtype = s64
132
149
rand_array = wrapper .randu ((1 , 1 ), dtype )
133
150
number = wrapper .get_scalar (rand_array , dtype )
134
151
@@ -139,7 +156,7 @@ def test_constant_long_shape_invalid() -> None:
139
156
def test_constant_ulong_shape_invalid () -> None :
140
157
"""Test if constant_ulong handles a shape with greater than 4 dimensions"""
141
158
with pytest .raises (TypeError ):
142
- dtype = dtypes . u64
159
+ dtype = u64
143
160
rand_array = wrapper .randu ((1 , 1 ), dtype )
144
161
number = wrapper .get_scalar (rand_array , dtype )
145
162
@@ -148,50 +165,47 @@ def test_constant_ulong_shape_invalid() -> None:
148
165
149
166
150
167
@pytest .mark .parametrize (
151
- "dtype_index " ,
152
- [ i for i in range ( 13 )] ,
168
+ "dtype " ,
169
+ all_types ,
153
170
)
154
- def test_constant_dtype (dtype_index : int ) -> None :
171
+ def test_constant_dtype (dtype : Dtype ) -> None :
155
172
"""Test if constant creates an array with the correct dtype."""
156
- if dtype_index in [ 1 , 3 ] or ( dtype_index == 2 and not wrapper . get_dbl_support () ):
173
+ if is_cmplx_type ( dtype ) or not is_system_supported ( dtype ):
157
174
pytest .skip ()
158
175
159
- dtype = dtypes .c_api_value_to_dtype (dtype_index )
160
-
161
176
rand_array = wrapper .randu ((1 , 1 ), dtype )
162
177
value = wrapper .get_scalar (rand_array , dtype )
163
178
shape = (2 , 2 )
164
179
if isinstance (value , (int , float )):
165
180
result = wrapper .constant (value , shape , dtype )
166
- assert dtypes . c_api_value_to_dtype (wrapper .get_type (result )) == dtype
181
+ assert c_api_value_to_dtype (wrapper .get_type (result )) == dtype
167
182
else :
168
183
pytest .skip ()
169
184
170
185
171
186
@pytest .mark .parametrize (
172
- "dtype_index " ,
173
- [ i for i in range ( 13 )] ,
187
+ "dtype " ,
188
+ all_types ,
174
189
)
175
- def test_constant_complex_dtype (dtype_index : int ) -> None :
190
+ def test_constant_complex_dtype (dtype : Dtype ) -> None :
176
191
"""Test if constant_complex creates an array with the correct dtype."""
177
- if dtype_index not in [ 1 , 3 ] or ( dtype_index == 3 and not wrapper . get_dbl_support () ):
192
+ if not is_cmplx_type ( dtype ) or not is_system_supported ( dtype ):
178
193
pytest .skip ()
179
194
180
- dtype = dtypes .c_api_value_to_dtype (dtype_index )
181
195
rand_array = wrapper .randu ((1 , 1 ), dtype )
182
196
value = wrapper .get_scalar (rand_array , dtype )
183
197
shape = (2 , 2 )
184
198
185
199
if isinstance (value , (int , float , complex )):
186
200
result = wrapper .constant_complex (value , shape , dtype )
187
- assert dtypes . c_api_value_to_dtype (wrapper .get_type (result )) == dtype
201
+ assert c_api_value_to_dtype (wrapper .get_type (result )) == dtype
188
202
else :
189
203
pytest .skip ()
190
204
191
205
192
206
def test_constant_long_dtype () -> None :
193
207
"""Test if constant_long creates an array with the correct dtype."""
194
- dtype = dtypes . s64
208
+ dtype = s64
195
209
196
210
rand_array = wrapper .randu ((1 , 1 ), dtype )
197
211
value = wrapper .get_scalar (rand_array , dtype )
@@ -200,14 +214,14 @@ def test_constant_long_dtype() -> None:
200
214
if isinstance (value , (int , float )):
201
215
result = wrapper .constant_long (value , shape , dtype )
202
216
203
- assert dtypes . c_api_value_to_dtype (wrapper .get_type (result )) == dtype
217
+ assert c_api_value_to_dtype (wrapper .get_type (result )) == dtype
204
218
else :
205
219
pytest .skip ()
206
220
207
221
208
222
def test_constant_ulong_dtype () -> None :
209
223
"""Test if constant_ulong creates an array with the correct dtype."""
210
- dtype = dtypes . u64
224
+ dtype = u64
211
225
212
226
rand_array = wrapper .randu ((1 , 1 ), dtype )
213
227
value = wrapper .get_scalar (rand_array , dtype )
@@ -216,6 +230,17 @@ def test_constant_ulong_dtype() -> None:
216
230
if isinstance (value , (int , float )):
217
231
result = wrapper .constant_ulong (value , shape , dtype )
218
232
219
- assert dtypes . c_api_value_to_dtype (wrapper .get_type (result )) == dtype
233
+ assert c_api_value_to_dtype (wrapper .get_type (result )) == dtype
220
234
else :
221
235
pytest .skip ()
236
+
237
+
238
+ def is_cmplx_type (dtype : Dtype ) -> bool :
239
+ return dtype == c32 or dtype == c64
240
+
241
+
242
+ def is_system_supported (dtype : Dtype ) -> bool :
243
+ if dtype in [f64 , c64 ] and not wrapper .get_dbl_support ():
244
+ return False
245
+
246
+ return True
0 commit comments