@@ -134,6 +134,13 @@ def count_nonzero(x: Array, axis=None, keepdims=False) -> Array:
134
134
return result
135
135
136
136
137
+ # "axis=-1" is an optional argument of `take_along_axis` but numpy has no default
138
+ def take_along_axis (x : Array , indices : Array , / , * , axis : int = - 1 ):
139
+ if axis is None :
140
+ axis = - 1
141
+ return np .take_along_axis (x , indices , axis = axis )
142
+
143
+
137
144
# These functions are completely new here. If the library already has them
138
145
# (i.e., numpy 2.0), use the library version instead of our wrapper.
139
146
if hasattr (np , 'vecdot' ):
@@ -155,6 +162,7 @@ def count_nonzero(x: Array, axis=None, keepdims=False) -> Array:
155
162
'acos' , 'acosh' , 'asin' , 'asinh' , 'atan' ,
156
163
'atan2' , 'atanh' , 'bitwise_left_shift' ,
157
164
'bitwise_invert' , 'bitwise_right_shift' ,
158
- 'bool' , 'concat' , 'count_nonzero' , 'pow' ]
165
+ 'bool' , 'concat' , 'count_nonzero' , 'pow' ,
166
+ 'take_along_axis' ]
159
167
160
168
_all_ignore = ['np' , 'get_xp' ]
0 commit comments