Skip to content

Commit

Permalink
Merge branch 'main' into first-update
Browse files Browse the repository at this point in the history
  • Loading branch information
raspstephan committed Jan 17, 2024
2 parents 1e7126d + a5d30f9 commit 0ecdb29
Show file tree
Hide file tree
Showing 10 changed files with 263 additions and 74 deletions.
3 changes: 2 additions & 1 deletion scripts/compute_statistical_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def main(argv: list[str]) -> None:
)
# Rechunk in time
pcoll_time = pcoll_tmp | f'RechunkTime_{order}' >> xbeam.Rechunk(
space_reduce_template.sizes,
# Convert to string to satisfy pytype.
{str(k): v for k, v in space_reduce_template.sizes.items()},
reduce_working_chunks,
time_working_chunks,
itemsize=RECHUNK_ITEMSIZE.value,
Expand Down
65 changes: 50 additions & 15 deletions scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@
_DEFAULT_VARIABLES,
help='Comma delimited list of variables to select from weather.',
)
AUX_VARIABLES = flags.DEFINE_list(
'aux_variables',
None,
help='Comma delimited list of auxiliary variables for metric evaluation.',
)
DERIVED_VARIABLES = flags.DEFINE_list(
'derived_variables',
[],
Expand Down Expand Up @@ -229,9 +234,15 @@
)


def _wind_vector_rmse():
"""Defines Wind Vector RMSEs if U/V components are in variables."""
wind_vector_rmse = []
def _wind_vector_error(err_type: str):
"""Defines Wind Vector [R]MSEs if U/V components are in variables."""
if err_type == 'mse':
cls = metrics.WindVectorMSE
elif err_type == 'rmse':
cls = metrics.WindVectorRMSESqrtBeforeTimeAvg
else:
raise ValueError(f'Unrecognized {err_type=}')
wind_vector_error = []
available = set(VARIABLES.value).union(DERIVED_VARIABLES.value)
for u_name, v_name, vector_name in [
('u_component_of_wind', 'v_component_of_wind', 'wind_vector'),
Expand All @@ -248,20 +259,21 @@ def _wind_vector_rmse():
),
]:
if u_name in available and v_name in available:
wind_vector_rmse.append(
metrics.WindVectorRMSE(
wind_vector_error.append(
cls(
u_name=u_name,
v_name=v_name,
vector_name=vector_name,
)
)
return wind_vector_rmse
return wind_vector_error


def main(argv: list[str]) -> None:
"""Run all WB2 metrics."""
selection = config.Selection(
variables=VARIABLES.value,
aux_variables=AUX_VARIABLES.value,
levels=[int(level) for level in LEVELS.value],
time_slice=slice(TIME_START.value, TIME_STOP.value),
)
Expand Down Expand Up @@ -331,6 +343,12 @@ def main(argv: list[str]) -> None:
LandRegion(land_sea_mask=land_sea_mask),
]
),
'tropics_land': CombinedRegion(
regions=[
SliceRegion(lat_slice=slice(-20, 20)),
LandRegion(land_sea_mask=land_sea_mask),
]
),
}
predefined_regions = predefined_regions | land_regions
except KeyError:
Expand All @@ -349,12 +367,16 @@ def main(argv: list[str]) -> None:
climatology = evaluation.make_latitude_increasing(climatology)

deterministic_metrics = {
'rmse': metrics.RMSE(wind_vector_rmse=_wind_vector_rmse()),
'mse': metrics.MSE(),
'mse': metrics.MSE(wind_vector_mse=_wind_vector_error('mse')),
'acc': metrics.ACC(climatology=climatology),
'bias': metrics.Bias(),
'mae': metrics.MAE(),
}
rmse_metrics = {
'rmse_sqrt_before_time_avg': metrics.RMSESqrtBeforeTimeAvg(
wind_vector_rmse=_wind_vector_error('rmse')
),
}
spatial_metrics = {
'bias': metrics.SpatialBias(),
'mse': metrics.SpatialMSE(),
Expand Down Expand Up @@ -404,7 +426,7 @@ def main(argv: list[str]) -> None:
output_format='zarr',
),
'deterministic_temporal': config.Eval(
metrics=deterministic_metrics,
metrics=deterministic_metrics | rmse_metrics,
against_analysis=False,
regions=regions,
derived_variables=derived_variables,
Expand All @@ -427,15 +449,9 @@ def main(argv: list[str]) -> None:
ensemble_dim=ENSEMBLE_DIM.value
),
'crps_skill': metrics.CRPSSkill(ensemble_dim=ENSEMBLE_DIM.value),
'ensemble_mean_rmse': metrics.EnsembleMeanRMSE(
ensemble_dim=ENSEMBLE_DIM.value
),
'ensemble_mean_mse': metrics.EnsembleMeanMSE(
ensemble_dim=ENSEMBLE_DIM.value
),
'ensemble_stddev': metrics.EnsembleStddev(
ensemble_dim=ENSEMBLE_DIM.value
),
'ensemble_variance': metrics.EnsembleVariance(
ensemble_dim=ENSEMBLE_DIM.value
),
Expand All @@ -459,6 +475,16 @@ def main(argv: list[str]) -> None:
'energy_score_skill': metrics.EnergyScoreSkill(
ensemble_dim=ENSEMBLE_DIM.value
),
'ensemble_mean_rmse_sqrt_before_time_avg': (
metrics.EnsembleMeanRMSESqrtBeforeTimeAvg(
ensemble_dim=ENSEMBLE_DIM.value
)
),
'ensemble_stddev_sqrt_before_time_avg': (
metrics.EnsembleStddevSqrtBeforeTimeAvg(
ensemble_dim=ENSEMBLE_DIM.value
)
),
},
against_analysis=False,
derived_variables=derived_variables,
Expand Down Expand Up @@ -501,6 +527,15 @@ def main(argv: list[str]) -> None:
probabilistic_climatology_hour_interval=PROBABILISTIC_CLIMATOLOGY_HOUR_INTERVAL.value,
output_format='zarr',
),
'gaussian': config.Eval(
metrics={
'crps': metrics.GaussianCRPS(),
'ensemble_variance': metrics.GaussianVariance(),
},
against_analysis=False,
regions=regions,
derived_variables=derived_variables,
),
}
if not set(EVAL_CONFIGS.value.split(',')).issubset(eval_configs):
raise flags.UnrecognizedFlagError(
Expand Down
4 changes: 2 additions & 2 deletions scripts/evaluate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def _test(self, use_beam=True, input_chunks=None):
evaluate.main([])

for config_name in eval_configs:
expected_sizes_2d = {'metric': 5, 'lead_time': 4, 'region': 4}
expected_sizes_3d = {'metric': 5, 'lead_time': 4, 'region': 4, 'level': 3}
expected_sizes_2d = {'metric': 4, 'lead_time': 4, 'region': 4}
expected_sizes_3d = {'metric': 4, 'lead_time': 4, 'region': 4, 'level': 3}

with self.subTest(config_name):
results_path = os.path.join(output_dir, f'{config_name}.nc')
Expand Down
3 changes: 3 additions & 0 deletions weatherbench2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class Selection:
levels: List of pressure levels.
lat_slice: Latitude range in degrees.
lon_slice: Longitude range in degrees.
aux_variables: Sequence of auxiliary forecast variables required for certain
evaluation metrics.
"""

variables: t.Sequence[str]
Expand All @@ -46,6 +48,7 @@ class Selection:
lon_slice: t.Optional[slice] = dataclasses.field(
default_factory=lambda: slice(None, None)
)
aux_variables: t.Optional[t.Sequence[str]] = None


@dataclasses.dataclass
Expand Down
14 changes: 11 additions & 3 deletions weatherbench2/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,14 @@ def _impose_data_selection(
selection: config.Selection,
select_time: bool = True,
time_dim: Optional[str] = None,
select_aux: bool = False,
) -> xr.Dataset:
"""Returns selection of dataset specified in Selection instance."""
dataset = dataset[selection.variables].sel(
if select_aux and selection.aux_variables is not None:
sel_variables = set(selection.variables) | set(selection.aux_variables)
else:
sel_variables = selection.variables
dataset = dataset[sel_variables].sel(
latitude=selection.lat_slice,
longitude=selection.lon_slice,
)
Expand Down Expand Up @@ -314,10 +319,12 @@ def open_forecast_and_truth_datasets(
)

obs_all_times = _impose_data_selection(
obs, data_config.selection, select_time=False
obs,
data_config.selection,
select_time=False,
)
forecast_all_times = _impose_data_selection(
forecast, data_config.selection, select_time=False
forecast, data_config.selection, select_time=False, select_aux=True
)

if data_config.by_init: # Will select appropriate chunks later
Expand All @@ -328,6 +335,7 @@ def open_forecast_and_truth_datasets(
forecast,
data_config.selection,
time_dim='init_time' if data_config.by_init else 'time',
select_aux=True,
)

# Determine ground truth dataset
Expand Down
6 changes: 3 additions & 3 deletions weatherbench2/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ def test_in_memory_and_beam_consistency(self):
eval_configs = {
'forecast_vs_era': config.Eval(
metrics={
'rmse': metrics.RMSE(),
'rmse': metrics.RMSESqrtBeforeTimeAvg(),
'acc': metrics.ACC(climatology=climatology),
},
against_analysis=False,
),
'forecast_vs_era_by_region': config.Eval(
metrics={'rmse': metrics.RMSE()},
metrics={'rmse': metrics.RMSESqrtBeforeTimeAvg()},
against_analysis=False,
regions=regions,
),
Expand All @@ -101,7 +101,7 @@ def test_in_memory_and_beam_consistency(self):
against_analysis=False,
),
'forecast_vs_era_temporal': config.Eval(
metrics={'rmse': metrics.RMSE()},
metrics={'rmse': metrics.RMSESqrtBeforeTimeAvg()},
against_analysis=False,
temporal_mean=False,
),
Expand Down
Loading

0 comments on commit 0ecdb29

Please sign in to comment.