Skip to content

Commit 84390f5

Browse files
authored
Merge pull request #224 from ev-br/2024.12
Add draft support for 2024.12 revision
2 parents d12e561 + 0bcb032 commit 84390f5

File tree

9 files changed

+82
-9
lines changed

9 files changed

+82
-9
lines changed

array_api_compat/common/_aliases.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,36 @@ def cumulative_sum(
292292
)
293293
return res
294294

295+
296+
def cumulative_prod(
297+
x: ndarray,
298+
/,
299+
xp,
300+
*,
301+
axis: Optional[int] = None,
302+
dtype: Optional[Dtype] = None,
303+
include_initial: bool = False,
304+
**kwargs
305+
) -> ndarray:
306+
wrapped_xp = array_namespace(x)
307+
308+
if axis is None:
309+
if x.ndim > 1:
310+
raise ValueError("axis must be specified in cumulative_prod for more than one dimension")
311+
axis = 0
312+
313+
res = xp.cumprod(x, axis=axis, dtype=dtype, **kwargs)
314+
315+
# np.cumprod does not support include_initial
316+
if include_initial:
317+
initial_shape = list(x.shape)
318+
initial_shape[axis] = 1
319+
res = xp.concatenate(
320+
[wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
321+
axis=axis,
322+
)
323+
return res
324+
295325
# The min and max argument names in clip are different and not optional in numpy, and type
296326
# promotion behavior is different.
297327
def clip(
@@ -544,7 +574,7 @@ def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
544574
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
545575
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
546576
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
547-
'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
577+
'std', 'var', 'cumulative_sum', 'cumulative_prod','clip', 'permute_dims',
548578
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
549579
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
550580
'unstack', 'sign']

array_api_compat/cupy/_aliases.py

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
std = get_xp(cp)(_aliases.std)
5050
var = get_xp(cp)(_aliases.var)
5151
cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
52+
cumulative_prod = get_xp(cp)(_aliases.cumulative_prod)
5253
clip = get_xp(cp)(_aliases.clip)
5354
permute_dims = get_xp(cp)(_aliases.permute_dims)
5455
reshape = get_xp(cp)(_aliases.reshape)

array_api_compat/cupy/_info.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def capabilities(self):
101101
"boolean indexing": True,
102102
"data-dependent shapes": True,
103103
# 'max rank' will be part of the 2024.12 standard
104-
# "max rank": 64,
104+
"max dimensions": 64,
105105
}
106106

107107
def default_device(self):

array_api_compat/dask/array/_aliases.py

+1
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def arange(
121121
std = get_xp(da)(_aliases.std)
122122
var = get_xp(da)(_aliases.var)
123123
cumulative_sum = get_xp(da)(_aliases.cumulative_sum)
124+
cumulative_prod = get_xp(da)(_aliases.cumulative_prod)
124125
empty = get_xp(da)(_aliases.empty)
125126
empty_like = get_xp(da)(_aliases.empty_like)
126127
full = get_xp(da)(_aliases.full)

array_api_compat/dask/array/_info.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def capabilities(self):
102102
"boolean indexing": False,
103103
"data-dependent shapes": False,
104104
# 'max rank' will be part of the 2024.12 standard
105-
# "max rank": 64,
105+
"max dimensions": 64,
106106
}
107107

108108
def default_device(self):

array_api_compat/numpy/_aliases.py

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
std = get_xp(np)(_aliases.std)
5050
var = get_xp(np)(_aliases.var)
5151
cumulative_sum = get_xp(np)(_aliases.cumulative_sum)
52+
cumulative_prod = get_xp(np)(_aliases.cumulative_prod)
5253
clip = get_xp(np)(_aliases.clip)
5354
permute_dims = get_xp(np)(_aliases.permute_dims)
5455
reshape = get_xp(np)(_aliases.reshape)

array_api_compat/numpy/_info.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def capabilities(self):
101101
"boolean indexing": True,
102102
"data-dependent shapes": True,
103103
# 'max rank' will be part of the 2024.12 standard
104-
# "max rank": 64,
104+
"max dimensions": 64,
105105
}
106106

107107
def default_device(self):

array_api_compat/torch/_aliases.py

+44-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
clip as _aliases_clip,
99
unstack as _aliases_unstack,
1010
cumulative_sum as _aliases_cumulative_sum,
11+
cumulative_prod as _aliases_cumulative_prod,
1112
)
1213
from .._internal import get_xp
1314

@@ -124,7 +125,11 @@ def _fix_promotion(x1, x2, only_scalar=True):
124125
x1 = x1.to(dtype)
125126
return x1, x2
126127

127-
def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
128+
129+
_py_scalars = (bool, int, float, complex)
130+
131+
132+
def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype:
128133
if len(arrays_and_dtypes) == 0:
129134
raise TypeError("At least one array or dtype must be provided")
130135
if len(arrays_and_dtypes) == 1:
@@ -136,6 +141,9 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
136141
return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))
137142

138143
x, y = arrays_and_dtypes
144+
if isinstance(x, _py_scalars) or isinstance(y, _py_scalars):
145+
return torch.result_type(x, y)
146+
139147
xdt = x.dtype if not isinstance(x, torch.dtype) else x
140148
ydt = y.dtype if not isinstance(y, torch.dtype) else y
141149

@@ -210,6 +218,7 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
210218
clip = get_xp(torch)(_aliases_clip)
211219
unstack = get_xp(torch)(_aliases_unstack)
212220
cumulative_sum = get_xp(torch)(_aliases_cumulative_sum)
221+
cumulative_prod = get_xp(torch)(_aliases_cumulative_prod)
213222

214223
# torch.sort also returns a tuple
215224
# https://github.com/pytorch/pytorch/issues/70921
@@ -504,6 +513,31 @@ def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
504513
raise ValueError("nonzero() does not support zero-dimensional arrays")
505514
return torch.nonzero(x, as_tuple=True, **kwargs)
506515

516+
517+
# torch uses `dim` instead of `axis`
518+
def diff(
519+
x: array,
520+
/,
521+
*,
522+
axis: int = -1,
523+
n: int = 1,
524+
prepend: Optional[array] = None,
525+
append: Optional[array] = None,
526+
) -> array:
527+
return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append)
528+
529+
530+
# torch uses `dim` instead of `axis`
531+
def count_nonzero(
532+
x: array,
533+
/,
534+
*,
535+
axis: Optional[Union[int, Tuple[int, ...]]] = None,
536+
keepdims: bool = False,
537+
) -> array:
538+
return torch.count_nonzero(x, dim=axis, keepdims=keepdims)
539+
540+
507541
def where(condition: array, x1: array, x2: array, /) -> array:
508542
x1, x2 = _fix_promotion(x1, x2)
509543
return torch.where(condition, x1, x2)
@@ -734,6 +768,11 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
734768
axis = 0
735769
return torch.index_select(x, axis, indices, **kwargs)
736770

771+
772+
def take_along_axis(x: array, indices: array, /, *, axis: int = -1) -> array:
773+
return torch.take_along_dim(x, indices, dim=axis)
774+
775+
737776
def sign(x: array, /) -> array:
738777
# torch sign() does not support complex numbers and does not propagate
739778
# nans. See https://github.com/data-apis/array-api-compat/issues/136
@@ -752,18 +791,19 @@ def sign(x: array, /) -> array:
752791
__all__ = ['__array_namespace_info__', 'result_type', 'can_cast',
753792
'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
754793
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
755-
'bitwise_right_shift', 'bitwise_xor', 'copysign', 'divide',
794+
'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero',
795+
'diff', 'divide',
756796
'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot',
757797
'less', 'less_equal', 'logaddexp', 'maximum', 'minimum',
758798
'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max',
759-
'min', 'clip', 'unstack', 'cumulative_sum', 'sort', 'prod', 'sum',
799+
'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod', 'sort', 'prod', 'sum',
760800
'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze',
761801
'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape',
762802
'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty',
763803
'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays',
764804
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
765805
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
766806
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
767-
'take', 'sign']
807+
'take', 'take_along_axis', 'sign']
768808

769809
_all_ignore = ['torch', 'get_xp']

array_api_compat/torch/_info.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def capabilities(self):
8686
"boolean indexing": True,
8787
"data-dependent shapes": True,
8888
# 'max rank' will be part of the 2024.12 standard
89-
# "max rank": 64,
89+
"max dimensions": 64,
9090
}
9191

9292
def default_device(self):

0 commit comments

Comments
 (0)