15
15
# limitations under the License.
16
16
import builtins
17
17
import operator
18
+ from numbers import Integral
18
19
19
20
import numpy as np
20
21
@@ -799,6 +800,79 @@ def _nonzero_impl(ary):
799
800
return res
800
801
801
802
803
+ def _validate_indices (inds , queue_list , usm_type_list ):
804
+ """
805
+ Utility for validating indices are usm_ndarray of integral dtype or Python
806
+ integers. At least one must be an array.
807
+
808
+ For each array, the queue and usm type are appended to `queue_list` and
809
+ `usm_type_list`, respectively.
810
+ """
811
+ any_usmarray = False
812
+ for ind in inds :
813
+ if isinstance (ind , dpt .usm_ndarray ):
814
+ any_usmarray = True
815
+ if ind .dtype .kind not in "ui" :
816
+ raise IndexError (
817
+ "arrays used as indices must be of integer (or boolean) "
818
+ "type"
819
+ )
820
+ queue_list .append (ind .sycl_queue )
821
+ usm_type_list .append (ind .usm_type )
822
+ elif not isinstance (ind , Integral ):
823
+ raise TypeError (
824
+ "all elements of `ind` expected to be usm_ndarrays "
825
+ f"or integers, found { type (ind )} "
826
+ )
827
+ if not any_usmarray :
828
+ raise TypeError (
829
+ "at least one element of `inds` expected to be a usm_ndarray"
830
+ )
831
+ return inds
832
+
833
+
834
+ def _prepare_indices_arrays (inds , q , usm_type ):
835
+ """
836
+ Utility taking a mix of usm_ndarray and possibly Python int scalar indices,
837
+ a queue (assumed to be common to arrays in inds), and a usm type.
838
+
839
+ Python scalar integers are promoted to arrays on the provided queue and
840
+ with the provided usm type. All arrays are then promoted to a common
841
+ integral type (if possible) before being broadcast to a common shape.
842
+ """
843
+ # scalar integers -> arrays
844
+ inds = tuple (
845
+ map (
846
+ lambda ind : (
847
+ ind
848
+ if isinstance (ind , dpt .usm_ndarray )
849
+ else dpt .asarray (ind , usm_type = usm_type , sycl_queue = q )
850
+ ),
851
+ inds ,
852
+ )
853
+ )
854
+
855
+ # promote to a common integral type if possible
856
+ ind_dt = dpt .result_type (* inds )
857
+ if ind_dt .kind not in "ui" :
858
+ raise ValueError (
859
+ "cannot safely promote indices to an integer data type"
860
+ )
861
+ inds = tuple (
862
+ map (
863
+ lambda ind : (
864
+ ind if ind .dtype == ind_dt else dpt .astype (ind , ind_dt )
865
+ ),
866
+ inds ,
867
+ )
868
+ )
869
+
870
+ # broadcast
871
+ inds = dpt .broadcast_arrays (* inds )
872
+
873
+ return inds
874
+
875
+
802
876
def _take_multi_index (ary , inds , p , mode = 0 ):
803
877
if not isinstance (ary , dpt .usm_ndarray ):
804
878
raise TypeError (
@@ -819,15 +893,8 @@ def _take_multi_index(ary, inds, p, mode=0):
819
893
]
820
894
if not isinstance (inds , (list , tuple )):
821
895
inds = (inds ,)
822
- for ind in inds :
823
- if not isinstance (ind , dpt .usm_ndarray ):
824
- raise TypeError ("all elements of `ind` expected to be usm_ndarrays" )
825
- queues_ .append (ind .sycl_queue )
826
- usm_types_ .append (ind .usm_type )
827
- if ind .dtype .kind not in "ui" :
828
- raise IndexError (
829
- "arrays used as indices must be of integer (or boolean) type"
830
- )
896
+
897
+ _validate_indices (inds , queues_ , usm_types_ )
831
898
res_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
832
899
exec_q = dpctl .utils .get_execution_queue (queues_ )
833
900
if exec_q is None :
@@ -837,22 +904,10 @@ def _take_multi_index(ary, inds, p, mode=0):
837
904
"Use `usm_ndarray.to_device` method to migrate data to "
838
905
"be associated with the same queue."
839
906
)
907
+
840
908
if len (inds ) > 1 :
841
- ind_dt = dpt .result_type (* inds )
842
- # ind arrays have been checked to be of integer dtype
843
- if ind_dt .kind not in "ui" :
844
- raise ValueError (
845
- "cannot safely promote indices to an integer data type"
846
- )
847
- inds = tuple (
848
- map (
849
- lambda ind : (
850
- ind if ind .dtype == ind_dt else dpt .astype (ind , ind_dt )
851
- ),
852
- inds ,
853
- )
854
- )
855
- inds = dpt .broadcast_arrays (* inds )
909
+ inds = _prepare_indices_arrays (inds , exec_q , res_usm_type )
910
+
856
911
ind0 = inds [0 ]
857
912
ary_sh = ary .shape
858
913
p_end = p + len (inds )
@@ -968,15 +1023,9 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
968
1023
]
969
1024
if not isinstance (inds , (list , tuple )):
970
1025
inds = (inds ,)
971
- for ind in inds :
972
- if not isinstance (ind , dpt .usm_ndarray ):
973
- raise TypeError ("all elements of `ind` expected to be usm_ndarrays" )
974
- queues_ .append (ind .sycl_queue )
975
- usm_types_ .append (ind .usm_type )
976
- if ind .dtype .kind not in "ui" :
977
- raise IndexError (
978
- "arrays used as indices must be of integer (or boolean) type"
979
- )
1026
+
1027
+ _validate_indices (inds , queues_ , usm_types_ )
1028
+
980
1029
vals_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
981
1030
exec_q = dpctl .utils .get_execution_queue (queues_ )
982
1031
if exec_q is not None :
@@ -993,22 +1042,10 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
993
1042
"Use `usm_ndarray.to_device` method to migrate data to "
994
1043
"be associated with the same queue."
995
1044
)
1045
+
996
1046
if len (inds ) > 1 :
997
- ind_dt = dpt .result_type (* inds )
998
- # ind arrays have been checked to be of integer dtype
999
- if ind_dt .kind not in "ui" :
1000
- raise ValueError (
1001
- "cannot safely promote indices to an integer data type"
1002
- )
1003
- inds = tuple (
1004
- map (
1005
- lambda ind : (
1006
- ind if ind .dtype == ind_dt else dpt .astype (ind , ind_dt )
1007
- ),
1008
- inds ,
1009
- )
1010
- )
1011
- inds = dpt .broadcast_arrays (* inds )
1047
+ inds = _prepare_indices_arrays (inds , exec_q , vals_usm_type )
1048
+
1012
1049
ind0 = inds [0 ]
1013
1050
ary_sh = ary .shape
1014
1051
p_end = p + len (inds )
0 commit comments