Skip to content

Commit ff41afe

Browse files
committed
BUG: do not allow asarray of nested sequences of arrays
1 parent 25cc3d7 commit ff41afe

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

array_api_strict/_creation_functions.py

+4
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@ def asarray(
121121

122122
if isinstance(obj, Array):
123123
return Array._new(np.array(obj._array, copy=copy, dtype=_np_dtype), device=device)
124+
elif isinstance(obj, list | tuple):
125+
if any(isinstance(x, Array) for x in obj):
126+
raise TypeError("Nested Arrays are not allowed. Use `stack` instead.")
127+
124128
if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)):
125129
# Give a better error message in this case. NumPy would convert this
126130
# to an object array. TODO: This won't handle large integers in lists.

array_api_strict/tests/test_creation_functions.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -97,18 +97,20 @@ def test_asarray_copy():
9797
a[0] = 0
9898
assert all(b[0] == 0)
9999

100+
100101
def test_asarray_list_of_lists():
101-
a = asarray(1, dtype=int16)
102-
b = asarray([1], dtype=int16)
103-
res = asarray([a, a])
104-
assert res.shape == (2,)
105-
assert res.dtype == int16
106-
assert all(res == asarray([1, 1]))
107-
108-
res = asarray([b, b])
109-
assert res.shape == (2, 1)
110-
assert res.dtype == int16
111-
assert all(res == asarray([[1], [1]]))
102+
lst = [[1, 2, 3], [4, 5, 6]]
103+
res = asarray(lst)
104+
assert res.shape == (2, 3)
105+
106+
107+
def test_asarray_nested_arrays():
108+
# do not allow arrays in nested sequences
109+
with pytest.raises(TypeError):
110+
asarray([[1, 2, 3], asarray([4, 5, 6])])
111+
112+
with pytest.raises(TypeError):
113+
asarray([1, asarray(1)])
112114

113115

114116
def test_asarray_device_inference():

0 commit comments

Comments
 (0)