diff --git a/weatherbench2/schema.py b/weatherbench2/schema.py index c037c1b..9a75b99 100644 --- a/weatherbench2/schema.py +++ b/weatherbench2/schema.py @@ -29,18 +29,18 @@ def apply_time_conventions( forecast = forecast.copy() if 'prediction_timedelta' in forecast.coords: forecast = forecast.rename({'prediction_timedelta': 'lead_time'}) - if by_init: - # Need to rename time dimension because different from time dimension in - # truth dataset - forecast = forecast.rename({'time': 'init_time'}) - valid_time = forecast.init_time + forecast.lead_time - forecast.coords['valid_time'] = valid_time - assert not hasattr( - forecast, 'time' - ), f'Forecast should not have time dimension at this point: {forecast}' - else: - init_time = forecast.time - forecast.lead_time - forecast.coords['init_time'] = init_time + if by_init: + # Need to rename time dimension because different from time dimension in + # truth dataset + forecast = forecast.rename({'time': 'init_time'}) + valid_time = forecast.init_time + forecast.lead_time + forecast.coords['valid_time'] = valid_time + assert not hasattr( + forecast, 'time' + ), f'Forecast should not have time dimension at this point: {forecast}' + else: + init_time = forecast.time - forecast.lead_time + forecast.coords['init_time'] = init_time return forecast