Skip to content

Commit fb67fc2

Browse files
committed
Expand on the implementation of set_dims to make the trivial case easier
1 parent 6fe1bc4 commit fb67fc2

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

xarray/core/variable.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -1355,7 +1355,7 @@ def set_dims(self, dim, shape=None):
13551355
dim = [dim]
13561356

13571357
if shape is None and is_dict_like(dim):
1358-
shape = dim.values()
1358+
shape = tuple(dim.values())
13591359

13601360
missing_dims = set(self.dims) - set(dim)
13611361
if missing_dims:
@@ -1371,13 +1371,18 @@ def set_dims(self, dim, shape=None):
13711371
# don't use broadcast_to unless necessary so the result remains
13721372
# writeable if possible
13731373
expanded_data = self.data
1374-
elif shape is not None:
1374+
elif shape is None or all(
1375+
s == 1 for s, e in zip(shape, dim, strict=True) if e not in self_dims
1376+
):
1377+
# "Trivial" broadcasting, i.e. simply inserting a new dimension
1378+
# This is typically easier for duck arrays to implement
1379+
# than the full "broadcast_to" semantics
1380+
indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,)
1381+
expanded_data = self.data[indexer]
1382+
else: # elif shape is not None:
13751383
dims_map = dict(zip(dim, shape, strict=True))
13761384
tmp_shape = tuple(dims_map[d] for d in expanded_dims)
13771385
expanded_data = duck_array_ops.broadcast_to(self._data, tmp_shape)
1378-
else:
1379-
indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,)
1380-
expanded_data = self.data[indexer]
13811386

13821387
expanded_var = Variable(
13831388
expanded_dims, expanded_data, self._attrs, self._encoding, fastpath=True

0 commit comments

Comments
 (0)