Skip to content

Commit

Permalink
[weatherbench2] Add EnsembleRPS metric.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 605658781
  • Loading branch information
ilopezgp authored and Weatherbench2 authors committed Feb 9, 2024
1 parent 66d38da commit 2aa282a
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 1 deletion.
69 changes: 68 additions & 1 deletion weatherbench2/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,8 @@ class GaussianRPS(Metric):
"""Ranked probability score of a Gaussian forecast for a given quantization.
The ranked probability score (RPS) is computed based on the forecast and
observed cumulative distribution functions.
observed cumulative distribution functions. See `EnsembleRPS` for a discussion
of this metric.
References:
[Epstein, 1969] A Scoring System for Probability Forecasts of Ranked
Expand Down Expand Up @@ -1422,6 +1423,72 @@ def compute_chunk(
)


@dataclasses.dataclass
class EnsembleRPS(EnsembleMetric):
"""Ranked probability score of an ensemble forecast for a given quantization.
The ranked probability score (RPS) is computed based on the forecast and
observed cumulative distribution functions, coarsened to the level of
the input thresholds.
The thresholds are taken to define the limits of each considered interval,
except the first (resp. last) interval, which includes all values
lower (resp. higher) than the first (resp. last) threshold. Three thresholds
would define the following four intervals:
<-- | --- | --- | -->
As an example, if the thresholds are climatological terciles [0.33, 0.66],
the observed event was at the climatological quantile 0.5, and the ensemble
forecasts were at climatological quantiles [0.1, 0.5, 0.6, 0.8], then the
observed CDF would be [0, 1, 1], and the forecast CDF would be
[0.25, 0.75, 1]. Note that the score over the last interval need not
be computed, since the quantized CDFs are always one for both forecasts and
observations there.
References:
[Epstein, 1969] A Scoring System for Probability Forecasts of Ranked
Categories,
DOI: https://doi.org/10.1175/1520-0450(1969)008<0985:ASSFPF>2.0.CO;2
"""

def __init__(
self,
threshold: Sequence[thresholds.Threshold],
ensemble_dim: str = REALIZATION,
):
"""Initializes an EnsembleRPS.
Args:
threshold: A sequence of thresholds used to divide predictions and targets
categorically.
ensemble_dim: Dimension indexing ensemble member.
"""
super().__init__(ensemble_dim=ensemble_dim)
self.thresholds = threshold

def compute_chunk(
self,
forecast: xr.Dataset,
truth: xr.Dataset,
region: t.Optional[Region] = None,
) -> xr.Dataset:
"""Spatially averaged RPS of the ensemble forecast."""
rps_per_threshold = []
threshold_list = [t.compute(truth) for t in self.thresholds]
for threshold in threshold_list:

truth_ecdf = xr.where(truth < threshold, 1.0, 0.0)
forecast_ecdf = xr.where(forecast < threshold, 1.0, 0.0)
ensemble_forecast_ecdf = forecast_ecdf.mean(
self.ensemble_dim, skipna=False
)

rps_per_threshold.append((ensemble_forecast_ecdf - truth_ecdf) ** 2)

return _spatial_average(sum(rps_per_threshold), region=region)


@dataclasses.dataclass
class RankHistogram(EnsembleMetric):
"""Histogram of truth's rank with respect to forecast ensemble members.
Expand Down
59 changes: 59 additions & 0 deletions weatherbench2/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,65 @@ def test_ensemble_ignorance_score(self, error, expected):
)


class EnsembleRPSTest(parameterized.TestCase):

@parameterized.named_parameters(
dict(
testcase_name='good model',
error=0.02,
expected=0.0,
),
dict(
testcase_name='poor model',
error=-2.0,
expected=2,
),
)
def test_ensemble_rps(self, error, expected):
kwargs = {
'variables_2d': ['2m_temperature'],
'variables_3d': [],
'time_start': '2022-01-01',
'time_stop': '2022-01-02',
}
forecast = schema.mock_forecast_data(
ensemble_size=4, lead_stop='1 day', **kwargs
)
truth = schema.mock_truth_data(**kwargs)
q_1 = (
truth.isel(time=0, drop=True)
.expand_dims(dim={'dayofyear': 366, 'quantile': np.array([0.33])})
.rename({'2m_temperature': '2m_temperature_quantile'})
)
q_2 = (
(truth + 1.0)
.isel(time=0, drop=True)
.expand_dims(dim={'dayofyear': 366, 'quantile': np.array([0.66])})
.rename({'2m_temperature': '2m_temperature_quantile'})
)
q_3 = (
(truth + 2.0)
.isel(time=0, drop=True)
.expand_dims(dim={'dayofyear': 366, 'quantile': np.array([1.0])})
.rename({'2m_temperature': '2m_temperature_quantile'})
)
climatology = xr.merge([q_1, q_2, q_3])

truth = truth + 1.5
forecast = forecast + 1.0 + error

threshold_list = [
thresholds.QuantileThreshold(climatology=climatology, quantile=q)
for q in [0.33, 0.66, 1.0]
]

result = metrics.EnsembleRPS(threshold_list).compute(forecast, truth)
expected_arr = np.array([expected, expected])
np.testing.assert_allclose(
result['2m_temperature'].values, expected_arr, rtol=1e-4
)


class SEEPSTest(absltest.TestCase):

def testExpectedValues(self):
Expand Down

0 comments on commit 2aa282a

Please sign in to comment.