@@ -1355,7 +1355,7 @@ def set_dims(self, dim, shape=None):
1355
1355
dim = [dim ]
1356
1356
1357
1357
if shape is None and is_dict_like (dim ):
1358
- shape = dim .values ()
1358
+ shape = tuple ( dim .values () )
1359
1359
1360
1360
missing_dims = set (self .dims ) - set (dim )
1361
1361
if missing_dims :
@@ -1371,13 +1371,18 @@ def set_dims(self, dim, shape=None):
1371
1371
# don't use broadcast_to unless necessary so the result remains
1372
1372
# writeable if possible
1373
1373
expanded_data = self .data
1374
- elif shape is not None :
1374
+ elif shape is None or all (
1375
+ s == 1 for s , e in zip (shape , dim , strict = True ) if e not in self_dims
1376
+ ):
1377
+ # "Trivial" broadcasting, i.e. simply inserting a new dimension
1378
+ # This is typically easier for duck arrays to implement
1379
+ # than the full "broadcast_to" semantics
1380
+ indexer = (None ,) * (len (expanded_dims ) - self .ndim ) + (...,)
1381
+ expanded_data = self .data [indexer ]
1382
+ else : # elif shape is not None:
1375
1383
dims_map = dict (zip (dim , shape , strict = True ))
1376
1384
tmp_shape = tuple (dims_map [d ] for d in expanded_dims )
1377
1385
expanded_data = duck_array_ops .broadcast_to (self ._data , tmp_shape )
1378
- else :
1379
- indexer = (None ,) * (len (expanded_dims ) - self .ndim ) + (...,)
1380
- expanded_data = self .data [indexer ]
1381
1386
1382
1387
expanded_var = Variable (
1383
1388
expanded_dims , expanded_data , self ._attrs , self ._encoding , fastpath = True
0 commit comments