diff --git a/CHANGELOG.md b/CHANGELOG.md index 30d60fc9898..f8aaae542ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -83,6 +83,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum * Fixed `.data.ptr` property on array views to correctly return the pointer to the view's data location instead of the base allocation pointer [#2812](https://github.com/IntelPython/dpnp/pull/2812) * Resolved an issue with strides calculation in `dpnp.diagonal` to return correct values for empty diagonals [#2814](https://github.com/IntelPython/dpnp/pull/2814) * Fixed test tolerance issues for float16 intermediate precision that became visible when testing against conda-forge's NumPy [#2828](https://github.com/IntelPython/dpnp/pull/2828) +* Ensured device aware dtype handling in `dpnp.identity` and `dpnp.gradient` [#2835](https://github.com/IntelPython/dpnp/pull/2835) ### Security diff --git a/dpnp/dpnp_iface_arraycreation.py b/dpnp/dpnp_iface_arraycreation.py index e7b90264718..5bcf5ea19b8 100644 --- a/dpnp/dpnp_iface_arraycreation.py +++ b/dpnp/dpnp_iface_arraycreation.py @@ -2664,10 +2664,9 @@ def identity( dpnp.check_limitations(like=like) - _dtype = dpnp.default_float_type() if dtype is None else dtype return dpnp.eye( n, - dtype=_dtype, + dtype=dtype, device=device, usm_type=usm_type, sycl_queue=sycl_queue, diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index 366a3363404..e06904a57bd 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -141,7 +141,9 @@ def _gradient_build_dx(f, axes, *varargs): if dpnp.issubdtype(distances.dtype, dpnp.integer): # Convert integer types to default float type to avoid modular # arithmetic in dpnp.diff(distances). - distances = distances.astype(dpnp.default_float_type()) + distances = distances.astype( + dpnp.default_float_type(sycl_queue=f.sycl_queue) + ) diffx = dpnp.diff(distances) # if distances are constant reduce to the scalar case @@ -2707,9 +2709,9 @@ def gradient(f, *varargs, axis=None, edge_order=1): # All other types convert to floating point. # First check if f is a dpnp integer type; if so, convert f to default # float type to avoid modular arithmetic when computing changes in f. - if dpnp.issubdtype(otype, dpnp.integer): - f = f.astype(dpnp.default_float_type()) - otype = dpnp.default_float_type() + otype = dpnp.default_float_type(sycl_queue=f.sycl_queue) + if dpnp.issubdtype(f.dtype, dpnp.integer): + f = f.astype(otype) for axis_, ax_dx in zip(axes, dx): if f.shape[axis_] < edge_order + 1: diff --git a/dpnp/tests/test_sycl_queue.py b/dpnp/tests/test_sycl_queue.py index b0f746720af..e4b9403df8a 100644 --- a/dpnp/tests/test_sycl_queue.py +++ b/dpnp/tests/test_sycl_queue.py @@ -54,6 +54,23 @@ def assert_sycl_queue_equal(result, expected): assert exec_queue is not None +def get_all_dev_dtypes(no_float16=True, no_none=True): + """ + Build a list of (device, dtype) combinations for each device's + supported dtype. + + """ + + device_dtype_pairs = [] + for device in valid_dev: + dtypes = get_all_dtypes( + no_float16=no_float16, no_none=no_none, device=device + ) + for dtype in dtypes: + device_dtype_pairs.append((device, dtype)) + return device_dtype_pairs + + @pytest.mark.parametrize( "func, arg, kwargs", [ @@ -1082,11 +1099,10 @@ def test_array_creation_from_dpctl(copy, device): assert isinstance(result, dpnp_array) -@pytest.mark.parametrize("device", valid_dev, ids=dev_ids) -@pytest.mark.parametrize("arr_dtype", get_all_dtypes(no_float16=True)) +@pytest.mark.parametrize("device, dt", get_all_dev_dtypes()) @pytest.mark.parametrize("shape", [tuple(), (2,), (3, 0, 1), (2, 2, 2)]) -def test_from_dlpack(arr_dtype, shape, device): - X = dpnp.ones(shape=shape, dtype=arr_dtype, device=device) +def test_from_dlpack(shape, device, dt): + X = dpnp.ones(shape=shape, dtype=dt, device=device) Y = dpnp.from_dlpack(X) assert_array_equal(X, Y) assert X.__dlpack_device__() == Y.__dlpack_device__() @@ -1098,10 +1114,9 @@ def test_from_dlpack(arr_dtype, shape, device): assert V.strides == W.strides -@pytest.mark.parametrize("device", valid_dev, ids=dev_ids) -@pytest.mark.parametrize("arr_dtype", get_all_dtypes(no_float16=True)) -def test_from_dlpack_with_dpt(arr_dtype, device): - X = dpt.ones((64,), dtype=arr_dtype, device=device) +@pytest.mark.parametrize("device, dt", get_all_dev_dtypes()) +def test_from_dlpack_with_dpt(device, dt): + X = dpt.ones((64,), dtype=dt, device=device) Y = dpnp.from_dlpack(X) assert_array_equal(X, Y) assert isinstance(Y, dpnp.dpnp_array.dpnp_array)