Skip to content

[MAINT] Run array API conformity with 2024.12 spec #2021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
@@ -771,6 +771,7 @@ jobs:
cd /home/runner/work/array-api-tests
${CONDA_PREFIX}/bin/python -c "import dpctl; dpctl.lsplatform()"
export ARRAY_API_TESTS_MODULE=dpctl.tensor
export ARRAY_API_TESTS_VERSION=2024.12
${CONDA_PREFIX}/bin/python -m pytest --json-report --json-report-file=$FILE --disable-deadline --skips-file ${GITHUB_WORKSPACE}/.github/workflows/array-api-skips.txt array_api_tests/ || true
- name: Set Github environment variables
shell: bash -l {0}
133 changes: 85 additions & 48 deletions dpctl/tensor/_copy_utils.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
# limitations under the License.
import builtins
import operator
from numbers import Integral

import numpy as np

@@ -799,6 +800,79 @@ def _nonzero_impl(ary):
return res


def _validate_indices(inds, queue_list, usm_type_list):
"""
Utility for validating indices are usm_ndarray of integral dtype or Python
integers. At least one must be an array.

For each array, the queue and usm type are appended to `queue_list` and
`usm_type_list`, respectively.
"""
any_usmarray = False
for ind in inds:
if isinstance(ind, dpt.usm_ndarray):
any_usmarray = True
if ind.dtype.kind not in "ui":
raise IndexError(
"arrays used as indices must be of integer (or boolean) "
"type"
)
queue_list.append(ind.sycl_queue)
usm_type_list.append(ind.usm_type)
elif not isinstance(ind, Integral):
raise TypeError(
"all elements of `ind` expected to be usm_ndarrays "
f"or integers, found {type(ind)}"
)
if not any_usmarray:
raise TypeError(
"at least one element of `inds` expected to be a usm_ndarray"
)
return inds


def _prepare_indices_arrays(inds, q, usm_type):
"""
Utility taking a mix of usm_ndarray and possibly Python int scalar indices,
a queue (assumed to be common to arrays in inds), and a usm type.

Python scalar integers are promoted to arrays on the provided queue and
with the provided usm type. All arrays are then promoted to a common
integral type (if possible) before being broadcast to a common shape.
"""
# scalar integers -> arrays
inds = tuple(
map(
lambda ind: (
ind
if isinstance(ind, dpt.usm_ndarray)
else dpt.asarray(ind, usm_type=usm_type, sycl_queue=q)
),
inds,
)
)

# promote to a common integral type if possible
ind_dt = dpt.result_type(*inds)
if ind_dt.kind not in "ui":
raise ValueError(
"cannot safely promote indices to an integer data type"
)
inds = tuple(
map(
lambda ind: (
ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt)
),
inds,
)
)

# broadcast
inds = dpt.broadcast_arrays(*inds)

return inds


def _take_multi_index(ary, inds, p, mode=0):
if not isinstance(ary, dpt.usm_ndarray):
raise TypeError(
@@ -819,15 +893,8 @@ def _take_multi_index(ary, inds, p, mode=0):
]
if not isinstance(inds, (list, tuple)):
inds = (inds,)
for ind in inds:
if not isinstance(ind, dpt.usm_ndarray):
raise TypeError("all elements of `ind` expected to be usm_ndarrays")
queues_.append(ind.sycl_queue)
usm_types_.append(ind.usm_type)
if ind.dtype.kind not in "ui":
raise IndexError(
"arrays used as indices must be of integer (or boolean) type"
)

_validate_indices(inds, queues_, usm_types_)
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
exec_q = dpctl.utils.get_execution_queue(queues_)
if exec_q is None:
@@ -837,22 +904,10 @@ def _take_multi_index(ary, inds, p, mode=0):
"Use `usm_ndarray.to_device` method to migrate data to "
"be associated with the same queue."
)

if len(inds) > 1:
ind_dt = dpt.result_type(*inds)
# ind arrays have been checked to be of integer dtype
if ind_dt.kind not in "ui":
raise ValueError(
"cannot safely promote indices to an integer data type"
)
inds = tuple(
map(
lambda ind: (
ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt)
),
inds,
)
)
inds = dpt.broadcast_arrays(*inds)
inds = _prepare_indices_arrays(inds, exec_q, res_usm_type)

ind0 = inds[0]
ary_sh = ary.shape
p_end = p + len(inds)
@@ -968,15 +1023,9 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
]
if not isinstance(inds, (list, tuple)):
inds = (inds,)
for ind in inds:
if not isinstance(ind, dpt.usm_ndarray):
raise TypeError("all elements of `ind` expected to be usm_ndarrays")
queues_.append(ind.sycl_queue)
usm_types_.append(ind.usm_type)
if ind.dtype.kind not in "ui":
raise IndexError(
"arrays used as indices must be of integer (or boolean) type"
)

_validate_indices(inds, queues_, usm_types_)

vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
exec_q = dpctl.utils.get_execution_queue(queues_)
if exec_q is not None:
@@ -993,22 +1042,10 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
"Use `usm_ndarray.to_device` method to migrate data to "
"be associated with the same queue."
)

if len(inds) > 1:
ind_dt = dpt.result_type(*inds)
# ind arrays have been checked to be of integer dtype
if ind_dt.kind not in "ui":
raise ValueError(
"cannot safely promote indices to an integer data type"
)
inds = tuple(
map(
lambda ind: (
ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt)
),
inds,
)
)
inds = dpt.broadcast_arrays(*inds)
inds = _prepare_indices_arrays(inds, exec_q, vals_usm_type)

ind0 = inds[0]
ary_sh = ary.shape
p_end = p + len(inds)
64 changes: 45 additions & 19 deletions dpctl/tensor/_slicing.pxi
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
# limitations under the License.

import numbers
from operator import index
from cpython.buffer cimport PyObject_CheckBuffer


@@ -64,7 +65,7 @@ cdef bint _is_integral(object x) except *:
return False
if callable(getattr(x, "__index__", None)):
try:
x.__index__()
index(x)
except (TypeError, ValueError):
return False
return True
@@ -136,7 +137,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
else:
return ((0,) + shape, (0,) + strides, offset, _no_advanced_ind, _no_advanced_pos)
elif _is_integral(ind):
ind = ind.__index__()
ind = index(ind)
new_shape = shape[1:]
new_strides = strides[1:]
is_empty = any(sh_i == 0 for sh_i in new_shape)
@@ -179,10 +180,12 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
if array_streak_started:
array_streak_interrupted = True
elif _is_integral(i):
explicit_index += 1
axes_referenced += 1
if array_streak_started:
array_streak_interrupted = True
if array_streak_started and not array_streak_interrupted:
# integers converted to arrays in this case
array_count += 1
else:
explicit_index += 1
elif isinstance(i, usm_ndarray):
if not seen_arrays_yet:
seen_arrays_yet = True
@@ -229,6 +232,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
advanced_start_pos_set = False
new_offset = offset
is_empty = False
array_streak = False
for i in range(len(ind)):
ind_i = ind[i]
if (ind_i is Ellipsis):
@@ -239,9 +243,13 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
is_empty = True
new_offset = offset
k = k_new
if array_streak:
array_streak = False
elif ind_i is None:
new_shape.append(1)
new_strides.append(0)
if array_streak:
array_streak = False
elif isinstance(ind_i, slice):
k_new = k + 1
sl_start, sl_stop, sl_step = ind_i.indices(shape[k])
@@ -255,26 +263,46 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
is_empty = True
new_offset = offset
k = k_new
if array_streak:
array_streak = False
elif _is_boolean(ind_i):
new_shape.append(1 if ind_i else 0)
new_strides.append(0)
if array_streak:
array_streak = False
elif _is_integral(ind_i):
ind_i = ind_i.__index__()
if 0 <= ind_i < shape[k]:
if array_streak:
if not isinstance(ind_i, usm_ndarray):
ind_i = index(ind_i)
# integer will be converted to an array, still raise if OOB
if not (0 <= ind_i < shape[k] or -shape[k] <= ind_i < 0):
raise IndexError(
("Index {0} is out of range for "
"axes {1} with size {2}").format(ind_i, k, shape[k]))
new_advanced_ind.append(ind_i)
k_new = k + 1
if not is_empty:
new_offset = new_offset + ind_i * strides[k]
k = k_new
elif -shape[k] <= ind_i < 0:
k_new = k + 1
if not is_empty:
new_offset = new_offset + (shape[k] + ind_i) * strides[k]
new_shape.extend(shape[k:k_new])
new_strides.extend(strides[k:k_new])
k = k_new
else:
raise IndexError(
("Index {0} is out of range for "
"axes {1} with size {2}").format(ind_i, k, shape[k]))
ind_i = index(ind_i)
if 0 <= ind_i < shape[k]:
k_new = k + 1
if not is_empty:
new_offset = new_offset + ind_i * strides[k]
k = k_new
elif -shape[k] <= ind_i < 0:
k_new = k + 1
if not is_empty:
new_offset = new_offset + (shape[k] + ind_i) * strides[k]
k = k_new
else:
raise IndexError(
("Index {0} is out of range for "
"axes {1} with size {2}").format(ind_i, k, shape[k]))
elif isinstance(ind_i, usm_ndarray):
if not array_streak:
array_streak = True
if not advanced_start_pos_set:
new_advanced_start_pos = len(new_shape)
advanced_start_pos_set = True
@@ -287,8 +315,6 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
new_shape.extend(shape[k:k_new])
new_strides.extend(strides[k:k_new])
k = k_new
else:
raise IndexError
new_shape.extend(shape[k:])
new_strides.extend(strides[k:])
new_shape_len += len(shape) - k
11 changes: 6 additions & 5 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
@@ -161,7 +161,6 @@ cdef void _validate_and_use_stream(object stream, c_dpctl.SyclQueue self_queue)
ev = self_queue.submit_barrier()
stream.submit_barrier(dependent_events=[ev])


cdef class usm_ndarray:
""" usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
offset=0, order="C", buffer_ctor_kwargs=dict(), \
@@ -962,6 +961,8 @@ cdef class usm_ndarray:
return res

from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index

# if len(adv_ind == 1), the (only) element is always an array
if len(adv_ind) == 1 and adv_ind[0].dtype == dpt_bool:
key_ = adv_ind[0]
adv_ind_end_p = key_.ndim + adv_ind_start_p
@@ -979,10 +980,10 @@ cdef class usm_ndarray:
res.flags_ = _copy_writable(res.flags_, self.flags_)
return res

if any(ind.dtype == dpt_bool for ind in adv_ind):
if any((isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool) for ind in adv_ind):
adv_ind_int = list()
for ind in adv_ind:
if ind.dtype == dpt_bool:
if isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool:
adv_ind_int.extend(_nonzero_impl(ind))
else:
adv_ind_int.append(ind)
@@ -1433,10 +1434,10 @@ cdef class usm_ndarray:
_place_impl(Xv, adv_ind[0], rhs, axis=adv_ind_start_p)
return

if any(ind.dtype == dpt_bool for ind in adv_ind):
if any((isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool) for ind in adv_ind):
adv_ind_int = list()
for ind in adv_ind:
if ind.dtype == dpt_bool:
if isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool:
adv_ind_int.extend(_nonzero_impl(ind))
else:
adv_ind_int.append(ind)
63 changes: 61 additions & 2 deletions dpctl/tests/test_usm_ndarray_indexing.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
import dpctl
import dpctl.tensor as dpt
import dpctl.tensor._tensor_impl as ti
from dpctl.tensor._copy_utils import _take_multi_index
from dpctl.utils import ExecutionPlacementError

from .helper import get_queue_or_skip, skip_if_dtype_not_supported
@@ -252,8 +253,14 @@ def test_advanced_slice5():
q = get_queue_or_skip()
ii = dpt.asarray([1, 2], sycl_queue=q)
x = _make_3d("i4", q)
with pytest.raises(IndexError):
x[ii, 0, ii]
y = x[ii, 0, ii]
assert isinstance(y, dpt.usm_ndarray)
# 0 broadcast to [0, 0] per array API
assert y.shape == ii.shape
assert _all_equal(
(x[ii[i], 0, ii[i]] for i in range(ii.shape[0])),
(y[i] for i in range(ii.shape[0])),
)


def test_advanced_slice6():
@@ -395,6 +402,44 @@ def test_advanced_slice13():
assert (dpt.asnumpy(y) == dpt.asnumpy(expected)).all()


def test_advanced_slice14():
q = get_queue_or_skip()
ii = dpt.asarray([1, 2], sycl_queue=q)
x = dpt.reshape(dpt.arange(3**5, dtype="i4", sycl_queue=q), (3,) * 5)
y = x[ii, 0, ii, 1, :]
assert isinstance(y, dpt.usm_ndarray)
# integers broadcast to ii.shape per array API
assert y.shape == ii.shape + x.shape[-1:]
assert _all_equal(
(
x[ii[i], 0, ii[i], 1, k]
for i in range(ii.shape[0])
for k in range(x.shape[-1])
),
(y[i, k] for i in range(ii.shape[0]) for k in range(x.shape[-1])),
)


def test_advanced_slice15():
q = get_queue_or_skip()
ii = dpt.asarray([1, 2], sycl_queue=q)
x = dpt.reshape(dpt.arange(3**5, dtype="i4", sycl_queue=q), (3,) * 5)
# : cannot appear between two integral arrays
with pytest.raises(IndexError):
x[ii, 0, ii, :, ii]


def test_advanced_slice16():
q = get_queue_or_skip()
ii = dpt.asarray(1, sycl_queue=q)
i0 = dpt.asarray(False, sycl_queue=q)
i1 = dpt.asarray(True, sycl_queue=q)
x = dpt.reshape(dpt.arange(3**5, dtype="i4", sycl_queue=q), (3,) * 5)
y = x[ii, i0, ii, i1, :]
# TODO: add a shape check here when discrepancy with NumPy is investigated
assert isinstance(y, dpt.usm_ndarray)


def test_boolean_indexing_validation():
get_queue_or_skip()
x = dpt.zeros(10, dtype="i4")
@@ -1956,3 +2001,17 @@ def test_take_out_errors():
out_bad_q = dpt.empty(ind.shape, dtype=x.dtype, sycl_queue=q2)
with pytest.raises(dpctl.utils.ExecutionPlacementError):
dpt.take(x, ind, out=out_bad_q)


def test_getitem_impl_fn_invalid_inp():
get_queue_or_skip()

x = dpt.ones((10, 10), dtype="i4")

bad_ind_type = (dpt.ones((), dtype="i4"), 2.0)
with pytest.raises(TypeError):
_take_multi_index(x, bad_ind_type, 0, 0)

no_array_inds = (2, 3)
with pytest.raises(TypeError):
_take_multi_index(x, no_array_inds, 0, 0)