@@ -54,6 +54,21 @@ def assert_sycl_queue_equal(result, expected):
5454 assert exec_queue is not None
5555
5656
57+ def get_all_dev_dtypes (no_float16 = True , no_none = True ):
58+ """
59+ Build a list of (device, dtype) combinations for each device's
60+ supported dtype.
61+
62+ """
63+
64+ device_dtype_pairs = []
65+ for device in valid_dev :
66+ dtypes = get_all_dtypes (no_float16 = True , no_none = True , device = device )
67+ for dtype in dtypes :
68+ device_dtype_pairs .append ((device , dtype ))
69+ return device_dtype_pairs
70+
71+
5772@pytest .mark .parametrize (
5873 "func, arg, kwargs" ,
5974 [
@@ -1082,11 +1097,10 @@ def test_array_creation_from_dpctl(copy, device):
10821097 assert isinstance (result , dpnp_array )
10831098
10841099
1085- @pytest .mark .parametrize ("device" , valid_dev , ids = dev_ids )
1086- @pytest .mark .parametrize ("arr_dtype" , get_all_dtypes (no_float16 = True ))
1100+ @pytest .mark .parametrize ("device, dt" , get_all_dev_dtypes ())
10871101@pytest .mark .parametrize ("shape" , [tuple (), (2 ,), (3 , 0 , 1 ), (2 , 2 , 2 )])
1088- def test_from_dlpack (arr_dtype , shape , device ):
1089- X = dpnp .ones (shape = shape , dtype = arr_dtype , device = device )
1102+ def test_from_dlpack (shape , device , dt ):
1103+ X = dpnp .ones (shape = shape , dtype = dt , device = device )
10901104 Y = dpnp .from_dlpack (X )
10911105 assert_array_equal (X , Y )
10921106 assert X .__dlpack_device__ () == Y .__dlpack_device__ ()
@@ -1098,10 +1112,9 @@ def test_from_dlpack(arr_dtype, shape, device):
10981112 assert V .strides == W .strides
10991113
11001114
1101- @pytest .mark .parametrize ("device" , valid_dev , ids = dev_ids )
1102- @pytest .mark .parametrize ("arr_dtype" , get_all_dtypes (no_float16 = True ))
1103- def test_from_dlpack_with_dpt (arr_dtype , device ):
1104- X = dpt .ones ((64 ,), dtype = arr_dtype , device = device )
1115+ @pytest .mark .parametrize ("device, dt" , get_all_dev_dtypes ())
1116+ def test_from_dlpack_with_dpt (device , dt ):
1117+ X = dpt .ones ((64 ,), dtype = dt , device = device )
11051118 Y = dpnp .from_dlpack (X )
11061119 assert_array_equal (X , Y )
11071120 assert isinstance (Y , dpnp .dpnp_array .dpnp_array )
0 commit comments