Skip to content

Commit a65cbc4

Browse files
authored
Merge pull request data-apis#341 from ev-br/fix_meshgrid
BUG: torch/meshgrid: stop ignoring the "indexing" argument
2 parents f8b9dc4 + fa35e90 commit a65cbc4

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,7 @@ def sign(x: Array, /) -> Array:
826826
def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array]:
827827
# enforce the default of 'xy'
828828
# TODO: is the return type a list or a tuple
829-
return list(torch.meshgrid(*arrays, indexing='xy'))
829+
return list(torch.meshgrid(*arrays, indexing=indexing))
830830

831831

832832
__all__ = ['asarray', 'result_type', 'can_cast',

tests/test_dask.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
from contextlib import contextmanager
23

34
import numpy as np
@@ -167,6 +168,10 @@ def test_sort_argsort_chunk_size(xp, func, shape, chunks):
167168
)
168169

169170

171+
@pytest.mark.skipif(
172+
sys.version_info.major*100 + sys.version_info.minor < 312,
173+
reason="dask interop requires numpy >= 3.12"
174+
)
170175
@pytest.mark.parametrize("func", ["sort", "argsort"])
171176
def test_sort_argsort_meta(xp, func):
172177
"""Test meta-namespace other than numpy"""

tests/test_torch.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,16 @@ def test_meshgrid():
117117

118118
assert Y.shape == Y_xy.shape
119119
assert xp.all(Y == Y_xy)
120+
121+
# repeat with an explicit indexing
122+
X, Y = xp.meshgrid(x, y, indexing='ij')
123+
124+
# output of torch.meshgrid(x, y, indexing='ij')
125+
X_ij, Y_ij = xp.asarray([[1], [2]]), xp.asarray([[4], [4]])
126+
127+
assert X.shape == X_ij.shape
128+
assert xp.all(X == X_ij)
129+
130+
assert Y.shape == Y_ij.shape
131+
assert xp.all(Y == Y_ij)
132+

0 commit comments

Comments
 (0)