From 6fe1bc46f59d51d886ef93bf0ab72321fcfe42cc Mon Sep 17 00:00:00 2001 From: Mark Harfouche Date: Wed, 30 Apr 2025 12:07:53 -0400 Subject: [PATCH 1/4] Add a test to showcase the differences in the set_dims behavior xref: https://github.com/pydata/xarray/issues/9462 --- xarray/tests/test_variable.py | 60 +++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 619dc1561ef..d1afa10c5fa 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1653,6 +1653,66 @@ def test_set_dims_object_dtype(self): expected = Variable(["x"], exp_values) assert_identical(actual, expected) + def test_set_dims_without_broadcast(self): + class ArrayWithoutBroadcastTo(NDArrayMixin, indexing.ExplicitlyIndexed): + def __init__(self, array): + self.array = array + + # Broadcasting with __getitem__ is "easier" to implement + # especially for dims of 1 + def __getitem__(self, key): + return self.array[key] + + def __array_function__(self, *args, **kwargs): + raise NotImplementedError( + "Not we don't want to use broadcast_to here " + "https://github.com/pydata/xarray/issues/9462" + ) + + arr = ArrayWithoutBroadcastTo(np.zeros((3, 4))) + # We should be able to add a new axis without broadcasting + assert arr[np.newaxis, :, :].shape == (1, 3, 4) + with pytest.raises(NotImplementedError): + np.broadcast_to(arr, (1, 3, 4)) + + v = Variable(["x", "y"], arr) + v_expanded = v.set_dims(["z", "x", "y"]) + assert v_expanded.dims == ("z", "x", "y") + assert v_expanded.shape == (1, 3, 4) + + # Explicitly asking for a shape of 1 triggers a different + # codepath in set_dims + # https://github.com/pydata/xarray/issues/9462 + v_expanded = v.set_dims(["z", "x", "y"], shape=(1, 3, 4)) + assert v_expanded.dims == ("z", "x", "y") + assert v_expanded.shape == (1, 3, 4) + + v_expanded = v.set_dims(["x", "z", "y"], shape=(3, 1, 4)) + assert v_expanded.dims == ("x", "z", "y") + assert v_expanded.shape == (3, 1, 4) + + v_expanded = v.set_dims(["x", "y", "z"], shape=(3, 4, 1)) + assert v_expanded.dims == ("x", "y", "z") + assert v_expanded.shape == (3, 4, 1) + + v_expanded = v.set_dims({"z": 1, "x": 3, "y": 4}) + assert v_expanded.dims == ("z", "x", "y") + assert v_expanded.shape == (1, 3, 4) + + v_expanded = v.set_dims({"x": 3, "z": 1, "y": 4}) + assert v_expanded.dims == ("x", "z", "y") + assert v_expanded.shape == (3, 1, 4) + + v_expanded = v.set_dims({"x": 3, "y": 4, "z": 1}) + assert v_expanded.dims == ("x", "y", "z") + assert v_expanded.shape == (3, 4, 1) + + with pytest.raises(NotImplementedError): + v.set_dims({"z": 2, "x": 3, "y": 4}) + + with pytest.raises(NotImplementedError): + v.set_dims(["z", "x", "y"], shape=(2, 3, 4)) + def test_stack(self): v = Variable(["x", "y"], [[0, 1], [2, 3]], {"foo": "bar"}) actual = v.stack(z=("x", "y")) From fb67fc273a1fe6ec0ae90882daefbc7307311621 Mon Sep 17 00:00:00 2001 From: Mark Harfouche Date: Wed, 30 Apr 2025 12:08:18 -0400 Subject: [PATCH 2/4] Expand on the implementation of set_dims to make the trivial case easier --- xarray/core/variable.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index b8b33997780..7ca22cd98e5 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1355,7 +1355,7 @@ def set_dims(self, dim, shape=None): dim = [dim] if shape is None and is_dict_like(dim): - shape = dim.values() + shape = tuple(dim.values()) missing_dims = set(self.dims) - set(dim) if missing_dims: @@ -1371,13 +1371,18 @@ def set_dims(self, dim, shape=None): # don't use broadcast_to unless necessary so the result remains # writeable if possible expanded_data = self.data - elif shape is not None: + elif shape is None or all( + s == 1 for s, e in zip(shape, dim, strict=True) if e not in self_dims + ): + # "Trivial" broadcasting, i.e. simply inserting a new dimension + # This is typically easier for duck arrays to implement + # than the full "broadcast_to" semantics + indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,) + expanded_data = self.data[indexer] + else: # elif shape is not None: dims_map = dict(zip(dim, shape, strict=True)) tmp_shape = tuple(dims_map[d] for d in expanded_dims) expanded_data = duck_array_ops.broadcast_to(self._data, tmp_shape) - else: - indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,) - expanded_data = self.data[indexer] expanded_var = Variable( expanded_dims, expanded_data, self._attrs, self._encoding, fastpath=True From db27ee019e386121529b273ab8d8c81b2d4be1b9 Mon Sep 17 00:00:00 2001 From: Mark Harfouche Date: Wed, 30 Apr 2025 12:08:22 -0400 Subject: [PATCH 3/4] Add release note for https://github.com/pydata/xarray/pull/10277 --- doc/whats-new.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 76fb5d42aa9..f85378d7176 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -97,6 +97,11 @@ Internal Changes ~~~~~~~~~~~~~~~~ - Avoid stacking when grouping by a chunked array. This can be a large performance improvement. By `Deepak Cherian `_. +- The implementation of ``Variable.set_dims`` has changed to use array indexing syntax + instead of ``np.broadcast_to`` to perform dimension expansions where + all new dimensions have a size of 1. This should improve compatibility with + duck arrays that do not support broadcasting (:issue:`9462`, :pull:`10277`). + By `Mark Harfouche `_. .. _whats-new.2025.03.1: From a4ee4e54f97bd4ae5645abfe8bacbfae4edff40c Mon Sep 17 00:00:00 2001 From: Mark Harfouche Date: Thu, 1 May 2025 10:04:39 -0400 Subject: [PATCH 4/4] Update test_variable.py --- xarray/tests/test_variable.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index d1afa10c5fa..1e7c32dec1e 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1680,6 +1680,14 @@ def __array_function__(self, *args, **kwargs): assert v_expanded.dims == ("z", "x", "y") assert v_expanded.shape == (1, 3, 4) + v_expanded = v.set_dims(["x", "z", "y"]) + assert v_expanded.dims == ("x", "z", "y") + assert v_expanded.shape == (3, 1, 4) + + v_expanded = v.set_dims(["x", "y", "z"]) + assert v_expanded.dims == ("x", "y", "z") + assert v_expanded.shape == (3, 4, 1) + # Explicitly asking for a shape of 1 triggers a different # codepath in set_dims # https://github.com/pydata/xarray/issues/9462