@@ -800,6 +800,79 @@ def _nonzero_impl(ary):
800
800
return res
801
801
802
802
803
+ def _validate_indices (inds , queue_list , usm_type_list ):
804
+ """
805
+ Utility for validating indices are usm_ndarray of integral dtype or Python
806
+ integers. At least one must be an array.
807
+
808
+ For each array, the queue and usm type are appended to `queue_list` and
809
+ `usm_type_list`, respectively.
810
+ """
811
+ any_usmarray = False
812
+ for ind in inds :
813
+ if isinstance (ind , dpt .usm_ndarray ):
814
+ any_usmarray = True
815
+ if ind .dtype .kind not in "ui" :
816
+ raise IndexError (
817
+ "arrays used as indices must be of integer (or boolean) "
818
+ "type"
819
+ )
820
+ queue_list .append (ind .sycl_queue )
821
+ usm_type_list .append (ind .usm_type )
822
+ elif not isinstance (ind , Integral ):
823
+ raise TypeError (
824
+ "all elements of `ind` expected to be usm_ndarrays "
825
+ f"or integers, found { type (ind )} "
826
+ )
827
+ if not any_usmarray :
828
+ raise TypeError (
829
+ "at least one element of `inds` expected to be a usm_ndarray"
830
+ )
831
+ return inds
832
+
833
+
834
+ def _prepare_indices_arrays (inds , q , usm_type ):
835
+ """
836
+ Utility taking a mix of usm_ndarray and possibly Python int scalar indices,
837
+ a queue (assumed to be common to arrays in inds), and a usm type.
838
+
839
+ Python scalar integers are promoted to arrays on the provided queue and
840
+ with the provided usm type. All arrays are then promoted to a common
841
+ integral type (if possible) before being broadcast to a common shape.
842
+ """
843
+ # scalar integers -> arrays
844
+ inds = tuple (
845
+ map (
846
+ lambda ind : (
847
+ ind
848
+ if isinstance (ind , dpt .usm_ndarray )
849
+ else dpt .asarray (ind , usm_type = usm_type , sycl_queue = q )
850
+ ),
851
+ inds ,
852
+ )
853
+ )
854
+
855
+ # promote to a common integral type if possible
856
+ ind_dt = dpt .result_type (* inds )
857
+ if ind_dt .kind not in "ui" :
858
+ raise ValueError (
859
+ "cannot safely promote indices to an integer data type"
860
+ )
861
+ inds = tuple (
862
+ map (
863
+ lambda ind : (
864
+ ind if ind .dtype == ind_dt else dpt .astype (ind , ind_dt )
865
+ ),
866
+ inds ,
867
+ )
868
+ )
869
+
870
+ # broadcast
871
+ inds = dpt .broadcast_arrays (* inds )
872
+
873
+ return inds
874
+
875
+
803
876
def _take_multi_index (ary , inds , p , mode = 0 ):
804
877
if not isinstance (ary , dpt .usm_ndarray ):
805
878
raise TypeError (
@@ -820,26 +893,8 @@ def _take_multi_index(ary, inds, p, mode=0):
820
893
]
821
894
if not isinstance (inds , (list , tuple )):
822
895
inds = (inds ,)
823
- any_usmarray = False
824
- for ind in inds :
825
- if isinstance (ind , dpt .usm_ndarray ):
826
- any_usmarray = True
827
- if ind .dtype .kind not in "ui" :
828
- raise IndexError (
829
- "arrays used as indices must be of integer (or boolean) "
830
- "type"
831
- )
832
- queues_ .append (ind .sycl_queue )
833
- usm_types_ .append (ind .usm_type )
834
- elif not isinstance (ind , Integral ):
835
- raise TypeError (
836
- "all elements of `ind` expected to be usm_ndarrays "
837
- "or integers"
838
- )
839
- if not any_usmarray :
840
- raise TypeError (
841
- "at least one element of `ind` expected to be a usm_ndarray"
842
- )
896
+
897
+ _validate_indices (inds , queues_ , usm_types_ )
843
898
res_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
844
899
exec_q = dpctl .utils .get_execution_queue (queues_ )
845
900
if exec_q is None :
@@ -849,34 +904,10 @@ def _take_multi_index(ary, inds, p, mode=0):
849
904
"Use `usm_ndarray.to_device` method to migrate data to "
850
905
"be associated with the same queue."
851
906
)
907
+
852
908
if len (inds ) > 1 :
853
- inds = tuple (
854
- map (
855
- lambda ind : (
856
- ind
857
- if isinstance (ind , dpt .usm_ndarray )
858
- else dpt .asarray (
859
- ind , usm_type = res_usm_type , sycl_queue = exec_q
860
- )
861
- ),
862
- inds ,
863
- )
864
- )
865
- ind_dt = dpt .result_type (* inds )
866
- # ind arrays have been checked to be of integer dtype
867
- if ind_dt .kind not in "ui" :
868
- raise ValueError (
869
- "cannot safely promote indices to an integer data type"
870
- )
871
- inds = tuple (
872
- map (
873
- lambda ind : (
874
- ind if ind .dtype == ind_dt else dpt .astype (ind , ind_dt )
875
- ),
876
- inds ,
877
- )
878
- )
879
- inds = dpt .broadcast_arrays (* inds )
909
+ inds = _prepare_indices_arrays (inds , exec_q , res_usm_type )
910
+
880
911
ind0 = inds [0 ]
881
912
ary_sh = ary .shape
882
913
p_end = p + len (inds )
@@ -992,26 +1023,9 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
992
1023
]
993
1024
if not isinstance (inds , (list , tuple )):
994
1025
inds = (inds ,)
995
- any_usmarray = False
996
- for ind in inds :
997
- if isinstance (ind , dpt .usm_ndarray ):
998
- any_usmarray = True
999
- if ind .dtype .kind not in "ui" :
1000
- raise IndexError (
1001
- "arrays used as indices must be of integer (or boolean) "
1002
- "type"
1003
- )
1004
- queues_ .append (ind .sycl_queue )
1005
- usm_types_ .append (ind .usm_type )
1006
- elif not isinstance (ind , Integral ):
1007
- raise TypeError (
1008
- "all elements of `ind` expected to be usm_ndarrays "
1009
- "or integers"
1010
- )
1011
- if not any_usmarray :
1012
- raise TypeError (
1013
- "at least one element of `ind` expected to be a usm_ndarray"
1014
- )
1026
+
1027
+ _validate_indices (inds , queues_ , usm_types_ )
1028
+
1015
1029
vals_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
1016
1030
exec_q = dpctl .utils .get_execution_queue (queues_ )
1017
1031
if exec_q is not None :
@@ -1028,34 +1042,10 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
1028
1042
"Use `usm_ndarray.to_device` method to migrate data to "
1029
1043
"be associated with the same queue."
1030
1044
)
1045
+
1031
1046
if len (inds ) > 1 :
1032
- inds = tuple (
1033
- map (
1034
- lambda ind : (
1035
- ind
1036
- if isinstance (ind , dpt .usm_ndarray )
1037
- else dpt .asarray (
1038
- ind , usm_type = vals_usm_type , sycl_queue = exec_q
1039
- )
1040
- ),
1041
- inds ,
1042
- )
1043
- )
1044
- ind_dt = dpt .result_type (* inds )
1045
- # ind arrays have been checked to be of integer dtype
1046
- if ind_dt .kind not in "ui" :
1047
- raise ValueError (
1048
- "cannot safely promote indices to an integer data type"
1049
- )
1050
- inds = tuple (
1051
- map (
1052
- lambda ind : (
1053
- ind if ind .dtype == ind_dt else dpt .astype (ind , ind_dt )
1054
- ),
1055
- inds ,
1056
- )
1057
- )
1058
- inds = dpt .broadcast_arrays (* inds )
1047
+ inds = _prepare_indices_arrays (inds , exec_q , vals_usm_type )
1048
+
1059
1049
ind0 = inds [0 ]
1060
1050
ary_sh = ary .shape
1061
1051
p_end = p + len (inds )
0 commit comments