@@ -175,7 +175,7 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None:
175
175
176
176
@pytest .mark .parametrize ("center" , (True , False ))
177
177
@pytest .mark .parametrize ("window" , (1 , 2 , 3 , 4 ))
178
- def test_rolling_construct (self , center , window ) -> None :
178
+ def test_rolling_construct (self , center : bool , window : int ) -> None :
179
179
s = pd .Series (np .arange (10 ))
180
180
da = DataArray .from_series (s )
181
181
@@ -610,7 +610,7 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None:
610
610
611
611
@pytest .mark .parametrize ("center" , (True , False ))
612
612
@pytest .mark .parametrize ("window" , (1 , 2 , 3 , 4 ))
613
- def test_rolling_construct (self , center , window ) -> None :
613
+ def test_rolling_construct (self , center : bool , window : int ) -> None :
614
614
df = pd .DataFrame (
615
615
{
616
616
"x" : np .random .randn (20 ),
@@ -627,19 +627,58 @@ def test_rolling_construct(self, center, window) -> None:
627
627
np .testing .assert_allclose (df_rolling ["x" ].values , ds_rolling_mean ["x" ].values )
628
628
np .testing .assert_allclose (df_rolling .index , ds_rolling_mean ["index" ])
629
629
630
- # with stride
631
- ds_rolling_mean = ds_rolling .construct ("window" , stride = 2 ).mean ("window" )
632
- np .testing .assert_allclose (
633
- df_rolling ["x" ][::2 ].values , ds_rolling_mean ["x" ].values
634
- )
635
- np .testing .assert_allclose (df_rolling .index [::2 ], ds_rolling_mean ["index" ])
636
630
# with fill_value
637
631
ds_rolling_mean = ds_rolling .construct ("window" , stride = 2 , fill_value = 0.0 ).mean (
638
632
"window"
639
633
)
640
634
assert (ds_rolling_mean .isnull ().sum () == 0 ).to_array (dim = "vars" ).all ()
641
635
assert (ds_rolling_mean ["x" ] == 0.0 ).sum () >= 0
642
636
637
+ @pytest .mark .parametrize ("center" , (True , False ))
638
+ @pytest .mark .parametrize ("window" , (1 , 2 , 3 , 4 ))
639
+ def test_rolling_construct_stride (self , center : bool , window : int ) -> None :
640
+ df = pd .DataFrame (
641
+ {
642
+ "x" : np .random .randn (20 ),
643
+ "y" : np .random .randn (20 ),
644
+ "time" : np .linspace (0 , 1 , 20 ),
645
+ }
646
+ )
647
+ ds = Dataset .from_dataframe (df )
648
+ df_rolling_mean = df .rolling (window , center = center , min_periods = 1 ).mean ()
649
+
650
+ # With an index (dimension coordinate)
651
+ ds_rolling = ds .rolling (index = window , center = center )
652
+ ds_rolling_mean = ds_rolling .construct ("w" , stride = 2 ).mean ("w" )
653
+ np .testing .assert_allclose (
654
+ df_rolling_mean ["x" ][::2 ].values , ds_rolling_mean ["x" ].values
655
+ )
656
+ np .testing .assert_allclose (df_rolling_mean .index [::2 ], ds_rolling_mean ["index" ])
657
+
658
+ # Without index (https://github.com/pydata/xarray/issues/7021)
659
+ ds2 = ds .drop_vars ("index" )
660
+ ds2_rolling = ds2 .rolling (index = window , center = center )
661
+ ds2_rolling_mean = ds2_rolling .construct ("w" , stride = 2 ).mean ("w" )
662
+ np .testing .assert_allclose (
663
+ df_rolling_mean ["x" ][::2 ].values , ds2_rolling_mean ["x" ].values
664
+ )
665
+
666
+ # Mixed coordinates, indexes and 2D coordinates
667
+ ds3 = xr .Dataset (
668
+ {"x" : ("t" , range (20 )), "x2" : ("y" , range (5 ))},
669
+ {
670
+ "t" : range (20 ),
671
+ "y" : ("y" , range (5 )),
672
+ "t2" : ("t" , range (20 )),
673
+ "y2" : ("y" , range (5 )),
674
+ "yt" : (["t" , "y" ], np .ones ((20 , 5 ))),
675
+ },
676
+ )
677
+ ds3_rolling = ds3 .rolling (t = window , center = center )
678
+ ds3_rolling_mean = ds3_rolling .construct ("w" , stride = 2 ).mean ("w" )
679
+ for coord in ds3 .coords :
680
+ assert coord in ds3_rolling_mean .coords
681
+
643
682
@pytest .mark .slow
644
683
@pytest .mark .parametrize ("ds" , (1 , 2 ), indirect = True )
645
684
@pytest .mark .parametrize ("center" , (True , False ))
0 commit comments