Skip to content

Commit 48839ac

Browse files
committed
ENH: fancy indexing __setitem__ is not allowed
1 parent 24778de commit 48839ac

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

array_api_strict/_array_object.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def _normalize_two_args(x1, x2) -> Tuple[Array, Array]:
321321

322322
# Note: A large fraction of allowed indices are disallowed here (see the
323323
# docstring below)
324-
def _validate_index(self, key):
324+
def _validate_index(self, key, op="getitem"):
325325
"""
326326
Validate an index according to the array API.
327327
@@ -384,6 +384,9 @@ def _validate_index(self, key):
384384
"zero-dimensional integer arrays and boolean arrays "
385385
"are specified in the Array API."
386386
)
387+
if op == "setitem":
388+
if isinstance(i, Array) and i.dtype in _integer_dtypes:
389+
raise IndexError("Fancy indexing __setitem__ is not supported.")
387390

388391
nonexpanding_key = []
389392
single_axes = []
@@ -908,7 +911,7 @@ def __setitem__(
908911
"""
909912
# Note: Only indices required by the spec are allowed. See the
910913
# docstring of _validate_index
911-
self._validate_index(key)
914+
self._validate_index(key, op="setitem")
912915
if isinstance(key, Array):
913916
# Indexing self._array with array_api_strict arrays can be erroneous
914917
key = key._array

array_api_strict/tests/test_array_object.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,9 @@ def test_indexing_arrays():
117117
a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])])
118118
assert all(a_idx == a_idx_loop)
119119

120-
# setitem with arrays is not allowed # XXX
121-
# with assert_raises(IndexError):
122-
# a[idx] = 42
120+
# setitem with arrays is not allowed
121+
with assert_raises(IndexError):
122+
a[idx] = 42
123123

124124
# mixed array and integer indexing
125125
a = reshape(arange(3*4), (3, 4))
@@ -129,12 +129,15 @@ def test_indexing_arrays():
129129
a_idx_loop = asarray([a[idx[i], 1] for i in range(idx.shape[0])])
130130
assert all(a_idx == a_idx_loop)
131131

132-
133132
# index with two arrays
134133
a_idx = a[idx, idx]
135134
a_idx_loop = asarray([a[idx[i], idx[i]] for i in range(idx.shape[0])])
136135
assert all(a_idx == a_idx_loop)
137136

137+
# setitem with arrays is not allowed
138+
with assert_raises(IndexError):
139+
a[idx, idx] = 42
140+
138141

139142
def test_promoted_scalar_inherits_device():
140143
device1 = Device("device1")

0 commit comments

Comments
 (0)