@@ -1653,6 +1653,66 @@ def test_set_dims_object_dtype(self):
1653
1653
expected = Variable (["x" ], exp_values )
1654
1654
assert_identical (actual , expected )
1655
1655
1656
+ def test_set_dims_without_broadcast (self ):
1657
+ class ArrayWithoutBroadcastTo (NDArrayMixin , indexing .ExplicitlyIndexed ):
1658
+ def __init__ (self , array ):
1659
+ self .array = array
1660
+
1661
+ # Broadcasting with __getitem__ is "easier" to implement
1662
+ # especially for dims of 1
1663
+ def __getitem__ (self , key ):
1664
+ return self .array [key ]
1665
+
1666
+ def __array_function__ (self , * args , ** kwargs ):
1667
+ raise NotImplementedError (
1668
+ "Not we don't want to use broadcast_to here "
1669
+ "https://github.com/pydata/xarray/issues/9462"
1670
+ )
1671
+
1672
+ arr = ArrayWithoutBroadcastTo (np .zeros ((3 , 4 )))
1673
+ # We should be able to add a new axis without broadcasting
1674
+ assert arr [np .newaxis , :, :].shape == (1 , 3 , 4 )
1675
+ with pytest .raises (NotImplementedError ):
1676
+ np .broadcast_to (arr , (1 , 3 , 4 ))
1677
+
1678
+ v = Variable (["x" , "y" ], arr )
1679
+ v_expanded = v .set_dims (["z" , "x" , "y" ])
1680
+ assert v_expanded .dims == ("z" , "x" , "y" )
1681
+ assert v_expanded .shape == (1 , 3 , 4 )
1682
+
1683
+ # Explicitly asking for a shape of 1 triggers a different
1684
+ # codepath in set_dims
1685
+ # https://github.com/pydata/xarray/issues/9462
1686
+ v_expanded = v .set_dims (["z" , "x" , "y" ], shape = (1 , 3 , 4 ))
1687
+ assert v_expanded .dims == ("z" , "x" , "y" )
1688
+ assert v_expanded .shape == (1 , 3 , 4 )
1689
+
1690
+ v_expanded = v .set_dims (["x" , "z" , "y" ], shape = (3 , 1 , 4 ))
1691
+ assert v_expanded .dims == ("x" , "z" , "y" )
1692
+ assert v_expanded .shape == (3 , 1 , 4 )
1693
+
1694
+ v_expanded = v .set_dims (["x" , "y" , "z" ], shape = (3 , 4 , 1 ))
1695
+ assert v_expanded .dims == ("x" , "y" , "z" )
1696
+ assert v_expanded .shape == (3 , 4 , 1 )
1697
+
1698
+ v_expanded = v .set_dims ({"z" : 1 , "x" : 3 , "y" : 4 })
1699
+ assert v_expanded .dims == ("z" , "x" , "y" )
1700
+ assert v_expanded .shape == (1 , 3 , 4 )
1701
+
1702
+ v_expanded = v .set_dims ({"x" : 3 , "z" : 1 , "y" : 4 })
1703
+ assert v_expanded .dims == ("x" , "z" , "y" )
1704
+ assert v_expanded .shape == (3 , 1 , 4 )
1705
+
1706
+ v_expanded = v .set_dims ({"x" : 3 , "y" : 4 , "z" : 1 })
1707
+ assert v_expanded .dims == ("x" , "y" , "z" )
1708
+ assert v_expanded .shape == (3 , 4 , 1 )
1709
+
1710
+ with pytest .raises (NotImplementedError ):
1711
+ v .set_dims ({"z" : 2 , "x" : 3 , "y" : 4 })
1712
+
1713
+ with pytest .raises (NotImplementedError ):
1714
+ v .set_dims (["z" , "x" , "y" ], shape = (2 , 3 , 4 ))
1715
+
1656
1716
def test_stack (self ):
1657
1717
v = Variable (["x" , "y" ], [[0 , 1 ], [2 , 3 ]], {"foo" : "bar" })
1658
1718
actual = v .stack (z = ("x" , "y" ))
0 commit comments