Skip to content

Commit 0ea0562

Browse files
committed
Parametrize test_from_dlpack with device-specific dtypes
1 parent 6c10ea7 commit 0ea0562

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

dpnp/tests/test_sycl_queue.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,21 @@ def assert_sycl_queue_equal(result, expected):
5454
assert exec_queue is not None
5555

5656

57+
def get_all_dev_dtypes(no_float16=True, no_none=True):
58+
"""
59+
Build a list of (device, dtype) combinations for each device's
60+
supported dtype.
61+
62+
"""
63+
64+
device_dtype_pairs = []
65+
for device in valid_dev:
66+
dtypes = get_all_dtypes(no_float16=True, no_none=True, device=device)
67+
for dtype in dtypes:
68+
device_dtype_pairs.append((device, dtype))
69+
return device_dtype_pairs
70+
71+
5772
@pytest.mark.parametrize(
5873
"func, arg, kwargs",
5974
[
@@ -1082,11 +1097,10 @@ def test_array_creation_from_dpctl(copy, device):
10821097
assert isinstance(result, dpnp_array)
10831098

10841099

1085-
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
1086-
@pytest.mark.parametrize("arr_dtype", get_all_dtypes(no_float16=True))
1100+
@pytest.mark.parametrize("device, dt", get_all_dev_dtypes())
10871101
@pytest.mark.parametrize("shape", [tuple(), (2,), (3, 0, 1), (2, 2, 2)])
1088-
def test_from_dlpack(arr_dtype, shape, device):
1089-
X = dpnp.ones(shape=shape, dtype=arr_dtype, device=device)
1102+
def test_from_dlpack(shape, device, dt):
1103+
X = dpnp.ones(shape=shape, dtype=dt, device=device)
10901104
Y = dpnp.from_dlpack(X)
10911105
assert_array_equal(X, Y)
10921106
assert X.__dlpack_device__() == Y.__dlpack_device__()
@@ -1098,10 +1112,9 @@ def test_from_dlpack(arr_dtype, shape, device):
10981112
assert V.strides == W.strides
10991113

11001114

1101-
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
1102-
@pytest.mark.parametrize("arr_dtype", get_all_dtypes(no_float16=True))
1103-
def test_from_dlpack_with_dpt(arr_dtype, device):
1104-
X = dpt.ones((64,), dtype=arr_dtype, device=device)
1115+
@pytest.mark.parametrize("device, dt", get_all_dev_dtypes())
1116+
def test_from_dlpack_with_dpt(device, dt):
1117+
X = dpt.ones((64,), dtype=dt, device=device)
11051118
Y = dpnp.from_dlpack(X)
11061119
assert_array_equal(X, Y)
11071120
assert isinstance(Y, dpnp.dpnp_array.dpnp_array)

0 commit comments

Comments
 (0)