Skip to content

Commit

Permalink
Speed up CRPS evaluation by caching spread computation (it sorts larg…
Browse files Browse the repository at this point in the history
…e arrays). This CL adds `@utils.dataset_safe_lru_cache`, which is similar to `functools.lru_cache`, but handles `xarray.Dataset` in addition to `Hashable`.

Runtimes for `240x121` grid, `chunks={'init_time': 1, 'lead_time': 8}` and `ensemble_size=50`.

[Verified](http://screen/AeCumqcecvLComL) that the first CRPS call takes ~55 seconds, and subsequent ones (for the same loop) are < 2sec.

Correctness verified in unit tests and by comparing metrics [before](https://screen/Vrk5Hg9j273YutK) and [after](http://screen/8BuwEKfCHdTkJ4H)

PiperOrigin-RevId: 644185274
  • Loading branch information
langmore authored and Weatherbench2 authors committed Jun 18, 2024
1 parent c342117 commit 16e0131
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 6 deletions.
1 change: 1 addition & 0 deletions weatherbench2/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def _metric_and_region_loop(
) -> xr.Dataset:
"""Compute metric results looping over metrics and regions in eval config."""
# Compute derived variables
logging.info('Starting _metric_and_region_loop')
for name, dv in eval_config.derived_variables.items():
logging.info(f'Logging: derived_variable {name!r}: {dv}')
forecast[name] = dv.compute(forecast)
Expand Down
14 changes: 10 additions & 4 deletions weatherbench2/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as np
from scipy import stats
from weatherbench2 import thresholds
from weatherbench2 import utils
from weatherbench2.regions import Region
import xarray as xr

Expand Down Expand Up @@ -637,7 +638,7 @@ def compute_chunk(
) -> xr.Dataset:
"""CRPSSpread, averaged over space, for a time chunk of data."""
return _spatial_average(
_pointwise_crps_spread(forecast, truth, self.ensemble_dim),
_pointwise_crps_spread(forecast, self.ensemble_dim),
region=region,
)

Expand Down Expand Up @@ -688,7 +689,7 @@ def compute_chunk(
region: t.Optional[Region] = None,
) -> xr.Dataset:
"""CRPSSpread, averaged over space, for a time chunk of data."""
return _pointwise_crps_spread(forecast, truth, self.ensemble_dim)
return _pointwise_crps_spread(forecast, self.ensemble_dim)


@dataclasses.dataclass
Expand All @@ -705,11 +706,16 @@ def compute_chunk(
return _pointwise_crps_skill(forecast, truth, self.ensemble_dim)


@utils.dataset_safe_lru_cache(
# This is used in _metric_and_region_loop. The same dataset is used
# repeatedly for different metrics/regions, then the loop returns.
# Therefore, maxsize=1 is sufficient.
maxsize=1,
)
def _pointwise_crps_spread(
forecast: xr.Dataset, truth: xr.Dataset, ensemble_dim: str
forecast: xr.Dataset, ensemble_dim: str
) -> xr.Dataset:
"""CRPS spread at each point in truth, averaged over ensemble only."""
del truth # unused
n_ensemble = _get_n_ensemble(forecast, ensemble_dim)
if n_ensemble < 2:
return xr.zeros_like(forecast.isel({ensemble_dim: 0}))
Expand Down
56 changes: 56 additions & 0 deletions weatherbench2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Utility function for WeatherBench2."""
import functools
from typing import Callable, Union

import fsspec
Expand Down Expand Up @@ -292,3 +293,58 @@ def random_like(dataset: xr.Dataset, seed: int = 0) -> xr.Dataset:
return dataset.copy(
data={k: rs.normal(size=v.shape) for k, v in dataset.items()}
)


class _WrappedDataset:
"""Hashable wrapper for xarray.Datasets."""

def __init__(self, value):
if not isinstance(value, xr.Dataset):
raise ValueError(f'_WrappedDataset cannot wrap type {type(value)}')
self.value = value

def __eq__(self, other):
if not isinstance(other, _WrappedDataset):
return False
return self.value.equals(other.value)

def __hash__(self):
# Something that can be calculated quickly -- we won't have many collisions.
# Hash collisions just mean that that __eq__ needs to be checked.
# https://stackoverflow.com/questions/16589791/most-efficient-property-to-hash-for-numpy-array
return hash(
tuple(
(k, repr(v.data.ravel())) for k, v in self.value.data_vars.items()
)
)


def dataset_safe_lru_cache(maxsize=128):
"""An xarray.Dataset compatible version of functools.lru_cache."""

def decorator(func): # pylint: disable=missing-docstring
@functools.lru_cache(maxsize)
def cached_func(*args, **kwargs):
args = tuple(
a.value if isinstance(a, _WrappedDataset) else a for a in args
)
kwargs = {
k: v.value if isinstance(v, _WrappedDataset) else v
for k, v in kwargs.items()
}
return func(*args, **kwargs)

@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=missing-docstring
args = tuple(
_WrappedDataset(a) if isinstance(a, xr.Dataset) else a for a in args
)
kwargs = {
k: _WrappedDataset(v) if isinstance(v, xr.Dataset) else v
for k, v in kwargs.items()
}
return cached_func(*args, **kwargs)

return wrapper

return decorator
47 changes: 45 additions & 2 deletions weatherbench2/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.
# ==============================================================================
from absl.testing import absltest
import numpy as np
from weatherbench2 import schema
from weatherbench2 import utils
import xarray
import xarray as xr


class UtilsTest(absltest.TestCase):
Expand Down Expand Up @@ -43,7 +44,7 @@ def testMethodEquivalence(self):
hour_interval=24,
stat_fn='mean',
)
xarray.testing.assert_allclose(explicit, fast)
xr.testing.assert_allclose(explicit, fast)

def testProbabilisticClimatology(self):
truth = schema.mock_truth_data(
Expand All @@ -67,5 +68,47 @@ def testProbabilisticClimatology(self):
self.assertEqual(clim['2m_temperature'].sizes, expected_sizes)


class DatasetSafeLRUCacheTest(absltest.TestCase):

def test_handles_non_hashable_args_and_kwargs(self):

def dataset(z) -> xr.Dataset:
z = np.asarray(z)
assert z.ndim == 1
return xr.Dataset(
data_vars={'temperature': (('level',), z)},
coords={'level': np.arange(len(z))},
)

@utils.dataset_safe_lru_cache(maxsize=2)
def func(x: xr.Dataset, y: xr.Dataset, b: float = 1):
return (x + y * b).sum()

# Use 3 sets of arrays so we are sure to cycle through the size 2 cache.
with self.subTest('First set of Datasets'):
x = dataset([1.0, 2.0, 3.0])
y = x + 2
b = 1.3
expected = np.sum(x + y * b)
for _ in range(4):
self.assertEqual(expected, func(x, y, b=b))

with self.subTest('Second set of Datasets'):
x = dataset([0.0, -2.0, 0.123])
y = dataset([10.0, -1.0, 3])
b = 10.3
expected = np.sum(x + y * b)
for _ in range(4):
self.assertEqual(expected, func(x, y, b=b))

with self.subTest('Third set of Datasets'):
x = dataset([0.0, -20.0])
y = dataset([10.0, -11.0])
b = -1234
expected = np.sum(x + y * b)
for _ in range(4):
self.assertEqual(expected, func(x, y, b=b))


if __name__ == '__main__':
absltest.main()

0 comments on commit 16e0131

Please sign in to comment.