Skip to content

Commit 09d4613

Browse files
sakchalChaluvadi
and
Chaluvadi
authored
Rebased master branch (#47)
* utility functions for tests * added unit tests for range function * Readability changes to cosntants tests * readability changes pt.2 --------- Co-authored-by: Chaluvadi <saketh.chaluvadi@intel.com>
1 parent dc6ba81 commit 09d4613

File tree

2 files changed

+113
-27
lines changed

2 files changed

+113
-27
lines changed

tests/test_constants.py

+52-27
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,23 @@
22

33
import pytest
44

5-
import arrayfire_wrapper.dtypes as dtypes
65
import arrayfire_wrapper.lib as wrapper
6+
from arrayfire_wrapper.dtypes import (
7+
Dtype,
8+
c32,
9+
c64,
10+
c_api_value_to_dtype,
11+
f16,
12+
f32,
13+
f64,
14+
s16,
15+
s32,
16+
s64,
17+
u8,
18+
u16,
19+
u32,
20+
u64,
21+
)
722

823
invalid_shape = (
924
random.randint(1, 10),
@@ -14,6 +29,9 @@
1429
)
1530

1631

32+
all_types = [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64]
33+
34+
1735
@pytest.mark.parametrize(
1836
"shape",
1937
[
@@ -27,7 +45,7 @@
2745
def test_constant_shape(shape: tuple) -> None:
2846
"""Test if constant creates an array with the correct shape."""
2947
number = 5.0
30-
dtype = dtypes.s16
48+
dtype = s16
3149

3250
result = wrapper.constant(number, shape, dtype)
3351

@@ -46,9 +64,8 @@ def test_constant_shape(shape: tuple) -> None:
4664
)
4765
def test_constant_complex_shape(shape: tuple) -> None:
4866
"""Test if constant_complex creates an array with the correct shape."""
49-
dtype = dtypes.c32
67+
dtype = c32
5068

51-
dtype = dtypes.c32
5269
rand_array = wrapper.randu((1, 1), dtype)
5370
number = wrapper.get_scalar(rand_array, dtype)
5471

@@ -71,7 +88,7 @@ def test_constant_complex_shape(shape: tuple) -> None:
7188
)
7289
def test_constant_long_shape(shape: tuple) -> None:
7390
"""Test if constant_long creates an array with the correct shape."""
74-
dtype = dtypes.s64
91+
dtype = s64
7592
rand_array = wrapper.randu((1, 1), dtype)
7693
number = wrapper.get_scalar(rand_array, dtype)
7794

@@ -93,7 +110,7 @@ def test_constant_long_shape(shape: tuple) -> None:
93110
)
94111
def test_constant_ulong_shape(shape: tuple) -> None:
95112
"""Test if constant_ulong creates an array with the correct shape."""
96-
dtype = dtypes.u64
113+
dtype = u64
97114
rand_array = wrapper.randu((1, 1), dtype)
98115
number = wrapper.get_scalar(rand_array, dtype)
99116

@@ -109,15 +126,15 @@ def test_constant_shape_invalid() -> None:
109126
"""Test if constant handles a shape with greater than 4 dimensions"""
110127
with pytest.raises(TypeError):
111128
number = 5.0
112-
dtype = dtypes.s16
129+
dtype = s16
113130

114131
wrapper.constant(number, invalid_shape, dtype)
115132

116133

117134
def test_constant_complex_shape_invalid() -> None:
118135
"""Test if constant_complex handles a shape with greater than 4 dimensions"""
119136
with pytest.raises(TypeError):
120-
dtype = dtypes.c32
137+
dtype = c32
121138
rand_array = wrapper.randu((1, 1), dtype)
122139
number = wrapper.get_scalar(rand_array, dtype)
123140

@@ -128,7 +145,7 @@ def test_constant_complex_shape_invalid() -> None:
128145
def test_constant_long_shape_invalid() -> None:
129146
"""Test if constant_long handles a shape with greater than 4 dimensions"""
130147
with pytest.raises(TypeError):
131-
dtype = dtypes.s64
148+
dtype = s64
132149
rand_array = wrapper.randu((1, 1), dtype)
133150
number = wrapper.get_scalar(rand_array, dtype)
134151

@@ -139,7 +156,7 @@ def test_constant_long_shape_invalid() -> None:
139156
def test_constant_ulong_shape_invalid() -> None:
140157
"""Test if constant_ulong handles a shape with greater than 4 dimensions"""
141158
with pytest.raises(TypeError):
142-
dtype = dtypes.u64
159+
dtype = u64
143160
rand_array = wrapper.randu((1, 1), dtype)
144161
number = wrapper.get_scalar(rand_array, dtype)
145162

@@ -148,50 +165,47 @@ def test_constant_ulong_shape_invalid() -> None:
148165

149166

150167
@pytest.mark.parametrize(
151-
"dtype_index",
152-
[i for i in range(13)],
168+
"dtype",
169+
all_types,
153170
)
154-
def test_constant_dtype(dtype_index: int) -> None:
171+
def test_constant_dtype(dtype: Dtype) -> None:
155172
"""Test if constant creates an array with the correct dtype."""
156-
if dtype_index in [1, 3] or (dtype_index == 2 and not wrapper.get_dbl_support()):
173+
if is_cmplx_type(dtype) or not is_system_supported(dtype):
157174
pytest.skip()
158175

159-
dtype = dtypes.c_api_value_to_dtype(dtype_index)
160-
161176
rand_array = wrapper.randu((1, 1), dtype)
162177
value = wrapper.get_scalar(rand_array, dtype)
163178
shape = (2, 2)
164179
if isinstance(value, (int, float)):
165180
result = wrapper.constant(value, shape, dtype)
166-
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
181+
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
167182
else:
168183
pytest.skip()
169184

170185

171186
@pytest.mark.parametrize(
172-
"dtype_index",
173-
[i for i in range(13)],
187+
"dtype",
188+
all_types,
174189
)
175-
def test_constant_complex_dtype(dtype_index: int) -> None:
190+
def test_constant_complex_dtype(dtype: Dtype) -> None:
176191
"""Test if constant_complex creates an array with the correct dtype."""
177-
if dtype_index not in [1, 3] or (dtype_index == 3 and not wrapper.get_dbl_support()):
192+
if not is_cmplx_type(dtype) or not is_system_supported(dtype):
178193
pytest.skip()
179194

180-
dtype = dtypes.c_api_value_to_dtype(dtype_index)
181195
rand_array = wrapper.randu((1, 1), dtype)
182196
value = wrapper.get_scalar(rand_array, dtype)
183197
shape = (2, 2)
184198

185199
if isinstance(value, (int, float, complex)):
186200
result = wrapper.constant_complex(value, shape, dtype)
187-
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
201+
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
188202
else:
189203
pytest.skip()
190204

191205

192206
def test_constant_long_dtype() -> None:
193207
"""Test if constant_long creates an array with the correct dtype."""
194-
dtype = dtypes.s64
208+
dtype = s64
195209

196210
rand_array = wrapper.randu((1, 1), dtype)
197211
value = wrapper.get_scalar(rand_array, dtype)
@@ -200,14 +214,14 @@ def test_constant_long_dtype() -> None:
200214
if isinstance(value, (int, float)):
201215
result = wrapper.constant_long(value, shape, dtype)
202216

203-
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
217+
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
204218
else:
205219
pytest.skip()
206220

207221

208222
def test_constant_ulong_dtype() -> None:
209223
"""Test if constant_ulong creates an array with the correct dtype."""
210-
dtype = dtypes.u64
224+
dtype = u64
211225

212226
rand_array = wrapper.randu((1, 1), dtype)
213227
value = wrapper.get_scalar(rand_array, dtype)
@@ -216,6 +230,17 @@ def test_constant_ulong_dtype() -> None:
216230
if isinstance(value, (int, float)):
217231
result = wrapper.constant_ulong(value, shape, dtype)
218232

219-
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
233+
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
220234
else:
221235
pytest.skip()
236+
237+
238+
def is_cmplx_type(dtype: Dtype) -> bool:
239+
return dtype == c32 or dtype == c64
240+
241+
242+
def is_system_supported(dtype: Dtype) -> bool:
243+
if dtype in [f64, c64] and not wrapper.get_dbl_support():
244+
return False
245+
246+
return True

tests/test_range.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import random
2+
3+
import pytest
4+
5+
import arrayfire_wrapper.dtypes as dtypes
6+
import arrayfire_wrapper.lib as wrapper
7+
8+
9+
@pytest.mark.parametrize(
10+
"shape",
11+
[
12+
(),
13+
(random.randint(1, 10), 1),
14+
(random.randint(1, 10), random.randint(1, 10)),
15+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
16+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
17+
],
18+
)
19+
def test_range_shape(shape: tuple) -> None:
20+
"""Test if the range function output an AFArray with the correct shape"""
21+
dim = 2
22+
dtype = dtypes.s16
23+
24+
result = wrapper.range(shape, dim, dtype)
25+
26+
assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203
27+
28+
29+
def test_range_invalid_shape() -> None:
30+
"""Test if range function correctly handles an invalid shape"""
31+
with pytest.raises(TypeError):
32+
shape = (
33+
random.randint(1, 10),
34+
random.randint(1, 10),
35+
random.randint(1, 10),
36+
random.randint(1, 10),
37+
random.randint(1, 10),
38+
)
39+
dim = 2
40+
dtype = dtypes.s16
41+
42+
wrapper.range(shape, dim, dtype)
43+
44+
45+
@pytest.mark.parametrize(
46+
"shape",
47+
[
48+
(),
49+
(random.randint(1, 10), 1),
50+
(random.randint(1, 10), random.randint(1, 10)),
51+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
52+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
53+
],
54+
)
55+
def test_range_invalid_dim(shape: tuple) -> None:
56+
"""Test if the range function can properly handle and invalid dimension given"""
57+
with pytest.raises(RuntimeError):
58+
dim = random.randint(4, 10)
59+
dtype = dtypes.s16
60+
61+
wrapper.range(shape, dim, dtype)

0 commit comments

Comments
 (0)