diff --git a/weatherbench2/config.py b/weatherbench2/config.py index 13c0969..3be06e9 100644 --- a/weatherbench2/config.py +++ b/weatherbench2/config.py @@ -35,6 +35,8 @@ class Selection: levels: List of pressure levels. lat_slice: Latitude range in degrees. lon_slice: Longitude range in degrees. + aux_variables: Sequence of auxiliary forecast variables required for certain + evaluation metrics. """ variables: t.Sequence[str] @@ -46,6 +48,7 @@ class Selection: lon_slice: t.Optional[slice] = dataclasses.field( default_factory=lambda: slice(None, None) ) + aux_variables: t.Optional[t.Sequence[str]] = None @dataclasses.dataclass diff --git a/weatherbench2/evaluation.py b/weatherbench2/evaluation.py index a93907b..8fea878 100644 --- a/weatherbench2/evaluation.py +++ b/weatherbench2/evaluation.py @@ -140,9 +140,14 @@ def _impose_data_selection( selection: config.Selection, select_time: bool = True, time_dim: Optional[str] = None, + select_aux: bool = False, ) -> xr.Dataset: """Returns selection of dataset specified in Selection instance.""" - dataset = dataset[selection.variables].sel( + if select_aux and selection.aux_variables is not None: + sel_variables = set(selection.variables) | set(selection.aux_variables) + else: + sel_variables = selection.variables + dataset = dataset[sel_variables].sel( latitude=selection.lat_slice, longitude=selection.lon_slice, ) @@ -314,10 +319,12 @@ def open_forecast_and_truth_datasets( ) obs_all_times = _impose_data_selection( - obs, data_config.selection, select_time=False + obs, + data_config.selection, + select_time=False, ) forecast_all_times = _impose_data_selection( - forecast, data_config.selection, select_time=False + forecast, data_config.selection, select_time=False, select_aux=True ) if data_config.by_init: # Will select appropriate chunks later @@ -328,6 +335,7 @@ def open_forecast_and_truth_datasets( forecast, data_config.selection, time_dim='init_time' if data_config.by_init else 'time', + select_aux=True, ) # Determine ground truth dataset