Skip to content

Commit 1d375d5

Browse files
committed
Add common utilities for _put_multi_index and _take_multi_index
Reduces code duplication
1 parent 916b548 commit 1d375d5

File tree

1 file changed

+84
-94
lines changed

1 file changed

+84
-94
lines changed

dpctl/tensor/_copy_utils.py

+84-94
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,79 @@ def _nonzero_impl(ary):
800800
return res
801801

802802

803+
def _validate_indices(inds, queue_list, usm_type_list):
804+
"""
805+
Utility for validating indices are usm_ndarray of integral dtype or Python
806+
integers. At least one must be an array.
807+
808+
For each array, the queue and usm type are appended to `queue_list` and
809+
`usm_type_list`, respectively.
810+
"""
811+
any_usmarray = False
812+
for ind in inds:
813+
if isinstance(ind, dpt.usm_ndarray):
814+
any_usmarray = True
815+
if ind.dtype.kind not in "ui":
816+
raise IndexError(
817+
"arrays used as indices must be of integer (or boolean) "
818+
"type"
819+
)
820+
queue_list.append(ind.sycl_queue)
821+
usm_type_list.append(ind.usm_type)
822+
elif not isinstance(ind, Integral):
823+
raise TypeError(
824+
"all elements of `ind` expected to be usm_ndarrays "
825+
f"or integers, found {type(ind)}"
826+
)
827+
if not any_usmarray:
828+
raise TypeError(
829+
"at least one element of `inds` expected to be a usm_ndarray"
830+
)
831+
return inds
832+
833+
834+
def _prepare_indices_arrays(inds, q, usm_type):
835+
"""
836+
Utility taking a mix of usm_ndarray and possibly Python int scalar indices,
837+
a queue (assumed to be common to arrays in inds), and a usm type.
838+
839+
Python scalar integers are promoted to arrays on the provided queue and
840+
with the provided usm type. All arrays are then promoted to a common
841+
integral type (if possible) before being broadcast to a common shape.
842+
"""
843+
# scalar integers -> arrays
844+
inds = tuple(
845+
map(
846+
lambda ind: (
847+
ind
848+
if isinstance(ind, dpt.usm_ndarray)
849+
else dpt.asarray(ind, usm_type=usm_type, sycl_queue=q)
850+
),
851+
inds,
852+
)
853+
)
854+
855+
# promote to a common integral type if possible
856+
ind_dt = dpt.result_type(*inds)
857+
if ind_dt.kind not in "ui":
858+
raise ValueError(
859+
"cannot safely promote indices to an integer data type"
860+
)
861+
inds = tuple(
862+
map(
863+
lambda ind: (
864+
ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt)
865+
),
866+
inds,
867+
)
868+
)
869+
870+
# broadcast
871+
inds = dpt.broadcast_arrays(*inds)
872+
873+
return inds
874+
875+
803876
def _take_multi_index(ary, inds, p, mode=0):
804877
if not isinstance(ary, dpt.usm_ndarray):
805878
raise TypeError(
@@ -820,26 +893,8 @@ def _take_multi_index(ary, inds, p, mode=0):
820893
]
821894
if not isinstance(inds, (list, tuple)):
822895
inds = (inds,)
823-
any_usmarray = False
824-
for ind in inds:
825-
if isinstance(ind, dpt.usm_ndarray):
826-
any_usmarray = True
827-
if ind.dtype.kind not in "ui":
828-
raise IndexError(
829-
"arrays used as indices must be of integer (or boolean) "
830-
"type"
831-
)
832-
queues_.append(ind.sycl_queue)
833-
usm_types_.append(ind.usm_type)
834-
elif not isinstance(ind, Integral):
835-
raise TypeError(
836-
"all elements of `ind` expected to be usm_ndarrays "
837-
"or integers"
838-
)
839-
if not any_usmarray:
840-
raise TypeError(
841-
"at least one element of `ind` expected to be a usm_ndarray"
842-
)
896+
897+
_validate_indices(inds, queues_, usm_types_)
843898
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
844899
exec_q = dpctl.utils.get_execution_queue(queues_)
845900
if exec_q is None:
@@ -849,34 +904,10 @@ def _take_multi_index(ary, inds, p, mode=0):
849904
"Use `usm_ndarray.to_device` method to migrate data to "
850905
"be associated with the same queue."
851906
)
907+
852908
if len(inds) > 1:
853-
inds = tuple(
854-
map(
855-
lambda ind: (
856-
ind
857-
if isinstance(ind, dpt.usm_ndarray)
858-
else dpt.asarray(
859-
ind, usm_type=res_usm_type, sycl_queue=exec_q
860-
)
861-
),
862-
inds,
863-
)
864-
)
865-
ind_dt = dpt.result_type(*inds)
866-
# ind arrays have been checked to be of integer dtype
867-
if ind_dt.kind not in "ui":
868-
raise ValueError(
869-
"cannot safely promote indices to an integer data type"
870-
)
871-
inds = tuple(
872-
map(
873-
lambda ind: (
874-
ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt)
875-
),
876-
inds,
877-
)
878-
)
879-
inds = dpt.broadcast_arrays(*inds)
909+
inds = _prepare_indices_arrays(inds, exec_q, res_usm_type)
910+
880911
ind0 = inds[0]
881912
ary_sh = ary.shape
882913
p_end = p + len(inds)
@@ -992,26 +1023,9 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
9921023
]
9931024
if not isinstance(inds, (list, tuple)):
9941025
inds = (inds,)
995-
any_usmarray = False
996-
for ind in inds:
997-
if isinstance(ind, dpt.usm_ndarray):
998-
any_usmarray = True
999-
if ind.dtype.kind not in "ui":
1000-
raise IndexError(
1001-
"arrays used as indices must be of integer (or boolean) "
1002-
"type"
1003-
)
1004-
queues_.append(ind.sycl_queue)
1005-
usm_types_.append(ind.usm_type)
1006-
elif not isinstance(ind, Integral):
1007-
raise TypeError(
1008-
"all elements of `ind` expected to be usm_ndarrays "
1009-
"or integers"
1010-
)
1011-
if not any_usmarray:
1012-
raise TypeError(
1013-
"at least one element of `ind` expected to be a usm_ndarray"
1014-
)
1026+
1027+
_validate_indices(inds, queues_, usm_types_)
1028+
10151029
vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
10161030
exec_q = dpctl.utils.get_execution_queue(queues_)
10171031
if exec_q is not None:
@@ -1028,34 +1042,10 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
10281042
"Use `usm_ndarray.to_device` method to migrate data to "
10291043
"be associated with the same queue."
10301044
)
1045+
10311046
if len(inds) > 1:
1032-
inds = tuple(
1033-
map(
1034-
lambda ind: (
1035-
ind
1036-
if isinstance(ind, dpt.usm_ndarray)
1037-
else dpt.asarray(
1038-
ind, usm_type=vals_usm_type, sycl_queue=exec_q
1039-
)
1040-
),
1041-
inds,
1042-
)
1043-
)
1044-
ind_dt = dpt.result_type(*inds)
1045-
# ind arrays have been checked to be of integer dtype
1046-
if ind_dt.kind not in "ui":
1047-
raise ValueError(
1048-
"cannot safely promote indices to an integer data type"
1049-
)
1050-
inds = tuple(
1051-
map(
1052-
lambda ind: (
1053-
ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt)
1054-
),
1055-
inds,
1056-
)
1057-
)
1058-
inds = dpt.broadcast_arrays(*inds)
1047+
inds = _prepare_indices_arrays(inds, exec_q, vals_usm_type)
1048+
10591049
ind0 = inds[0]
10601050
ary_sh = ary.shape
10611051
p_end = p + len(inds)

0 commit comments

Comments
 (0)