Skip to content

Commit d66919e

Browse files
committed
ENH: add "repro_snippets" to test_array_object.py
1 parent 445c172 commit d66919e

File tree

1 file changed

+99
-77
lines changed

1 file changed

+99
-77
lines changed

array_api_tests/test_array_object.py

Lines changed: 99 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def test_getitem(shape, dtype, data):
8686
key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key")
8787

8888
repro_snippet = ph.format_snippet(f"{x!r}[{key!r}]")
89-
9089
try:
9190
out = x[key]
9291

@@ -109,6 +108,7 @@ def test_getitem(shape, dtype, data):
109108
exc.add_note(repro_snippet)
110109
raise
111110

111+
112112
@pytest.mark.unvectorized
113113
@given(
114114
shape=hh.shapes(),
@@ -133,28 +133,34 @@ def test_setitem(shape, dtypes, data):
133133
value = data.draw(value_strat, label="value")
134134

135135
res = xp.asarray(x, copy=True)
136-
res[key] = value
137-
138-
ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype")
139-
ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.shape")
140-
f_res = sh.fmt_idx("x", key)
141-
if isinstance(value, get_args(Scalar)):
142-
msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]"
143-
if cmath.isnan(value):
144-
assert xp.isnan(res[key]), msg
136+
137+
repro_snippet = ph.format_snippet(f"{res!r}[{key!r}] = {value!r}")
138+
try:
139+
res[key] = value
140+
141+
ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype")
142+
ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.shape")
143+
f_res = sh.fmt_idx("x", key)
144+
if isinstance(value, get_args(Scalar)):
145+
msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]"
146+
if cmath.isnan(value):
147+
assert xp.isnan(res[key]), msg
148+
else:
149+
assert res[key] == value, msg
145150
else:
146-
assert res[key] == value, msg
147-
else:
148-
ph.assert_array_elements("__setitem__", out=res[key], expected=value, out_repr=f_res)
149-
unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices))
150-
for idx in unaffected_indices:
151-
ph.assert_0d_equals(
152-
"__setitem__",
153-
x_repr=f"old {f_res}",
154-
x_val=x[idx],
155-
out_repr=f"modified {f_res}",
156-
out_val=res[idx],
157-
)
151+
ph.assert_array_elements("__setitem__", out=res[key], expected=value, out_repr=f_res)
152+
unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices))
153+
for idx in unaffected_indices:
154+
ph.assert_0d_equals(
155+
"__setitem__",
156+
x_repr=f"old {f_res}",
157+
x_val=x[idx],
158+
out_repr=f"modified {f_res}",
159+
out_val=res[idx],
160+
)
161+
except Exception as exc:
162+
exc.add_note(repro_snippet)
163+
raise
158164

159165

160166
@pytest.mark.unvectorized
@@ -178,29 +184,34 @@ def test_getitem_masking(shape, data):
178184
x[key]
179185
return
180186

181-
out = x[key]
187+
repro_snippet = ph.format_snippet(f"out = {x!r}[{key!r}]")
188+
try:
189+
out = x[key]
182190

183-
ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype)
184-
if key.ndim == 0:
185-
expected_shape = (1,) if key else (0,)
186-
expected_shape += x.shape
187-
else:
188-
size = int(xp.sum(xp.astype(key, xp.uint8)))
189-
expected_shape = (size,) + x.shape[key.ndim :]
190-
ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape)
191-
if not any(s == 0 for s in key.shape):
192-
assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios
193-
out_indices = sh.ndindex(out.shape)
194-
for x_idx in sh.ndindex(x.shape):
195-
if key[x_idx]:
196-
out_idx = next(out_indices)
197-
ph.assert_0d_equals(
198-
"__getitem__",
199-
x_repr=f"x[{x_idx}]",
200-
x_val=x[x_idx],
201-
out_repr=f"out[{out_idx}]",
202-
out_val=out[out_idx],
203-
)
191+
ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype)
192+
if key.ndim == 0:
193+
expected_shape = (1,) if key else (0,)
194+
expected_shape += x.shape
195+
else:
196+
size = int(xp.sum(xp.astype(key, xp.uint8)))
197+
expected_shape = (size,) + x.shape[key.ndim :]
198+
ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape)
199+
if not any(s == 0 for s in key.shape):
200+
assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios
201+
out_indices = sh.ndindex(out.shape)
202+
for x_idx in sh.ndindex(x.shape):
203+
if key[x_idx]:
204+
out_idx = next(out_indices)
205+
ph.assert_0d_equals(
206+
"__getitem__",
207+
x_repr=f"x[{x_idx}]",
208+
x_val=x[x_idx],
209+
out_repr=f"out[{out_idx}]",
210+
out_val=out[out_idx],
211+
)
212+
except Exception as exc:
213+
exc.add_note(repro_snippet)
214+
raise
204215

205216

206217
@pytest.mark.unvectorized
@@ -213,38 +224,44 @@ def test_setitem_masking(shape, data):
213224
)
214225

215226
res = xp.asarray(x, copy=True)
216-
res[key] = value
217-
218-
ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype")
219-
ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.dtype")
220-
scalar_type = dh.get_scalar_type(x.dtype)
221-
for idx in sh.ndindex(x.shape):
222-
if key[idx]:
223-
if isinstance(value, get_args(Scalar)):
224-
ph.assert_scalar_equals(
225-
"__setitem__",
226-
type_=scalar_type,
227-
idx=idx,
228-
out=scalar_type(res[idx]),
229-
expected=value,
230-
repr_name="modified x",
231-
)
227+
228+
repro_snippet = ph.format_snippet(f"{res}[{key!r}] = {value!r}")
229+
try:
230+
res[key] = value
231+
232+
ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype")
233+
ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.dtype")
234+
scalar_type = dh.get_scalar_type(x.dtype)
235+
for idx in sh.ndindex(x.shape):
236+
if key[idx]:
237+
if isinstance(value, get_args(Scalar)):
238+
ph.assert_scalar_equals(
239+
"__setitem__",
240+
type_=scalar_type,
241+
idx=idx,
242+
out=scalar_type(res[idx]),
243+
expected=value,
244+
repr_name="modified x",
245+
)
246+
else:
247+
ph.assert_0d_equals(
248+
"__setitem__",
249+
x_repr="value",
250+
x_val=value,
251+
out_repr=f"modified x[{idx}]",
252+
out_val=res[idx]
253+
)
232254
else:
233255
ph.assert_0d_equals(
234256
"__setitem__",
235-
x_repr="value",
236-
x_val=value,
257+
x_repr=f"old x[{idx}]",
258+
x_val=x[idx],
237259
out_repr=f"modified x[{idx}]",
238260
out_val=res[idx]
239261
)
240-
else:
241-
ph.assert_0d_equals(
242-
"__setitem__",
243-
x_repr=f"old x[{idx}]",
244-
x_val=x[idx],
245-
out_repr=f"modified x[{idx}]",
246-
out_val=res[idx]
247-
)
262+
except Exception as exc:
263+
exc.add_note(repro_snippet)
264+
raise
248265

249266

250267
# ### Fancy indexing ###
@@ -309,15 +326,20 @@ def _test_getitem_arrays_and_ints(shape, data, idx_max_dims):
309326
key.append(data.draw(st.integers(-shape[i], shape[i]-1)))
310327

311328
key = tuple(key)
312-
out = x[key]
329+
repro_snippet = ph.format_snippet(f"out = {x!r}[{key!r}]")
330+
try:
331+
out = x[key]
313332

314-
arrays = [xp.asarray(k) for k in key]
315-
bcast_shape = sh.broadcast_shapes(*[arr.shape for arr in arrays])
316-
bcast_key = [xp.broadcast_to(arr, bcast_shape) for arr in arrays]
333+
arrays = [xp.asarray(k) for k in key]
334+
bcast_shape = sh.broadcast_shapes(*[arr.shape for arr in arrays])
335+
bcast_key = [xp.broadcast_to(arr, bcast_shape) for arr in arrays]
317336

318-
for idx in sh.ndindex(bcast_shape):
319-
tpl = tuple(k[idx] for k in bcast_key)
320-
assert out[idx] == x[tpl], f"failing at {idx = } w/ {key = }"
337+
for idx in sh.ndindex(bcast_shape):
338+
tpl = tuple(k[idx] for k in bcast_key)
339+
assert out[idx] == x[tpl], f"failing at {idx = } w/ {key = }"
340+
except Exception as exc:
341+
exc.add_note(repro_snippet)
342+
raise
321343

322344

323345
def make_scalar_casting_param(

0 commit comments

Comments
 (0)