Skip to content

Commit

Permalink
[weatherbench2] Add auxiliary variables to config.Selection.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 597794742
  • Loading branch information
ilopezgp authored and Weatherbench2 authors committed Jan 12, 2024
1 parent d8b9b1a commit 3e4f392
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
3 changes: 3 additions & 0 deletions weatherbench2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class Selection:
levels: List of pressure levels.
lat_slice: Latitude range in degrees.
lon_slice: Longitude range in degrees.
aux_variables: Sequence of auxiliary forecast variables required for certain
evaluation metrics.
"""

variables: t.Sequence[str]
Expand All @@ -46,6 +48,7 @@ class Selection:
lon_slice: t.Optional[slice] = dataclasses.field(
default_factory=lambda: slice(None, None)
)
aux_variables: t.Optional[t.Sequence[str]] = None


@dataclasses.dataclass
Expand Down
14 changes: 11 additions & 3 deletions weatherbench2/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,14 @@ def _impose_data_selection(
selection: config.Selection,
select_time: bool = True,
time_dim: Optional[str] = None,
select_aux: bool = False,
) -> xr.Dataset:
"""Returns selection of dataset specified in Selection instance."""
dataset = dataset[selection.variables].sel(
if select_aux and selection.aux_variables is not None:
sel_variables = set(selection.variables) | set(selection.aux_variables)
else:
sel_variables = selection.variables
dataset = dataset[sel_variables].sel(
latitude=selection.lat_slice,
longitude=selection.lon_slice,
)
Expand Down Expand Up @@ -314,10 +319,12 @@ def open_forecast_and_truth_datasets(
)

obs_all_times = _impose_data_selection(
obs, data_config.selection, select_time=False
obs,
data_config.selection,
select_time=False,
)
forecast_all_times = _impose_data_selection(
forecast, data_config.selection, select_time=False
forecast, data_config.selection, select_time=False, select_aux=True
)

if data_config.by_init: # Will select appropriate chunks later
Expand All @@ -328,6 +335,7 @@ def open_forecast_and_truth_datasets(
forecast,
data_config.selection,
time_dim='init_time' if data_config.by_init else 'time',
select_aux=True,
)

# Determine ground truth dataset
Expand Down

0 comments on commit 3e4f392

Please sign in to comment.