Skip to content

Commit fdf9489

Browse files
committed
BUG: take_along_axis: numpy requires an axis
1 parent 621494b commit fdf9489

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

array_api_compat/numpy/_aliases.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,13 @@ def count_nonzero(x: Array, axis=None, keepdims=False) -> Array:
134134
return result
135135

136136

137+
# "axis=-1" is an optional argument of `take_along_axis` but numpy has no default
138+
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
139+
if axis is None:
140+
axis = -1
141+
return np.take_along_axis(x, indices, axis=axis)
142+
143+
137144
# These functions are completely new here. If the library already has them
138145
# (i.e., numpy 2.0), use the library version instead of our wrapper.
139146
if hasattr(np, 'vecdot'):
@@ -155,6 +162,7 @@ def count_nonzero(x: Array, axis=None, keepdims=False) -> Array:
155162
'acos', 'acosh', 'asin', 'asinh', 'atan',
156163
'atan2', 'atanh', 'bitwise_left_shift',
157164
'bitwise_invert', 'bitwise_right_shift',
158-
'bool', 'concat', 'count_nonzero', 'pow']
165+
'bool', 'concat', 'count_nonzero', 'pow',
166+
'take_along_axis']
159167

160168
_all_ignore = ['np', 'get_xp']

0 commit comments

Comments
 (0)