Skip to content

Commit 2b784f2

Browse files
p4perf4ceheadtr1ck
andauthored
Consistent DatasetRolling.construct behavior (#7578)
* Removed `.isel` for consistent rolling behavior. `.isel` causes `DatasetRolling.construct` to behavior to be inconsistent with `DataArrayRolling.construct` when `stride` > 1. * new rolling construct strategy for coords * add whats-new * add new tests with different coords * next try on aligning strided coords * add peakmem test for rolling.construct * increase asv benchmark rolling sizes --------- Co-authored-by: Michael Niklas <[email protected]>
1 parent 04550e6 commit 2b784f2

File tree

4 files changed

+68
-14
lines changed

4 files changed

+68
-14
lines changed

asv_bench/benchmarks/rolling.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
from . import parameterized, randn, requires_dask
77

8-
nx = 300
8+
nx = 3000
99
long_nx = 30000
1010
ny = 200
11-
nt = 100
11+
nt = 1000
1212
window = 20
1313

1414
randn_xy = randn((nx, ny), frac_nan=0.1)
@@ -115,6 +115,11 @@ def peakmem_1drolling_reduce(self, func, use_bottleneck):
115115
roll = self.ds.var3.rolling(t=100)
116116
getattr(roll, func)()
117117

118+
@parameterized(["stride"], ([None, 5, 50]))
119+
def peakmem_1drolling_construct(self, stride):
120+
self.ds.var2.rolling(t=100).construct("w", stride=stride)
121+
self.ds.var3.rolling(t=100).construct("w", stride=stride)
122+
118123

119124
class DatasetRollingMemory(RollingMemory):
120125
@parameterized(["func", "use_bottleneck"], (["sum", "max", "mean"], [True, False]))
@@ -128,3 +133,7 @@ def peakmem_1drolling_reduce(self, func, use_bottleneck):
128133
with xr.set_options(use_bottleneck=use_bottleneck):
129134
roll = self.ds.rolling(t=100)
130135
getattr(roll, func)()
136+
137+
@parameterized(["stride"], ([None, 5, 50]))
138+
def peakmem_1drolling_construct(self, stride):
139+
self.ds.rolling(t=100).construct("w", stride=stride)

doc/whats-new.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,12 @@ Bug fixes
7474
of :py:meth:`DataArray.__setitem__` lose dimension names.
7575
(:issue:`7030`, :pull:`8067`) By `Darsh Ranjan <https://github.com/dranjan>`_.
7676
- Return ``float64`` in presence of ``NaT`` in :py:class:`~core.accessor_dt.DatetimeAccessor` and
77-
special case ``NaT`` handling in :py:meth:`~core.accessor_dt.DatetimeAccessor.isocalendar()`
77+
special case ``NaT`` handling in :py:meth:`~core.accessor_dt.DatetimeAccessor.isocalendar`
7878
(:issue:`7928`, :pull:`8084`).
7979
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
80+
- Fix :py:meth:`~core.rolling.DatasetRolling.construct` with stride on Datasets without indexes.
81+
(:issue:`7021`, :pull:`7578`).
82+
By `Amrest Chinkamol <https://github.com/p4perf4ce>`_ and `Michael Niklas <https://github.com/headtr1ck>`_.
8083
- Calling plot with kwargs ``col``, ``row`` or ``hue`` no longer squeezes dimensions passed via these arguments
8184
(:issue:`7552`, :pull:`8174`).
8285
By `Wiktor Kraśnicki <https://github.com/wkrasnicki>`_.

xarray/core/rolling.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -785,11 +785,14 @@ def construct(
785785
if not keep_attrs:
786786
dataset[key].attrs = {}
787787

788+
# Need to stride coords as well. TODO: is there a better way?
789+
coords = self.obj.isel(
790+
{d: slice(None, None, s) for d, s in zip(self.dim, strides)}
791+
).coords
792+
788793
attrs = self.obj.attrs if keep_attrs else {}
789794

790-
return Dataset(dataset, coords=self.obj.coords, attrs=attrs).isel(
791-
{d: slice(None, None, s) for d, s in zip(self.dim, strides)}
792-
)
795+
return Dataset(dataset, coords=coords, attrs=attrs)
793796

794797

795798
class Coarsen(CoarsenArithmetic, Generic[T_Xarray]):

xarray/tests/test_rolling.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None:
175175

176176
@pytest.mark.parametrize("center", (True, False))
177177
@pytest.mark.parametrize("window", (1, 2, 3, 4))
178-
def test_rolling_construct(self, center, window) -> None:
178+
def test_rolling_construct(self, center: bool, window: int) -> None:
179179
s = pd.Series(np.arange(10))
180180
da = DataArray.from_series(s)
181181

@@ -610,7 +610,7 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None:
610610

611611
@pytest.mark.parametrize("center", (True, False))
612612
@pytest.mark.parametrize("window", (1, 2, 3, 4))
613-
def test_rolling_construct(self, center, window) -> None:
613+
def test_rolling_construct(self, center: bool, window: int) -> None:
614614
df = pd.DataFrame(
615615
{
616616
"x": np.random.randn(20),
@@ -627,19 +627,58 @@ def test_rolling_construct(self, center, window) -> None:
627627
np.testing.assert_allclose(df_rolling["x"].values, ds_rolling_mean["x"].values)
628628
np.testing.assert_allclose(df_rolling.index, ds_rolling_mean["index"])
629629

630-
# with stride
631-
ds_rolling_mean = ds_rolling.construct("window", stride=2).mean("window")
632-
np.testing.assert_allclose(
633-
df_rolling["x"][::2].values, ds_rolling_mean["x"].values
634-
)
635-
np.testing.assert_allclose(df_rolling.index[::2], ds_rolling_mean["index"])
636630
# with fill_value
637631
ds_rolling_mean = ds_rolling.construct("window", stride=2, fill_value=0.0).mean(
638632
"window"
639633
)
640634
assert (ds_rolling_mean.isnull().sum() == 0).to_array(dim="vars").all()
641635
assert (ds_rolling_mean["x"] == 0.0).sum() >= 0
642636

637+
@pytest.mark.parametrize("center", (True, False))
638+
@pytest.mark.parametrize("window", (1, 2, 3, 4))
639+
def test_rolling_construct_stride(self, center: bool, window: int) -> None:
640+
df = pd.DataFrame(
641+
{
642+
"x": np.random.randn(20),
643+
"y": np.random.randn(20),
644+
"time": np.linspace(0, 1, 20),
645+
}
646+
)
647+
ds = Dataset.from_dataframe(df)
648+
df_rolling_mean = df.rolling(window, center=center, min_periods=1).mean()
649+
650+
# With an index (dimension coordinate)
651+
ds_rolling = ds.rolling(index=window, center=center)
652+
ds_rolling_mean = ds_rolling.construct("w", stride=2).mean("w")
653+
np.testing.assert_allclose(
654+
df_rolling_mean["x"][::2].values, ds_rolling_mean["x"].values
655+
)
656+
np.testing.assert_allclose(df_rolling_mean.index[::2], ds_rolling_mean["index"])
657+
658+
# Without index (https://github.com/pydata/xarray/issues/7021)
659+
ds2 = ds.drop_vars("index")
660+
ds2_rolling = ds2.rolling(index=window, center=center)
661+
ds2_rolling_mean = ds2_rolling.construct("w", stride=2).mean("w")
662+
np.testing.assert_allclose(
663+
df_rolling_mean["x"][::2].values, ds2_rolling_mean["x"].values
664+
)
665+
666+
# Mixed coordinates, indexes and 2D coordinates
667+
ds3 = xr.Dataset(
668+
{"x": ("t", range(20)), "x2": ("y", range(5))},
669+
{
670+
"t": range(20),
671+
"y": ("y", range(5)),
672+
"t2": ("t", range(20)),
673+
"y2": ("y", range(5)),
674+
"yt": (["t", "y"], np.ones((20, 5))),
675+
},
676+
)
677+
ds3_rolling = ds3.rolling(t=window, center=center)
678+
ds3_rolling_mean = ds3_rolling.construct("w", stride=2).mean("w")
679+
for coord in ds3.coords:
680+
assert coord in ds3_rolling_mean.coords
681+
643682
@pytest.mark.slow
644683
@pytest.mark.parametrize("ds", (1, 2), indirect=True)
645684
@pytest.mark.parametrize("center", (True, False))

0 commit comments

Comments
 (0)