8
8
clip as _aliases_clip ,
9
9
unstack as _aliases_unstack ,
10
10
cumulative_sum as _aliases_cumulative_sum ,
11
+ cumulative_prod as _aliases_cumulative_prod ,
11
12
)
12
13
from .._internal import get_xp
13
14
@@ -124,7 +125,11 @@ def _fix_promotion(x1, x2, only_scalar=True):
124
125
x1 = x1 .to (dtype )
125
126
return x1 , x2
126
127
127
- def result_type (* arrays_and_dtypes : Union [array , Dtype ]) -> Dtype :
128
+
129
+ _py_scalars = (bool , int , float , complex )
130
+
131
+
132
+ def result_type (* arrays_and_dtypes : Union [array , Dtype , bool , int , float , complex ]) -> Dtype :
128
133
if len (arrays_and_dtypes ) == 0 :
129
134
raise TypeError ("At least one array or dtype must be provided" )
130
135
if len (arrays_and_dtypes ) == 1 :
@@ -136,6 +141,9 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
136
141
return result_type (arrays_and_dtypes [0 ], result_type (* arrays_and_dtypes [1 :]))
137
142
138
143
x , y = arrays_and_dtypes
144
+ if isinstance (x , _py_scalars ) or isinstance (y , _py_scalars ):
145
+ return torch .result_type (x , y )
146
+
139
147
xdt = x .dtype if not isinstance (x , torch .dtype ) else x
140
148
ydt = y .dtype if not isinstance (y , torch .dtype ) else y
141
149
@@ -210,6 +218,7 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
210
218
clip = get_xp (torch )(_aliases_clip )
211
219
unstack = get_xp (torch )(_aliases_unstack )
212
220
cumulative_sum = get_xp (torch )(_aliases_cumulative_sum )
221
+ cumulative_prod = get_xp (torch )(_aliases_cumulative_prod )
213
222
214
223
# torch.sort also returns a tuple
215
224
# https://github.com/pytorch/pytorch/issues/70921
@@ -504,6 +513,31 @@ def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
504
513
raise ValueError ("nonzero() does not support zero-dimensional arrays" )
505
514
return torch .nonzero (x , as_tuple = True , ** kwargs )
506
515
516
+
517
+ # torch uses `dim` instead of `axis`
518
+ def diff (
519
+ x : array ,
520
+ / ,
521
+ * ,
522
+ axis : int = - 1 ,
523
+ n : int = 1 ,
524
+ prepend : Optional [array ] = None ,
525
+ append : Optional [array ] = None ,
526
+ ) -> array :
527
+ return torch .diff (x , dim = axis , n = n , prepend = prepend , append = append )
528
+
529
+
530
+ # torch uses `dim` instead of `axis`
531
+ def count_nonzero (
532
+ x : array ,
533
+ / ,
534
+ * ,
535
+ axis : Optional [Union [int , Tuple [int , ...]]] = None ,
536
+ keepdims : bool = False ,
537
+ ) -> array :
538
+ return torch .count_nonzero (x , dim = axis , keepdims = keepdims )
539
+
540
+
507
541
def where (condition : array , x1 : array , x2 : array , / ) -> array :
508
542
x1 , x2 = _fix_promotion (x1 , x2 )
509
543
return torch .where (condition , x1 , x2 )
@@ -734,6 +768,11 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
734
768
axis = 0
735
769
return torch .index_select (x , axis , indices , ** kwargs )
736
770
771
+
772
+ def take_along_axis (x : array , indices : array , / , * , axis : int = - 1 ) -> array :
773
+ return torch .take_along_dim (x , indices , dim = axis )
774
+
775
+
737
776
def sign (x : array , / ) -> array :
738
777
# torch sign() does not support complex numbers and does not propagate
739
778
# nans. See https://github.com/data-apis/array-api-compat/issues/136
@@ -752,18 +791,19 @@ def sign(x: array, /) -> array:
752
791
__all__ = ['__array_namespace_info__' , 'result_type' , 'can_cast' ,
753
792
'permute_dims' , 'bitwise_invert' , 'newaxis' , 'conj' , 'add' ,
754
793
'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
755
- 'bitwise_right_shift' , 'bitwise_xor' , 'copysign' , 'divide' ,
794
+ 'bitwise_right_shift' , 'bitwise_xor' , 'copysign' , 'count_nonzero' ,
795
+ 'diff' , 'divide' ,
756
796
'equal' , 'floor_divide' , 'greater' , 'greater_equal' , 'hypot' ,
757
797
'less' , 'less_equal' , 'logaddexp' , 'maximum' , 'minimum' ,
758
798
'multiply' , 'not_equal' , 'pow' , 'remainder' , 'subtract' , 'max' ,
759
- 'min' , 'clip' , 'unstack' , 'cumulative_sum' , 'sort' , 'prod' , 'sum' ,
799
+ 'min' , 'clip' , 'unstack' , 'cumulative_sum' , 'cumulative_prod' , ' sort' , 'prod' , 'sum' ,
760
800
'any' , 'all' , 'mean' , 'std' , 'var' , 'concat' , 'squeeze' ,
761
801
'broadcast_to' , 'flip' , 'roll' , 'nonzero' , 'where' , 'reshape' ,
762
802
'arange' , 'eye' , 'linspace' , 'full' , 'ones' , 'zeros' , 'empty' ,
763
803
'tril' , 'triu' , 'expand_dims' , 'astype' , 'broadcast_arrays' ,
764
804
'UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
765
805
'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
766
806
'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' , 'isdtype' ,
767
- 'take' , 'sign' ]
807
+ 'take' , 'take_along_axis' , ' sign' ]
768
808
769
809
_all_ignore = ['torch' , 'get_xp' ]
0 commit comments