Skip to content

Commit 6fe1bc4

Browse files
committed
Add a test to showcase the differences in the set_dims behavior
xref: #9462
1 parent 0759405 commit 6fe1bc4

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

xarray/tests/test_variable.py

+60
Original file line numberDiff line numberDiff line change
@@ -1653,6 +1653,66 @@ def test_set_dims_object_dtype(self):
16531653
expected = Variable(["x"], exp_values)
16541654
assert_identical(actual, expected)
16551655

1656+
def test_set_dims_without_broadcast(self):
1657+
class ArrayWithoutBroadcastTo(NDArrayMixin, indexing.ExplicitlyIndexed):
1658+
def __init__(self, array):
1659+
self.array = array
1660+
1661+
# Broadcasting with __getitem__ is "easier" to implement
1662+
# especially for dims of 1
1663+
def __getitem__(self, key):
1664+
return self.array[key]
1665+
1666+
def __array_function__(self, *args, **kwargs):
1667+
raise NotImplementedError(
1668+
"Not we don't want to use broadcast_to here "
1669+
"https://github.com/pydata/xarray/issues/9462"
1670+
)
1671+
1672+
arr = ArrayWithoutBroadcastTo(np.zeros((3, 4)))
1673+
# We should be able to add a new axis without broadcasting
1674+
assert arr[np.newaxis, :, :].shape == (1, 3, 4)
1675+
with pytest.raises(NotImplementedError):
1676+
np.broadcast_to(arr, (1, 3, 4))
1677+
1678+
v = Variable(["x", "y"], arr)
1679+
v_expanded = v.set_dims(["z", "x", "y"])
1680+
assert v_expanded.dims == ("z", "x", "y")
1681+
assert v_expanded.shape == (1, 3, 4)
1682+
1683+
# Explicitly asking for a shape of 1 triggers a different
1684+
# codepath in set_dims
1685+
# https://github.com/pydata/xarray/issues/9462
1686+
v_expanded = v.set_dims(["z", "x", "y"], shape=(1, 3, 4))
1687+
assert v_expanded.dims == ("z", "x", "y")
1688+
assert v_expanded.shape == (1, 3, 4)
1689+
1690+
v_expanded = v.set_dims(["x", "z", "y"], shape=(3, 1, 4))
1691+
assert v_expanded.dims == ("x", "z", "y")
1692+
assert v_expanded.shape == (3, 1, 4)
1693+
1694+
v_expanded = v.set_dims(["x", "y", "z"], shape=(3, 4, 1))
1695+
assert v_expanded.dims == ("x", "y", "z")
1696+
assert v_expanded.shape == (3, 4, 1)
1697+
1698+
v_expanded = v.set_dims({"z": 1, "x": 3, "y": 4})
1699+
assert v_expanded.dims == ("z", "x", "y")
1700+
assert v_expanded.shape == (1, 3, 4)
1701+
1702+
v_expanded = v.set_dims({"x": 3, "z": 1, "y": 4})
1703+
assert v_expanded.dims == ("x", "z", "y")
1704+
assert v_expanded.shape == (3, 1, 4)
1705+
1706+
v_expanded = v.set_dims({"x": 3, "y": 4, "z": 1})
1707+
assert v_expanded.dims == ("x", "y", "z")
1708+
assert v_expanded.shape == (3, 4, 1)
1709+
1710+
with pytest.raises(NotImplementedError):
1711+
v.set_dims({"z": 2, "x": 3, "y": 4})
1712+
1713+
with pytest.raises(NotImplementedError):
1714+
v.set_dims(["z", "x", "y"], shape=(2, 3, 4))
1715+
16561716
def test_stack(self):
16571717
v = Variable(["x", "y"], [[0, 1], [2, 3]], {"foo": "bar"})
16581718
actual = v.stack(z=("x", "y"))

0 commit comments

Comments
 (0)