Skip to content

Commit 4f63b18

Browse files
authored
Merge pull request #2032 from IntelPython/update-indexing-for-array-api-2024
Update integer advanced indexing for array API 2024.12 spec
2 parents d1fb36e + 2fdac4c commit 4f63b18

File tree

4 files changed

+197
-74
lines changed

4 files changed

+197
-74
lines changed

dpctl/tensor/_copy_utils.py

+85-48
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616
import builtins
1717
import operator
18+
from numbers import Integral
1819

1920
import numpy as np
2021

@@ -799,6 +800,79 @@ def _nonzero_impl(ary):
799800
return res
800801

801802

803+
def _validate_indices(inds, queue_list, usm_type_list):
804+
"""
805+
Utility for validating indices are usm_ndarray of integral dtype or Python
806+
integers. At least one must be an array.
807+
808+
For each array, the queue and usm type are appended to `queue_list` and
809+
`usm_type_list`, respectively.
810+
"""
811+
any_usmarray = False
812+
for ind in inds:
813+
if isinstance(ind, dpt.usm_ndarray):
814+
any_usmarray = True
815+
if ind.dtype.kind not in "ui":
816+
raise IndexError(
817+
"arrays used as indices must be of integer (or boolean) "
818+
"type"
819+
)
820+
queue_list.append(ind.sycl_queue)
821+
usm_type_list.append(ind.usm_type)
822+
elif not isinstance(ind, Integral):
823+
raise TypeError(
824+
"all elements of `ind` expected to be usm_ndarrays "
825+
f"or integers, found {type(ind)}"
826+
)
827+
if not any_usmarray:
828+
raise TypeError(
829+
"at least one element of `inds` expected to be a usm_ndarray"
830+
)
831+
return inds
832+
833+
834+
def _prepare_indices_arrays(inds, q, usm_type):
835+
"""
836+
Utility taking a mix of usm_ndarray and possibly Python int scalar indices,
837+
a queue (assumed to be common to arrays in inds), and a usm type.
838+
839+
Python scalar integers are promoted to arrays on the provided queue and
840+
with the provided usm type. All arrays are then promoted to a common
841+
integral type (if possible) before being broadcast to a common shape.
842+
"""
843+
# scalar integers -> arrays
844+
inds = tuple(
845+
map(
846+
lambda ind: (
847+
ind
848+
if isinstance(ind, dpt.usm_ndarray)
849+
else dpt.asarray(ind, usm_type=usm_type, sycl_queue=q)
850+
),
851+
inds,
852+
)
853+
)
854+
855+
# promote to a common integral type if possible
856+
ind_dt = dpt.result_type(*inds)
857+
if ind_dt.kind not in "ui":
858+
raise ValueError(
859+
"cannot safely promote indices to an integer data type"
860+
)
861+
inds = tuple(
862+
map(
863+
lambda ind: (
864+
ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt)
865+
),
866+
inds,
867+
)
868+
)
869+
870+
# broadcast
871+
inds = dpt.broadcast_arrays(*inds)
872+
873+
return inds
874+
875+
802876
def _take_multi_index(ary, inds, p, mode=0):
803877
if not isinstance(ary, dpt.usm_ndarray):
804878
raise TypeError(
@@ -819,15 +893,8 @@ def _take_multi_index(ary, inds, p, mode=0):
819893
]
820894
if not isinstance(inds, (list, tuple)):
821895
inds = (inds,)
822-
for ind in inds:
823-
if not isinstance(ind, dpt.usm_ndarray):
824-
raise TypeError("all elements of `ind` expected to be usm_ndarrays")
825-
queues_.append(ind.sycl_queue)
826-
usm_types_.append(ind.usm_type)
827-
if ind.dtype.kind not in "ui":
828-
raise IndexError(
829-
"arrays used as indices must be of integer (or boolean) type"
830-
)
896+
897+
_validate_indices(inds, queues_, usm_types_)
831898
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
832899
exec_q = dpctl.utils.get_execution_queue(queues_)
833900
if exec_q is None:
@@ -837,22 +904,10 @@ def _take_multi_index(ary, inds, p, mode=0):
837904
"Use `usm_ndarray.to_device` method to migrate data to "
838905
"be associated with the same queue."
839906
)
907+
840908
if len(inds) > 1:
841-
ind_dt = dpt.result_type(*inds)
842-
# ind arrays have been checked to be of integer dtype
843-
if ind_dt.kind not in "ui":
844-
raise ValueError(
845-
"cannot safely promote indices to an integer data type"
846-
)
847-
inds = tuple(
848-
map(
849-
lambda ind: (
850-
ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt)
851-
),
852-
inds,
853-
)
854-
)
855-
inds = dpt.broadcast_arrays(*inds)
909+
inds = _prepare_indices_arrays(inds, exec_q, res_usm_type)
910+
856911
ind0 = inds[0]
857912
ary_sh = ary.shape
858913
p_end = p + len(inds)
@@ -968,15 +1023,9 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
9681023
]
9691024
if not isinstance(inds, (list, tuple)):
9701025
inds = (inds,)
971-
for ind in inds:
972-
if not isinstance(ind, dpt.usm_ndarray):
973-
raise TypeError("all elements of `ind` expected to be usm_ndarrays")
974-
queues_.append(ind.sycl_queue)
975-
usm_types_.append(ind.usm_type)
976-
if ind.dtype.kind not in "ui":
977-
raise IndexError(
978-
"arrays used as indices must be of integer (or boolean) type"
979-
)
1026+
1027+
_validate_indices(inds, queues_, usm_types_)
1028+
9801029
vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
9811030
exec_q = dpctl.utils.get_execution_queue(queues_)
9821031
if exec_q is not None:
@@ -993,22 +1042,10 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
9931042
"Use `usm_ndarray.to_device` method to migrate data to "
9941043
"be associated with the same queue."
9951044
)
1045+
9961046
if len(inds) > 1:
997-
ind_dt = dpt.result_type(*inds)
998-
# ind arrays have been checked to be of integer dtype
999-
if ind_dt.kind not in "ui":
1000-
raise ValueError(
1001-
"cannot safely promote indices to an integer data type"
1002-
)
1003-
inds = tuple(
1004-
map(
1005-
lambda ind: (
1006-
ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt)
1007-
),
1008-
inds,
1009-
)
1010-
)
1011-
inds = dpt.broadcast_arrays(*inds)
1047+
inds = _prepare_indices_arrays(inds, exec_q, vals_usm_type)
1048+
10121049
ind0 = inds[0]
10131050
ary_sh = ary.shape
10141051
p_end = p + len(inds)

dpctl/tensor/_slicing.pxi

+45-19
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616

1717
import numbers
18+
from operator import index
1819
from cpython.buffer cimport PyObject_CheckBuffer
1920

2021

@@ -64,7 +65,7 @@ cdef bint _is_integral(object x) except *:
6465
return False
6566
if callable(getattr(x, "__index__", None)):
6667
try:
67-
x.__index__()
68+
index(x)
6869
except (TypeError, ValueError):
6970
return False
7071
return True
@@ -136,7 +137,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
136137
else:
137138
return ((0,) + shape, (0,) + strides, offset, _no_advanced_ind, _no_advanced_pos)
138139
elif _is_integral(ind):
139-
ind = ind.__index__()
140+
ind = index(ind)
140141
new_shape = shape[1:]
141142
new_strides = strides[1:]
142143
is_empty = any(sh_i == 0 for sh_i in new_shape)
@@ -179,10 +180,12 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
179180
if array_streak_started:
180181
array_streak_interrupted = True
181182
elif _is_integral(i):
182-
explicit_index += 1
183183
axes_referenced += 1
184-
if array_streak_started:
185-
array_streak_interrupted = True
184+
if array_streak_started and not array_streak_interrupted:
185+
# integers converted to arrays in this case
186+
array_count += 1
187+
else:
188+
explicit_index += 1
186189
elif isinstance(i, usm_ndarray):
187190
if not seen_arrays_yet:
188191
seen_arrays_yet = True
@@ -229,6 +232,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
229232
advanced_start_pos_set = False
230233
new_offset = offset
231234
is_empty = False
235+
array_streak = False
232236
for i in range(len(ind)):
233237
ind_i = ind[i]
234238
if (ind_i is Ellipsis):
@@ -239,9 +243,13 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
239243
is_empty = True
240244
new_offset = offset
241245
k = k_new
246+
if array_streak:
247+
array_streak = False
242248
elif ind_i is None:
243249
new_shape.append(1)
244250
new_strides.append(0)
251+
if array_streak:
252+
array_streak = False
245253
elif isinstance(ind_i, slice):
246254
k_new = k + 1
247255
sl_start, sl_stop, sl_step = ind_i.indices(shape[k])
@@ -255,26 +263,46 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
255263
is_empty = True
256264
new_offset = offset
257265
k = k_new
266+
if array_streak:
267+
array_streak = False
258268
elif _is_boolean(ind_i):
259269
new_shape.append(1 if ind_i else 0)
260270
new_strides.append(0)
271+
if array_streak:
272+
array_streak = False
261273
elif _is_integral(ind_i):
262-
ind_i = ind_i.__index__()
263-
if 0 <= ind_i < shape[k]:
274+
if array_streak:
275+
if not isinstance(ind_i, usm_ndarray):
276+
ind_i = index(ind_i)
277+
# integer will be converted to an array, still raise if OOB
278+
if not (0 <= ind_i < shape[k] or -shape[k] <= ind_i < 0):
279+
raise IndexError(
280+
("Index {0} is out of range for "
281+
"axes {1} with size {2}").format(ind_i, k, shape[k]))
282+
new_advanced_ind.append(ind_i)
264283
k_new = k + 1
265-
if not is_empty:
266-
new_offset = new_offset + ind_i * strides[k]
267-
k = k_new
268-
elif -shape[k] <= ind_i < 0:
269-
k_new = k + 1
270-
if not is_empty:
271-
new_offset = new_offset + (shape[k] + ind_i) * strides[k]
284+
new_shape.extend(shape[k:k_new])
285+
new_strides.extend(strides[k:k_new])
272286
k = k_new
273287
else:
274-
raise IndexError(
275-
("Index {0} is out of range for "
276-
"axes {1} with size {2}").format(ind_i, k, shape[k]))
288+
ind_i = index(ind_i)
289+
if 0 <= ind_i < shape[k]:
290+
k_new = k + 1
291+
if not is_empty:
292+
new_offset = new_offset + ind_i * strides[k]
293+
k = k_new
294+
elif -shape[k] <= ind_i < 0:
295+
k_new = k + 1
296+
if not is_empty:
297+
new_offset = new_offset + (shape[k] + ind_i) * strides[k]
298+
k = k_new
299+
else:
300+
raise IndexError(
301+
("Index {0} is out of range for "
302+
"axes {1} with size {2}").format(ind_i, k, shape[k]))
277303
elif isinstance(ind_i, usm_ndarray):
304+
if not array_streak:
305+
array_streak = True
278306
if not advanced_start_pos_set:
279307
new_advanced_start_pos = len(new_shape)
280308
advanced_start_pos_set = True
@@ -287,8 +315,6 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
287315
new_shape.extend(shape[k:k_new])
288316
new_strides.extend(strides[k:k_new])
289317
k = k_new
290-
else:
291-
raise IndexError
292318
new_shape.extend(shape[k:])
293319
new_strides.extend(strides[k:])
294320
new_shape_len += len(shape) - k

dpctl/tensor/_usmarray.pyx

+6-5
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,6 @@ cdef void _validate_and_use_stream(object stream, c_dpctl.SyclQueue self_queue)
161161
ev = self_queue.submit_barrier()
162162
stream.submit_barrier(dependent_events=[ev])
163163

164-
165164
cdef class usm_ndarray:
166165
""" usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
167166
offset=0, order="C", buffer_ctor_kwargs=dict(), \
@@ -962,6 +961,8 @@ cdef class usm_ndarray:
962961
return res
963962

964963
from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index
964+
965+
# if len(adv_ind == 1), the (only) element is always an array
965966
if len(adv_ind) == 1 and adv_ind[0].dtype == dpt_bool:
966967
key_ = adv_ind[0]
967968
adv_ind_end_p = key_.ndim + adv_ind_start_p
@@ -979,10 +980,10 @@ cdef class usm_ndarray:
979980
res.flags_ = _copy_writable(res.flags_, self.flags_)
980981
return res
981982

982-
if any(ind.dtype == dpt_bool for ind in adv_ind):
983+
if any((isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool) for ind in adv_ind):
983984
adv_ind_int = list()
984985
for ind in adv_ind:
985-
if ind.dtype == dpt_bool:
986+
if isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool:
986987
adv_ind_int.extend(_nonzero_impl(ind))
987988
else:
988989
adv_ind_int.append(ind)
@@ -1433,10 +1434,10 @@ cdef class usm_ndarray:
14331434
_place_impl(Xv, adv_ind[0], rhs, axis=adv_ind_start_p)
14341435
return
14351436

1436-
if any(ind.dtype == dpt_bool for ind in adv_ind):
1437+
if any((isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool) for ind in adv_ind):
14371438
adv_ind_int = list()
14381439
for ind in adv_ind:
1439-
if ind.dtype == dpt_bool:
1440+
if isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool:
14401441
adv_ind_int.extend(_nonzero_impl(ind))
14411442
else:
14421443
adv_ind_int.append(ind)

0 commit comments

Comments
 (0)