Skip to content

Extend rolling_exp to support pd.Timedelta objects with window halflife #10237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
16 changes: 16 additions & 0 deletions doc/user-guide/computation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,22 @@ The ``rolling_exp`` method takes a ``window_type`` kwarg, which can be ``'alpha'
``'com'`` (for ``center-of-mass``), ``'span'``, and ``'halflife'``. The default is
``span``.

For datetime axes, ``rolling_exp`` can work with timedelta windows when using ``window_type='halflife'``.
This enables handling irregular time series by computing weights based on the actual time differences
between points, similar to pandas' ``ewm`` with ``times`` parameter:

.. code:: python

# Create a DataArray with datetime index
times = pd.date_range("2020-01-01", periods=5, freq="1D")
da = xr.DataArray([1, 2, 3, 4, 5], dims="time", coords={"time": times})

# Apply exponential moving average with 1-day halflife
da.rolling_exp(time=pd.Timedelta(days=1), window_type="halflife").mean()

Note that when using timedeltas with ``window_type='halflife'``, only the ``mean()`` operation is currently
supported, and it must be applied to a datetime coordinate.

Finally, the rolling object has a ``construct`` method which returns a
view of the original ``DataArray`` with the windowed dimension in
the last position.
Expand Down
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ New Features
- Improved compatibility with OPeNDAP DAP4 data model for backend engine ``pydap``. This
includes ``datatree`` support, and removing slashes from dimension names. By
`Miguel Jimenez-Urias <https://github.com/Mikejmnez>`_.
- Extended ``rolling_exp`` to support ``pd.Timedelta`` objects for the window size when using
``window_type="halflife"`` along datetime dimensions, similar to pandas' ``ewm``. This allows
expressions like ``da.rolling_exp(time=pd.Timedelta("1D"), window_type="halflife").mean()``.
By `Andrea Biasioli <https://github.com/abiasiol>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
127 changes: 119 additions & 8 deletions xarray/computation/rolling_exp.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

from collections.abc import Mapping
from typing import Any, Generic
from typing import Any, Generic, Literal

import numpy as np
import pandas as pd
from pandas.core.arrays.datetimelike import dtype_to_unit

from xarray.compat.pdcompat import count_not_none
from xarray.computation.apply_ufunc import apply_ufunc
from xarray.core.common import is_np_datetime_like
from xarray.core.options import _get_keep_attrs
from xarray.core.types import T_DataWithCoords
from xarray.core.utils import module_available
Expand Down Expand Up @@ -46,22 +49,119 @@ def _get_alpha(
raise ValueError("Must pass one of comass, span, halflife, or alpha")


def _raise_if_array(alpha: float | np.ndarray):
"""Check if alpha is a float, raise NotImplementedError if not.

If alpha is an array, it means window_type='halflife' with Timedelta window,
and the operation is applied on a datetime index. The 'mean' operation is the
only one supported for this type of operation.

Parameters
----------
alpha : float or np.ndarray
If array, only the 'mean' operation is supported.

Raises
------
NotImplementedError
If alpha is an array.
"""
if not isinstance(alpha, float):
msg = (
"Operation not supported for window_type='halflife' with 'Timedelta' window. "
"Only 'mean' operation is supported with those window parameters."
)
raise NotImplementedError(msg)


def _calculate_deltas(
times: np.ndarray,
halflife: pd.Timedelta,
):
"""
Return the diff of the times divided by the half-life. These values are used in
the calculation of the ewm mean.

Parameters
----------
times : np.ndarray, Series
Times corresponding to the observations. Must be monotonically increasing
and ``datetime64[ns]`` dtype.
halflife : float, str, timedelta, optional
Half-life specifying the decay

Returns
-------
np.ndarray
Diff of the times divided by the half-life
"""
unit = dtype_to_unit(times.dtype)
_times = np.asarray(times.view(np.int64), dtype=np.float64)
_halflife = float(pd.Timedelta(halflife).as_unit(unit)._value)
deltas = np.diff(_times) / _halflife
deltas = np.insert(deltas, 0, 1)
return deltas


def _verify_timedelta_requirements(
window_type: Literal["span", "com", "halflife", "alpha"], dim_type: np.dtype
):
"""
Check if the window type and dimension type are compatible.

This function is called when a window with data type 'Timedelta' is used,
and verifies that the window type is 'halflife' and the dimension type is
datetime64.

Parameters
----------
window_type : str
The type of the window.
dim_type : np.dtype
The type of the dimension.

Raises
------
ValueError
If the window type is not 'halflife' or the dimension type is not datetime64.
NotImplementedError
If the window type is 'halflife' and the dimension type is not datetime64.
"""
if window_type != "halflife":
raise ValueError(
"window with data type 'Timedelta' can only be used with window_type='halflife'"
)
if not is_np_datetime_like(dim_type):
raise NotImplementedError(
"window with data type 'Timedelta' must be used with a datetime64 coordinate"
)


class RollingExp(Generic[T_DataWithCoords]):
"""
Exponentially-weighted moving window object.
Similar to EWM in pandas

Similar to EWM in pandas. When using a Timedelta window with window_type='halflife',
the alpha values are computed based on the actual time differences between points,
allowing for irregular time series. This matches pandas' implementation in
pd.DataFrame.ewm(halflife=..., times=...).

Parameters
----------
obj : Dataset or DataArray
Object to window.
windows : mapping of hashable to int (or float for alpha type)
windows : mapping of hashable to int, float, or pd.Timedelta
A mapping from the name of the dimension to create the rolling
exponential window along (e.g. `time`) to the size of the moving window.
A pd.Timedelta can be provided for datetime dimensions only,
when using window_type='halflife'.
window_type : {"span", "com", "halflife", "alpha"}, default: "span"
The format of the previously supplied window. Each is a simple
numerical transformation of the others. Described in detail:
https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.ewm.html
When using a pd.Timedelta window, only 'halflife' is supported for window_type,
and it must be applied to a datetime coordinate. In this case, only the 'mean'
operation is supported.

Returns
-------
Expand All @@ -71,8 +171,8 @@ class RollingExp(Generic[T_DataWithCoords]):
def __init__(
self,
obj: T_DataWithCoords,
windows: Mapping[Any, int | float],
window_type: str = "span",
windows: Mapping[Any, int | float | pd.Timedelta],
window_type: Literal["span", "com", "halflife", "alpha"] = "span",
min_weight: float = 0.0,
):
if not module_available("numbagg"):
Expand All @@ -82,8 +182,16 @@ def __init__(

self.obj: T_DataWithCoords = obj
dim, window = next(iter(windows.items()))

if isinstance(window, pd.Timedelta):
_verify_timedelta_requirements(window_type, self.obj[dim].dtype)
deltas = _calculate_deltas(self.obj.get_index(dim), window)
# Equivalent to unweighted alpha=0.5 (like in pandas implementation)
self.alpha = 1 - (1 - 0.5) ** deltas
else:
self.alpha = _get_alpha(**{window_type: window})

self.dim = dim
self.alpha = _get_alpha(**{window_type: window})
self.min_weight = min_weight
# Don't pass min_weight=0 so we can support older versions of numbagg
kwargs = dict(alpha=self.alpha, axis=-1)
Expand Down Expand Up @@ -148,6 +256,7 @@ def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
array([1. , 1.33333333, 2.44444444, 2.81481481, 2.9382716 ])
Dimensions without coordinates: x
"""
_raise_if_array(self.alpha)

import numbagg

Expand Down Expand Up @@ -181,6 +290,7 @@ def std(self) -> T_DataWithCoords:
array([ nan, 0. , 0.67936622, 0.42966892, 0.25389527])
Dimensions without coordinates: x
"""
_raise_if_array(self.alpha)

import numbagg

Expand Down Expand Up @@ -211,6 +321,7 @@ def var(self) -> T_DataWithCoords:
array([ nan, 0. , 0.46153846, 0.18461538, 0.06446281])
Dimensions without coordinates: x
"""
_raise_if_array(self.alpha)
dim_order = self.obj.dims
import numbagg

Expand Down Expand Up @@ -239,7 +350,7 @@ def cov(self, other: T_DataWithCoords) -> T_DataWithCoords:
array([ nan, 0. , 1.38461538, 0.55384615, 0.19338843])
Dimensions without coordinates: x
"""

_raise_if_array(self.alpha)
dim_order = self.obj.dims
import numbagg

Expand Down Expand Up @@ -269,7 +380,7 @@ def corr(self, other: T_DataWithCoords) -> T_DataWithCoords:
array([ nan, nan, nan, 0.4330127 , 0.48038446])
Dimensions without coordinates: x
"""

_raise_if_array(self.alpha)
dim_order = self.obj.dims
import numbagg

Expand Down
92 changes: 90 additions & 2 deletions xarray/tests/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,59 @@ def test_rolling_exp_runs(self, da, dim, window_type, window, func) -> None:
@pytest.mark.parametrize("dim", ["time", "x"])
@pytest.mark.parametrize(
"window_type, window",
[["span", 5], ["alpha", 0.5], ["com", 0.5], ["halflife", 5]],
[
["span", pd.Timedelta(days=5)],
["alpha", pd.Timedelta(days=5)],
["com", pd.Timedelta(days=5)],
["halflife", pd.Timedelta(days=5)],
],
)
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
@pytest.mark.parametrize("func", ["mean", "sum", "var", "std"])
def test_rolling_exp_runs_timedelta(
self, da, dim, window_type, window, func
) -> None:
"""Test that rolling_exp works with Timedelta windows
only for halflife window_type and datetime index, and only
mean operation is supported.
"""
da = da.where(da > 0.2)

# Only halflife window_type should work with Timedelta
if window_type != "halflife":
with pytest.raises(ValueError):
da.rolling_exp(window_type=window_type, **{dim: window})
return

# Timedelta only works with datetime index
if dim != "time":
with pytest.raises(
NotImplementedError,
):
da.rolling_exp(window_type=window_type, **{dim: window})
return

rolling_exp = da.rolling_exp(window_type=window_type, **{dim: window})

# Only mean is supported for Timedelta windows
if func == "mean":
result = rolling_exp.mean()
assert isinstance(result, DataArray)
else:
with pytest.raises(
NotImplementedError, match="Only 'mean' operation is supported"
):
getattr(rolling_exp, func)()

@pytest.mark.parametrize("dim", ["time", "x"])
@pytest.mark.parametrize(
"window_type, window",
[
["span", 5],
["alpha", 0.5],
["com", 0.5],
["halflife", 5],
],
)
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
def test_rolling_exp_mean_pandas(self, da, dim, window_type, window) -> None:
Expand All @@ -466,12 +518,48 @@ def test_rolling_exp_mean_pandas(self, da, dim, window_type, window) -> None:
assert pandas_array.index.name == "time"
if dim == "x":
pandas_array = pandas_array.T

expected = xr.DataArray(
pandas_array.ewm(**{window_type: window}).mean()
).transpose(*da.dims)

assert_allclose(expected.variable, result.variable)

@pytest.mark.parametrize(
"window",
[
pd.Timedelta(days=5),
pd.Timedelta(days=10),
pd.Timedelta(weeks=1),
],
)
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
def test_rolling_exp_mean_pandas_halflife(self, da, window) -> None:
dim = "time"
window_type = "halflife"

da = da.isel(a=0).where(lambda x: x > 0.2)

result = da.rolling_exp(window_type=window_type, **{dim: window}).mean()
assert isinstance(result, DataArray)

pandas_array = da.to_pandas()
assert pandas_array.index.name == "time"

expected = xr.DataArray(
pandas_array.ewm(**{window_type: window}, times=pandas_array.index).mean()
).transpose(*da.dims)

assert_allclose(expected.variable, result.variable)

# test with different time units (ms -> s)
da["time"] = da.get_index("time").astype("datetime64[s]")
result = da.rolling_exp(window_type=window_type, **{dim: window}).mean()

# pandas < 2.2.0 does not support non-ns resolution for ewm
# so let's not re-assign pandas_array
assert_allclose(expected.variable, result.variable)

@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
@pytest.mark.parametrize("func", ["mean", "sum"])
def test_rolling_exp_keep_attrs(self, da, func) -> None:
Expand Down Expand Up @@ -899,7 +987,7 @@ def test_rolling_exp_keep_attrs(self, ds) -> None:
# discard attrs
result = ds.rolling_exp(time=10).mean(keep_attrs=False)
assert result.attrs == {}
# TODO: from #8114 — this arguably should be empty, but `apply_ufunc` doesn't do
# TODO: from #8114 — this arguably should be empty, but `apply_ufunc` doesn't do
# that at the moment. We should change in `apply_func` rather than
# special-case it here.
#
Expand Down
Loading