Skip to content

Do not rely on np.broadcast_to to perform trivial dimension insertion #10277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ Internal Changes
~~~~~~~~~~~~~~~~
- Avoid stacking when grouping by a chunked array. This can be a large performance improvement.
By `Deepak Cherian <https://github.com/dcherian>`_.
- The implementation of ``Variable.set_dims`` has changed to use array indexing syntax
instead of ``np.broadcast_to`` to perform dimension expansions where
all new dimensions have a size of 1. This should improve compatibility with
duck arrays that do not support broadcasting (:issue:`9462`, :pull:`10277`).
By `Mark Harfouche <https://github.com/hmaarrfk>`_.

.. _whats-new.2025.03.1:

Expand Down
15 changes: 10 additions & 5 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,7 +1355,7 @@ def set_dims(self, dim, shape=None):
dim = [dim]

if shape is None and is_dict_like(dim):
shape = dim.values()
shape = tuple(dim.values())

missing_dims = set(self.dims) - set(dim)
if missing_dims:
Expand All @@ -1371,13 +1371,18 @@ def set_dims(self, dim, shape=None):
# don't use broadcast_to unless necessary so the result remains
# writeable if possible
expanded_data = self.data
elif shape is not None:
elif shape is None or all(
s == 1 for s, e in zip(shape, dim, strict=True) if e not in self_dims
):
# "Trivial" broadcasting, i.e. simply inserting a new dimension
# This is typically easier for duck arrays to implement
# than the full "broadcast_to" semantics
indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,)
expanded_data = self.data[indexer]
else: # elif shape is not None:
dims_map = dict(zip(dim, shape, strict=True))
tmp_shape = tuple(dims_map[d] for d in expanded_dims)
expanded_data = duck_array_ops.broadcast_to(self._data, tmp_shape)
else:
indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,)
expanded_data = self.data[indexer]

expanded_var = Variable(
expanded_dims, expanded_data, self._attrs, self._encoding, fastpath=True
Expand Down
68 changes: 68 additions & 0 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1653,6 +1653,74 @@ def test_set_dims_object_dtype(self):
expected = Variable(["x"], exp_values)
assert_identical(actual, expected)

def test_set_dims_without_broadcast(self):
class ArrayWithoutBroadcastTo(NDArrayMixin, indexing.ExplicitlyIndexed):
def __init__(self, array):
self.array = array

# Broadcasting with __getitem__ is "easier" to implement
# especially for dims of 1
def __getitem__(self, key):
return self.array[key]

def __array_function__(self, *args, **kwargs):
raise NotImplementedError(
"Not we don't want to use broadcast_to here "
"https://github.com/pydata/xarray/issues/9462"
)

arr = ArrayWithoutBroadcastTo(np.zeros((3, 4)))
# We should be able to add a new axis without broadcasting
assert arr[np.newaxis, :, :].shape == (1, 3, 4)
with pytest.raises(NotImplementedError):
np.broadcast_to(arr, (1, 3, 4))

v = Variable(["x", "y"], arr)
v_expanded = v.set_dims(["z", "x", "y"])
assert v_expanded.dims == ("z", "x", "y")
assert v_expanded.shape == (1, 3, 4)

v_expanded = v.set_dims(["x", "z", "y"])
assert v_expanded.dims == ("x", "z", "y")
assert v_expanded.shape == (3, 1, 4)

v_expanded = v.set_dims(["x", "y", "z"])
assert v_expanded.dims == ("x", "y", "z")
assert v_expanded.shape == (3, 4, 1)

# Explicitly asking for a shape of 1 triggers a different
# codepath in set_dims
# https://github.com/pydata/xarray/issues/9462
v_expanded = v.set_dims(["z", "x", "y"], shape=(1, 3, 4))
assert v_expanded.dims == ("z", "x", "y")
assert v_expanded.shape == (1, 3, 4)

v_expanded = v.set_dims(["x", "z", "y"], shape=(3, 1, 4))
assert v_expanded.dims == ("x", "z", "y")
assert v_expanded.shape == (3, 1, 4)

v_expanded = v.set_dims(["x", "y", "z"], shape=(3, 4, 1))
assert v_expanded.dims == ("x", "y", "z")
assert v_expanded.shape == (3, 4, 1)

v_expanded = v.set_dims({"z": 1, "x": 3, "y": 4})
assert v_expanded.dims == ("z", "x", "y")
assert v_expanded.shape == (1, 3, 4)

v_expanded = v.set_dims({"x": 3, "z": 1, "y": 4})
assert v_expanded.dims == ("x", "z", "y")
assert v_expanded.shape == (3, 1, 4)

v_expanded = v.set_dims({"x": 3, "y": 4, "z": 1})
assert v_expanded.dims == ("x", "y", "z")
assert v_expanded.shape == (3, 4, 1)

with pytest.raises(NotImplementedError):
v.set_dims({"z": 2, "x": 3, "y": 4})

with pytest.raises(NotImplementedError):
v.set_dims(["z", "x", "y"], shape=(2, 3, 4))

def test_stack(self):
v = Variable(["x", "y"], [[0, 1], [2, 3]], {"foo": "bar"})
actual = v.stack(z=("x", "y"))
Expand Down
Loading