Skip to content

Commit 24778de

Browse files
committed
ENH: allow 1D integer array indices
1 parent 1a4fecb commit 24778de

File tree

2 files changed

+64
-7
lines changed

2 files changed

+64
-7
lines changed

array_api_strict/_array_object.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,8 @@ def _validate_index(self, key):
389389
single_axes = []
390390
n_ellipsis = 0
391391
key_has_mask = False
392+
key_has_index_array = False
393+
key_has_slices = False
392394
for i in _key:
393395
if i is not None:
394396
nonexpanding_key.append(i)
@@ -397,13 +399,17 @@ def _validate_index(self, key):
397399
if isinstance(i, Array):
398400
if i.dtype in _boolean_dtypes:
399401
key_has_mask = True
402+
elif i.dtype in _integer_dtypes:
403+
key_has_index_array = True
400404
single_axes.append(i)
401405
else:
402406
# i must not be an array here, to avoid elementwise equals
403407
if i == Ellipsis:
404408
n_ellipsis += 1
405409
else:
406410
single_axes.append(i)
411+
if isinstance(i, slice):
412+
key_has_slices = True
407413

408414
n_single_axes = len(single_axes)
409415
if n_ellipsis > 1:
@@ -421,6 +427,12 @@ def _validate_index(self, key):
421427
"specified in the Array API."
422428
)
423429

430+
if (key_has_index_array and (n_ellipsis > 0 or key_has_slices or key_has_mask)):
431+
raise IndexError(
432+
"Integer index arrays are only allowed with integer indices; "
433+
f"got {key}."
434+
)
435+
424436
if n_ellipsis == 0:
425437
indexed_shape = self.shape
426438
else:
@@ -479,11 +491,11 @@ def _validate_index(self, key):
479491
if not get_array_api_strict_flags()['boolean_indexing']:
480492
raise RuntimeError("The boolean_indexing flag has been disabled for array-api-strict")
481493

482-
elif i.dtype in _integer_dtypes and i.ndim != 0:
494+
elif i.dtype in _integer_dtypes and i.ndim > 1:
483495
raise IndexError(
484-
f"Single-axes index {i} is a non-zero-dimensional "
485-
"integer array, but advanced integer indexing is not "
486-
"specified in the Array API."
496+
f"Single-axes index {i} is a multi-dimensional "
497+
"integer array, but advanced integer indexing is only "
498+
"specified in the Array API for 1D index arrays."
487499
)
488500
elif isinstance(i, tuple):
489501
raise IndexError(

array_api_strict/tests/test_array_object.py

+48-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
import pytest
77

8-
from .. import ones, asarray, result_type, all, equal
8+
from .. import ones, arange, reshape, asarray, result_type, all, equal
99
from .._array_object import Array, CPU_DEVICE, Device
1010
from .._dtypes import (
1111
_all_dtypes,
@@ -70,11 +70,25 @@ def test_validate_index():
7070
assert_raises(IndexError, lambda: a[[True, True, True]])
7171
assert_raises(IndexError, lambda: a[(True, True, True),])
7272

73-
# Integer array indices are not allowed (except for 0-D)
74-
idx = asarray([0, 1])
73+
# Integer array indices are not allowed (except for 0-D or 1D)
74+
idx = asarray([[0, 1]]) # idx.ndim == 2
7575
assert_raises(IndexError, lambda: a[idx, 0])
7676
assert_raises(IndexError, lambda: a[0, idx])
7777

78+
# Mixing 1D integer array indices with slices, ellipsis or booleans is not allowed
79+
idx = asarray([0, 1])
80+
assert_raises(IndexError, lambda: a[..., idx])
81+
assert_raises(IndexError, lambda: a[:, idx])
82+
assert_raises(IndexError, lambda: a[asarray([True, True]), idx])
83+
84+
# 1D integer array indices must have the same length
85+
idx1 = asarray([0, 1])
86+
idx2 = asarray([0, 1, 1])
87+
assert_raises(IndexError, lambda: a[idx1, idx2])
88+
89+
# Non-integer array indices are not allowed
90+
assert_raises(IndexError, lambda: a[ones(2), 0])
91+
7892
# Array-likes (lists, tuples) are not allowed as indices
7993
assert_raises(IndexError, lambda: a[[0, 1]])
8094
assert_raises(IndexError, lambda: a[(0, 1), (0, 1)])
@@ -91,6 +105,37 @@ def test_validate_index():
91105
assert_raises(IndexError, lambda: a[:])
92106
assert_raises(IndexError, lambda: a[idx])
93107

108+
109+
def test_indexing_arrays():
110+
# indexing with 1D integer arrays and mixes of integers and 1D integer are allowed
111+
112+
# 1D array
113+
a = arange(5)
114+
idx = asarray([1, 0, 1, 2, -1])
115+
a_idx = a[idx]
116+
117+
a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])])
118+
assert all(a_idx == a_idx_loop)
119+
120+
# setitem with arrays is not allowed # XXX
121+
# with assert_raises(IndexError):
122+
# a[idx] = 42
123+
124+
# mixed array and integer indexing
125+
a = reshape(arange(3*4), (3, 4))
126+
idx = asarray([1, 0, 1, 2, -1])
127+
a_idx = a[idx, 1]
128+
129+
a_idx_loop = asarray([a[idx[i], 1] for i in range(idx.shape[0])])
130+
assert all(a_idx == a_idx_loop)
131+
132+
133+
# index with two arrays
134+
a_idx = a[idx, idx]
135+
a_idx_loop = asarray([a[idx[i], idx[i]] for i in range(idx.shape[0])])
136+
assert all(a_idx == a_idx_loop)
137+
138+
94139
def test_promoted_scalar_inherits_device():
95140
device1 = Device("device1")
96141
x = asarray([1., 2, 3], device=device1)

0 commit comments

Comments
 (0)