Skip to content

Commit 30ec9bb

Browse files
committed
TST: add a test for torch.argsort defaulting to stable=True
cross-ref data-apis#356 which wrapped torch.argsort to fix the default, and data-apis/array-api-tests#390 which made a matching change in the array-api-test suite.
1 parent f8b9dc4 commit 30ec9bb

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tests/test_torch.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,14 @@ def test_meshgrid():
117117

118118
assert Y.shape == Y_xy.shape
119119
assert xp.all(Y == Y_xy)
120+
121+
122+
def test_argsort_stable():
123+
"""Verify that argsort defaults to a stable sort."""
124+
# Bare pytorch defaults to an unstable sort, and the array_api_compat wrapper
125+
# enforces the stable=True default.
126+
# cf https://github.com/data-apis/array-api-compat/pull/356 and
127+
# https://github.com/data-apis/array-api-tests/pull/390#issuecomment-3452868329
128+
129+
t = xp.zeros(50) # should be >16
130+
assert xp.all(xp.argsort(t) == xp.arange(50))

0 commit comments

Comments
 (0)