From 8da460754507776f0822be8d9ac2c8fa913de579 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 17 Sep 2025 08:43:20 -0700 Subject: [PATCH 01/32] adding index exclusion list support in dali loaders, making index handling more failsafe --- makani/utils/dataloaders/dali_es_helper_2d.py | 32 ++++++------------- .../dataloaders/dali_es_helper_concat_2d.py | 1 + makani/utils/inference/inferencer.py | 1 + 3 files changed, 12 insertions(+), 22 deletions(-) diff --git a/makani/utils/dataloaders/dali_es_helper_2d.py b/makani/utils/dataloaders/dali_es_helper_2d.py index 5a8b252..7581349 100644 --- a/makani/utils/dataloaders/dali_es_helper_2d.py +++ b/makani/utils/dataloaders/dali_es_helper_2d.py @@ -183,7 +183,7 @@ def _generate_indexlist(self, timestamp_boundary_list): if timestamp_boundary_list: #compute list of allowed timestamps timestamp_boundary_list = [get_date_from_string(timestamp_string) for timestamp_string in timestamp_boundary_list] - + # now, based on dt, dh, n_history and n_future, we can build regions where no data is allowed timestamp_exclusion_list = get_date_ranges(timestamp_boundary_list, lookback_hours = dt_total * (self.n_future + 1), lookahead_hours = dt_total * self.n_history) @@ -227,28 +227,16 @@ def _get_stats_h5(self, enable_logging): if enable_logging: logging.info("Getting file stats from {}".format(self.files_paths[0])) # original image shape (before padding) - dset = _f[self.dataset_path] - self.img_shape = dset.shape[2:4] - self.total_channels = dset.shape[1] - self.n_samples_year.append(dset.shape[0]) - # read timestamps - if "timestamp" in dset.dims[0]: - self.timestamps.append(self.timezone_fn(dset.dims[0]["timestamp"][...])) - else: - timestamps = np.asarray([get_timestamp(self.years[0], hour=(idx * self.dhours)).timestamp() for idx in range(0, dset.shape[0], self.dhours)]) - self.timestamps.append(self.timezone_fn(timestamps)) + self.img_shape = _f[self.dataset_path].shape[2:4] + self.total_channels = _f[self.dataset_path].shape[1] + self.n_samples_year.append(_f[self.dataset_path].shape[0]) + self.timestamps.append(self.timezone_fn(_f[self.dataset_path].dims[0]["timestamp"][...])) # get all sample counts for idf, filename in enumerate(self.files_paths[1:], start=1): with fopen_handle(filename) as _f: - dset = _f[self.dataset_path] - self.n_samples_year.append(dset.shape[0]) - # read timestamps - if "timestamp" in dset.dims[0]: - self.timestamps.append(self.timezone_fn(dset.dims[0]["timestamp"][...])) - else: - timestamps = np.asarray([get_timestamp(self.years[idf], hour=(idx * self.dhours)).timestamp() for idx in range(0, dset.shape[0], self.dhours)]) - self.timestamps.append(self.timezone_fn(timestamps)) + self.n_samples_year.append(_f[self.dataset_path].shape[0]) + self.timestamps.append(self.timezone_fn(_f[self.dataset_path].dims[0]["timestamp"][...])) self.timestamps = np.concatenate(self.timestamps, axis=0) @@ -521,7 +509,7 @@ def _compute_zenith_angle(self, inp_times, tar_times): # nvtx range torch.cuda.nvtx.range_pop() - return cos_zenith_inp, cos_zenith_tar + return cos_zenith_inp, cos_zenith_tar def __getstate__(self): del self.aws_connector @@ -593,8 +581,8 @@ def __call__(self, sample_info): local_idx, year_idx = self._get_local_year_index_from_global_index(sample_idx) # if we are not at least self.dt*n_history timesteps into the prediction - local_idx = max(local_idx, self.dt * self.n_history) - local_idx = min(local_idx, self.n_samples_year[year_idx] - self.dt * (self.n_future + 1) - 1) + local_idx = min(local_idx, self.dt * self.n_history) + local_idx = max(local_idx, self.n_samples_year[year_idx] - self.dt * (self.n_future + 1) - 1) if self.files[year_idx] is None: diff --git a/makani/utils/dataloaders/dali_es_helper_concat_2d.py b/makani/utils/dataloaders/dali_es_helper_concat_2d.py index 21f8b6f..f9af5a9 100644 --- a/makani/utils/dataloaders/dali_es_helper_concat_2d.py +++ b/makani/utils/dataloaders/dali_es_helper_concat_2d.py @@ -16,6 +16,7 @@ import time import sys import os +from numba.parfors.parfor import replace_returns import numpy as np import h5py import logging diff --git a/makani/utils/inference/inferencer.py b/makani/utils/inference/inferencer.py index 7f8999b..2027c6c 100644 --- a/makani/utils/inference/inferencer.py +++ b/makani/utils/inference/inferencer.py @@ -39,6 +39,7 @@ # distributed computing stuff from makani.utils import comm +from makani.utils import visualize from makani.utils.dataloaders.data_helpers import get_date_from_string # inference specific stuff From 7bd3f42f60cb486688282a8e71cf31a00ffa19d2 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 17 Sep 2025 09:24:10 -0700 Subject: [PATCH 02/32] some small fixes for non combined dali loader and unannotated files --- makani/utils/dataloaders/dali_es_helper_2d.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/makani/utils/dataloaders/dali_es_helper_2d.py b/makani/utils/dataloaders/dali_es_helper_2d.py index 7581349..b2a63a1 100644 --- a/makani/utils/dataloaders/dali_es_helper_2d.py +++ b/makani/utils/dataloaders/dali_es_helper_2d.py @@ -230,7 +230,11 @@ def _get_stats_h5(self, enable_logging): self.img_shape = _f[self.dataset_path].shape[2:4] self.total_channels = _f[self.dataset_path].shape[1] self.n_samples_year.append(_f[self.dataset_path].shape[0]) - self.timestamps.append(self.timezone_fn(_f[self.dataset_path].dims[0]["timestamp"][...])) + if "timestamp" in _f[self.dataset_path].dims[0]: + self.timestamps.append(self.timezone_fn(_f[self.dataset_path].dims[0]["timestamp"][...])) + else: + timestamps = np.asarray([get_timestamp(self.years[0], hour=(idx * self.dhours)).timestamp() for idx in range(0, _f[self.dataset_path].shape[0], self.dhours)]) + self.timestamps.append(self.timezone_fn(timestamps)) # get all sample counts for idf, filename in enumerate(self.files_paths[1:], start=1): From 0c875db0ed21d0afc088b8be5b232f9a4c005f22 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Thu, 18 Sep 2025 03:53:49 -0700 Subject: [PATCH 03/32] distringuishing between read_shape and return_shape. read_shape will only be used by the dataloader internally, return_shape can be queried by other modules to determine the size of the returned samples --- makani/utils/dataloaders/dali_es_helper_2d.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/makani/utils/dataloaders/dali_es_helper_2d.py b/makani/utils/dataloaders/dali_es_helper_2d.py index b2a63a1..354e86c 100644 --- a/makani/utils/dataloaders/dali_es_helper_2d.py +++ b/makani/utils/dataloaders/dali_es_helper_2d.py @@ -227,20 +227,26 @@ def _get_stats_h5(self, enable_logging): if enable_logging: logging.info("Getting file stats from {}".format(self.files_paths[0])) # original image shape (before padding) - self.img_shape = _f[self.dataset_path].shape[2:4] - self.total_channels = _f[self.dataset_path].shape[1] - self.n_samples_year.append(_f[self.dataset_path].shape[0]) - if "timestamp" in _f[self.dataset_path].dims[0]: - self.timestamps.append(self.timezone_fn(_f[self.dataset_path].dims[0]["timestamp"][...])) + dset = _f[self.dataset_path] + self.img_shape = dset.shape[2:4] + self.total_channels = dset.shape[1] + self.n_samples_year.append(dset.shape[0]) + if "timestamp" in dset.dims[0]: + self.timestamps.append(self.timezone_fn(dset.dims[0]["timestamp"][...])) else: - timestamps = np.asarray([get_timestamp(self.years[0], hour=(idx * self.dhours)).timestamp() for idx in range(0, _f[self.dataset_path].shape[0], self.dhours)]) + timestamps = np.asarray([get_timestamp(self.years[0], hour=(idx * self.dhours)).timestamp() for idx in range(0, dset.shape[0], self.dhours)]) self.timestamps.append(self.timezone_fn(timestamps)) # get all sample counts for idf, filename in enumerate(self.files_paths[1:], start=1): with fopen_handle(filename) as _f: - self.n_samples_year.append(_f[self.dataset_path].shape[0]) - self.timestamps.append(self.timezone_fn(_f[self.dataset_path].dims[0]["timestamp"][...])) + dset = _f[self.dataset_path] + if "timestamp" in dset.dims[0]: + self.n_samples_year.append(dset.shape[0]) + self.timestamps.append(self.timezone_fn(dset.dims[0]["timestamp"][...])) + else: + timestamps = np.asarray([get_timestamp(self.years[idf], hour=(idx * self.dhours)).timestamp() for idx in range(0, dset.shape[0], self.dhours)]) + self.timestamps.append(self.timezone_fn(timestamps)) self.timestamps = np.concatenate(self.timestamps, axis=0) From 3773f84a463d9ca5aa099509b74b4d5fb25f7154 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Thu, 18 Sep 2025 04:23:08 -0700 Subject: [PATCH 04/32] fixing a bug where samples per year was not computed correctly --- makani/utils/dataloaders/dali_es_helper_2d.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/makani/utils/dataloaders/dali_es_helper_2d.py b/makani/utils/dataloaders/dali_es_helper_2d.py index 354e86c..f629686 100644 --- a/makani/utils/dataloaders/dali_es_helper_2d.py +++ b/makani/utils/dataloaders/dali_es_helper_2d.py @@ -231,6 +231,7 @@ def _get_stats_h5(self, enable_logging): self.img_shape = dset.shape[2:4] self.total_channels = dset.shape[1] self.n_samples_year.append(dset.shape[0]) + # read timestamps if "timestamp" in dset.dims[0]: self.timestamps.append(self.timezone_fn(dset.dims[0]["timestamp"][...])) else: @@ -241,8 +242,9 @@ def _get_stats_h5(self, enable_logging): for idf, filename in enumerate(self.files_paths[1:], start=1): with fopen_handle(filename) as _f: dset = _f[self.dataset_path] + self.n_samples_year.append(dset.shape[0]) + # read timestamps if "timestamp" in dset.dims[0]: - self.n_samples_year.append(dset.shape[0]) self.timestamps.append(self.timezone_fn(dset.dims[0]["timestamp"][...])) else: timestamps = np.asarray([get_timestamp(self.years[idf], hour=(idx * self.dhours)).timestamp() for idx in range(0, dset.shape[0], self.dhours)]) @@ -459,7 +461,7 @@ def _initialize_dataset_properties(self, enable_logging, timestamp_boundary_list if enable_logging: logging.info("Average number of samples per year: {:.1f}".format(float(self.n_samples_total) / float(self.n_years))) logging.info( - "Found data at path {}. Number of examples: {} (distributed over {} files). Full image Shape: {} x {} x {}. Read Shape: {} x {} x {}".format( + "Found data at path {}. Number of examples: {} (distributed over {} number of files). Full image Shape: {} x {} x {}. Read Shape: {} x {} x {}".format( self.location, self.n_samples_available, len(self.files_paths), self.img_shape[0], self.img_shape[1], self.total_channels, self.read_shape[0], self.read_shape[1], self.n_in_channels ) ) From e2ec9e0021fd385c11475f6c86e17b46bbeb5a05 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Thu, 18 Sep 2025 04:54:27 -0700 Subject: [PATCH 05/32] making nicer prints --- makani/utils/dataloaders/dali_es_helper_2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/makani/utils/dataloaders/dali_es_helper_2d.py b/makani/utils/dataloaders/dali_es_helper_2d.py index f629686..2752733 100644 --- a/makani/utils/dataloaders/dali_es_helper_2d.py +++ b/makani/utils/dataloaders/dali_es_helper_2d.py @@ -461,7 +461,7 @@ def _initialize_dataset_properties(self, enable_logging, timestamp_boundary_list if enable_logging: logging.info("Average number of samples per year: {:.1f}".format(float(self.n_samples_total) / float(self.n_years))) logging.info( - "Found data at path {}. Number of examples: {} (distributed over {} number of files). Full image Shape: {} x {} x {}. Read Shape: {} x {} x {}".format( + "Found data at path {}. Number of examples: {} (distributed over {} files). Full image Shape: {} x {} x {}. Read Shape: {} x {} x {}".format( self.location, self.n_samples_available, len(self.files_paths), self.img_shape[0], self.img_shape[1], self.total_channels, self.read_shape[0], self.read_shape[1], self.n_in_channels ) ) From 12093d75f30ead7085d855d0c51cdb8d7d66de96 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Thu, 18 Sep 2025 08:08:57 -0700 Subject: [PATCH 06/32] fixed dataloader indexing --- makani/models/noise.py | 20 ++++++++++++++----- makani/utils/dataloaders/dali_es_helper_2d.py | 4 ++-- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/makani/models/noise.py b/makani/models/noise.py index 56e3232..5ff3fa8 100644 --- a/makani/models/noise.py +++ b/makani/models/noise.py @@ -100,7 +100,7 @@ def reset(self, batch_size=None): # this routine generates a noise sample for a single time step and updates the state accordingly, by appending the last time step def update(self, replace_state=False, batch_size=None): - # Update should always create a new state, so + # Update should always create a new state, so # we don't need to check for replace_state # create single occurence with torch.no_grad(): @@ -366,13 +366,12 @@ def __init__( raise NotImplementedError(f"num_time_steps>1 learnable diffusion noise not supported") discount = [] - phi_flat = self.phi.reshape(-1) - for phi_tmp in phi_flat.tolist(): - phivec = np.power(phi_tmp, np.arange(0, self.num_time_steps)) + for phi in self.phi.reshape(-1).tolist(): + phivec = np.power(self.phi, np.arange(0, self.num_time_steps)) disc = torch.as_tensor(toep(phivec, np.zeros(self.num_time_steps))) disc = disc.to(dtype=torch.float32) discount.append(disc) - discount = torch.stack(discount, dim=0) + discount = torch.stack(discount) self.register_buffer("discount", discount, persistent=False) def is_stateful(self): @@ -381,6 +380,8 @@ def is_stateful(self): # this routine generates a noise sample for a single time step and updates the state accordingly, by appending the last time step def update(self, replace_state=False, batch_size=None): + print("Inside update, replace_state: ", replace_state, " batch_size: ", batch_size) + # create single occurence with torch.no_grad(): with torch.amp.autocast(device_type="cuda", enabled=False): @@ -408,6 +409,8 @@ def update(self, replace_state=False, batch_size=None): newstate = torch.cat([self.state[:, 1:, ...], newstate], dim=1) else: newstate = self.phi * self.state + eta_l + + print("update state, state shape: ", self.state.shape, " newstate shape: ", newstate.shape) else: newstate = eta_l # the very first element in the time history requires a different weighting to sample the stationary distribution @@ -416,6 +419,10 @@ def update(self, replace_state=False, batch_size=None): if self.num_time_steps > 1: newstate = torch.einsum("ctr,brclmu->btclmu", self.discount, newstate).contiguous() + print("replace state, discount shape: ", self.discount.shape) + + print("replace state, state shape: ", self.state.shape, " newstate shape: ", newstate.shape) + # update the state if newstate.shape == self.state.shape: self.state.copy_(newstate) @@ -426,6 +433,9 @@ def update(self, replace_state=False, batch_size=None): def forward(self, update_internal_state=False): + print("state shape: ", self.state.shape) + print("parameters: ", self.num_time_steps, self.num_channels, self.lmax_local, self.mmax_local) + # combine channels and time: cstate = torch.view_as_complex(self.state) batch_size = cstate.shape[0] diff --git a/makani/utils/dataloaders/dali_es_helper_2d.py b/makani/utils/dataloaders/dali_es_helper_2d.py index 2752733..bd6627d 100644 --- a/makani/utils/dataloaders/dali_es_helper_2d.py +++ b/makani/utils/dataloaders/dali_es_helper_2d.py @@ -593,8 +593,8 @@ def __call__(self, sample_info): local_idx, year_idx = self._get_local_year_index_from_global_index(sample_idx) # if we are not at least self.dt*n_history timesteps into the prediction - local_idx = min(local_idx, self.dt * self.n_history) - local_idx = max(local_idx, self.n_samples_year[year_idx] - self.dt * (self.n_future + 1) - 1) + local_idx = max(local_idx, self.dt * self.n_history) + local_idx = min(local_idx, self.n_samples_year[year_idx] - self.dt * (self.n_future + 1) - 1) if self.files[year_idx] is None: From cca3a5835c2d6e1808ebb6a6a0db91b583e4d163 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Thu, 18 Sep 2025 09:07:09 -0700 Subject: [PATCH 07/32] fixed history with noise --- makani/models/noise.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/makani/models/noise.py b/makani/models/noise.py index 5ff3fa8..8e094c5 100644 --- a/makani/models/noise.py +++ b/makani/models/noise.py @@ -366,12 +366,13 @@ def __init__( raise NotImplementedError(f"num_time_steps>1 learnable diffusion noise not supported") discount = [] - for phi in self.phi.reshape(-1).tolist(): - phivec = np.power(self.phi, np.arange(0, self.num_time_steps)) + phi_flat = self.phi.reshape(-1) + for phi_tmp in phi_flat.tolist(): + phivec = np.power(phi_tmp, np.arange(0, self.num_time_steps)) disc = torch.as_tensor(toep(phivec, np.zeros(self.num_time_steps))) disc = disc.to(dtype=torch.float32) discount.append(disc) - discount = torch.stack(discount) + discount = torch.stack(discount, dim=0) self.register_buffer("discount", discount, persistent=False) def is_stateful(self): @@ -409,8 +410,6 @@ def update(self, replace_state=False, batch_size=None): newstate = torch.cat([self.state[:, 1:, ...], newstate], dim=1) else: newstate = self.phi * self.state + eta_l - - print("update state, state shape: ", self.state.shape, " newstate shape: ", newstate.shape) else: newstate = eta_l # the very first element in the time history requires a different weighting to sample the stationary distribution @@ -419,10 +418,6 @@ def update(self, replace_state=False, batch_size=None): if self.num_time_steps > 1: newstate = torch.einsum("ctr,brclmu->btclmu", self.discount, newstate).contiguous() - print("replace state, discount shape: ", self.discount.shape) - - print("replace state, state shape: ", self.state.shape, " newstate shape: ", newstate.shape) - # update the state if newstate.shape == self.state.shape: self.state.copy_(newstate) From 555e9305bc66b86cda2718cc9160cc151febdcbf Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Fri, 19 Sep 2025 02:14:27 -0700 Subject: [PATCH 08/32] setting explicit noise seeds for ensemble serial settings. without this, noise between ensemble members is likely correlated --- makani/models/noise.py | 5 ----- makani/utils/inference/inferencer.py | 17 +++++++++-------- makani/utils/training/ensemble_trainer.py | 14 ++++++++------ 3 files changed, 17 insertions(+), 19 deletions(-) diff --git a/makani/models/noise.py b/makani/models/noise.py index 8e094c5..7b7d845 100644 --- a/makani/models/noise.py +++ b/makani/models/noise.py @@ -381,8 +381,6 @@ def is_stateful(self): # this routine generates a noise sample for a single time step and updates the state accordingly, by appending the last time step def update(self, replace_state=False, batch_size=None): - print("Inside update, replace_state: ", replace_state, " batch_size: ", batch_size) - # create single occurence with torch.no_grad(): with torch.amp.autocast(device_type="cuda", enabled=False): @@ -428,9 +426,6 @@ def update(self, replace_state=False, batch_size=None): def forward(self, update_internal_state=False): - print("state shape: ", self.state.shape) - print("parameters: ", self.num_time_steps, self.num_channels, self.lmax_local, self.mmax_local) - # combine channels and time: cstate = torch.view_as_complex(self.state) batch_size = cstate.shape[0] diff --git a/makani/utils/inference/inferencer.py b/makani/utils/inference/inferencer.py index 2027c6c..7edad4f 100644 --- a/makani/utils/inference/inferencer.py +++ b/makani/utils/inference/inferencer.py @@ -447,11 +447,13 @@ def inference_indexlist( return logs - def _initialize_noise_states(self): + def _initialize_noise_states(self, seed_offset=666): noise_states = [] - for _ in range(self.params.local_ensemble_size): - self.preprocessor.update_internal_state(replace_state=True) - noise_states.append(self.preprocessor.get_internal_state(tensor=True)) + for ide in range(self.params.local_ensemble_size): + member_seed = seed_offset + self.preprocessor.noise_base_seed * ide + self.preprocessor.input_noise.set_rng(seed=member_seed) + self.preprocessor.input_noise.update(replace_state=True) + noise_states.append(self.preprocessor.input_noise.get_tensor_state()) return noise_states def _inference_indexlist( @@ -516,7 +518,7 @@ def _inference_indexlist( climatology_iterator = iter(self.climatology_dataloader) # create loader for the full epoch - noise_states = [] + noise_states = self._initialize_noise_states() inptlist = None idt = 0 with torch.inference_mode(): @@ -573,7 +575,7 @@ def _inference_indexlist( self.preprocessor.update_internal_state(replace_state=True, batch_size=inp.shape[0]) # reset noise states and input list - noise_states = self._initialize_noise_states() + noise_states = self._initialize_noise_states(seed_offset=idt) inptlist = [inp.clone() for _ in range(self.params.local_ensemble_size)] if rollout_buffer is not None: @@ -603,9 +605,8 @@ def _inference_indexlist( # retrieve input inpt = inptlist[e] - # this is different, depending on local ensemble size + # restore noise state if (self.params.local_ensemble_size > 1): - # restore noise belonging to this ensemble member self.preprocessor.set_internal_state(noise_states[e]) # forward pass: never replace state since we do that manually diff --git a/makani/utils/training/ensemble_trainer.py b/makani/utils/training/ensemble_trainer.py index e6e7a86..fed1edc 100644 --- a/makani/utils/training/ensemble_trainer.py +++ b/makani/utils/training/ensemble_trainer.py @@ -613,11 +613,13 @@ def train_one_epoch(self, profiler=None): return train_time, total_data_gb, logs - def _initialize_noise_states(self): + def _initialize_noise_states(self, seed_offset=666): noise_states = [] - for _ in range(self.params.local_ensemble_size): - self.preprocessor.update_internal_state(replace_state=True) - noise_states.append(self.preprocessor.get_internal_state(tensor=True)) + for ide in range(self.params.local_ensemble_size): + member_seed = seed_offset + self.preprocessor.noise_base_seed * ide + self.preprocessor.input_noise.set_rng(seed=member_seed) + self.preprocessor.input_noise.update(replace_state=True) + noise_states.append(self.preprocessor.input_noise.get_tensor_state()) return noise_states def validate_one_epoch(self, epoch, profiler=None): @@ -665,7 +667,7 @@ def validate_one_epoch(self, epoch, profiler=None): # do autoregression for each ensemble member individually # do the rollout # initialize the noise states with random seeds: - noise_states = self._initialize_noise_states() + noise_states = self._initialize_noise_states(seed_offset=eval_steps) inptlist = [inp.clone() for _ in range(self.params.local_ensemble_size)] # loop over lead times @@ -683,7 +685,7 @@ def validate_one_epoch(self, epoch, profiler=None): # retrieve input inpt = inptlist[e] - # this is different, depending on local ensemble size + # restore noise state if self.params.local_ensemble_size > 1: # recover correct state self.preprocessor.set_internal_state(noise_states[e]) From 38780eefa38855b68f281f7c8ad7457ac866e5d4 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Fri, 19 Sep 2025 02:50:59 -0700 Subject: [PATCH 09/32] adding accessor functions for some RNG features when noise is enabled --- makani/utils/inference/inferencer.py | 7 +++---- makani/utils/training/ensemble_trainer.py | 7 +++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/makani/utils/inference/inferencer.py b/makani/utils/inference/inferencer.py index 7edad4f..e598d8d 100644 --- a/makani/utils/inference/inferencer.py +++ b/makani/utils/inference/inferencer.py @@ -450,10 +450,9 @@ def inference_indexlist( def _initialize_noise_states(self, seed_offset=666): noise_states = [] for ide in range(self.params.local_ensemble_size): - member_seed = seed_offset + self.preprocessor.noise_base_seed * ide - self.preprocessor.input_noise.set_rng(seed=member_seed) - self.preprocessor.input_noise.update(replace_state=True) - noise_states.append(self.preprocessor.input_noise.get_tensor_state()) + member_seed = seed_offset + self.preprocessor.get_base_seed(default=333) * ide + self.preprocessor.set_rng(seed=member_seed, reset=True) + noise_states.append(self.preprocessor.get_internal_state(tensor=True)) return noise_states def _inference_indexlist( diff --git a/makani/utils/training/ensemble_trainer.py b/makani/utils/training/ensemble_trainer.py index fed1edc..cf3ad3c 100644 --- a/makani/utils/training/ensemble_trainer.py +++ b/makani/utils/training/ensemble_trainer.py @@ -616,10 +616,9 @@ def train_one_epoch(self, profiler=None): def _initialize_noise_states(self, seed_offset=666): noise_states = [] for ide in range(self.params.local_ensemble_size): - member_seed = seed_offset + self.preprocessor.noise_base_seed * ide - self.preprocessor.input_noise.set_rng(seed=member_seed) - self.preprocessor.input_noise.update(replace_state=True) - noise_states.append(self.preprocessor.input_noise.get_tensor_state()) + member_seed = seed_offset + self.preprocessor.get_base_seed(default=333) * ide + self.preprocessor.set_rng(seed=member_seed, reset=True) + noise_states.append(self.preprocessor.get_internal_state(tensor=True)) return noise_states def validate_one_epoch(self, epoch, profiler=None): From b68733e31aeb73b29858ddfc06d0d20814d22de7 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Fri, 19 Sep 2025 03:06:25 -0700 Subject: [PATCH 10/32] removing unneccessary imports --- makani/utils/dataloaders/dali_es_helper_concat_2d.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/makani/utils/dataloaders/dali_es_helper_concat_2d.py b/makani/utils/dataloaders/dali_es_helper_concat_2d.py index f9af5a9..421d568 100644 --- a/makani/utils/dataloaders/dali_es_helper_concat_2d.py +++ b/makani/utils/dataloaders/dali_es_helper_concat_2d.py @@ -16,11 +16,11 @@ import time import sys import os -from numba.parfors.parfor import replace_returns import numpy as np import h5py import logging -from itertools import groupby +from itertools import groupby, accumulate +from bisect import bisect_right # for nvtx annotation import torch @@ -160,7 +160,7 @@ def _generate_indexlist(self, timestamp_boundary_list): if timestamp_boundary_list: #compute list of allowed timestamps timestamp_boundary_list = [get_date_from_string(timestamp_string) for timestamp_string in timestamp_boundary_list] - + # now, based on dt, dh, n_history and n_future, we can build regions where no data is allowed timestamp_exclusion_list = get_date_ranges(timestamp_boundary_list, lookback_hours = dt_total * (self.n_future + 1), lookahead_hours = dt_total * self.n_history) From 7656ad470e57f75716082498b3f10dcfff105187 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Fri, 19 Sep 2025 03:07:43 -0700 Subject: [PATCH 11/32] removing more imports --- makani/utils/dataloaders/dali_es_helper_concat_2d.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/makani/utils/dataloaders/dali_es_helper_concat_2d.py b/makani/utils/dataloaders/dali_es_helper_concat_2d.py index 421d568..c2e30a8 100644 --- a/makani/utils/dataloaders/dali_es_helper_concat_2d.py +++ b/makani/utils/dataloaders/dali_es_helper_concat_2d.py @@ -19,8 +19,7 @@ import numpy as np import h5py import logging -from itertools import groupby, accumulate -from bisect import bisect_right +from itertools import groupby # for nvtx annotation import torch From 228e7d9a6e32a0c97d024de39ac82f04f33a0e93 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Fri, 19 Sep 2025 03:09:58 -0700 Subject: [PATCH 12/32] removing more imports --- makani/utils/inference/inferencer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/makani/utils/inference/inferencer.py b/makani/utils/inference/inferencer.py index e598d8d..78bb0c9 100644 --- a/makani/utils/inference/inferencer.py +++ b/makani/utils/inference/inferencer.py @@ -39,7 +39,6 @@ # distributed computing stuff from makani.utils import comm -from makani.utils import visualize from makani.utils.dataloaders.data_helpers import get_date_from_string # inference specific stuff From 02b6addb83a0e04fdf0cba3a67ed63b5992993b7 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Fri, 19 Sep 2025 04:46:26 -0700 Subject: [PATCH 13/32] adapting changes from internal branch --- makani/models/common/__init__.py | 1 + makani/models/common/pos_embedding.py | 99 ++++++++ makani/models/networks/fourcastnet3.py | 18 +- makani/utils/dataloaders/data_helpers.py | 11 + makani/utils/features.py | 2 +- makani/utils/loss.py | 60 +++-- makani/utils/losses/__init__.py | 2 +- makani/utils/losses/base_loss.py | 122 ++++++++- makani/utils/losses/crps_loss.py | 307 +++++++++++++++++++++-- makani/utils/metric.py | 16 +- 10 files changed, 574 insertions(+), 64 deletions(-) create mode 100644 makani/models/common/pos_embedding.py diff --git a/makani/models/common/__init__.py b/makani/models/common/__init__.py index 0341114..7c14ff0 100644 --- a/makani/models/common/__init__.py +++ b/makani/models/common/__init__.py @@ -18,3 +18,4 @@ from .fft import RealFFT1, InverseRealFFT1, RealFFT2, InverseRealFFT2, RealFFT3, InverseRealFFT3 from .layer_norm import GeometricInstanceNormS2 from .spectral_convolution import SpectralConv, SpectralAttention +from .pos_embedding import LearnablePositionEmbedding diff --git a/makani/models/common/pos_embedding.py b/makani/models/common/pos_embedding.py new file mode 100644 index 0000000..0f65da5 --- /dev/null +++ b/makani/models/common/pos_embedding.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +import torch +import torch.nn as nn + +from makani.utils import comm +from physicsnemo.distributed.utils import compute_split_shapes + +class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta): + """ + Abstract base class for position embeddings. + + This class defines the interface for position embedding modules + that add positional information to input tensors. + + Parameters + ---------- + img_shape : tuple, optional + Image shape (height, width), by default (480, 960) + grid : str, optional + Grid type, by default "equiangular" + num_chans : int, optional + Number of channels, by default 1 + """ + + def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1): + super().__init__() + + self.img_shape = img_shape + self.num_chans = num_chans + + def forward(self): + + return self.position_embeddings + +class LearnablePositionEmbedding(PositionEmbedding): + """ + Learnable position embeddings for spherical transformers. + + This module provides learnable position embeddings that can be either + latitude-only or full latitude-longitude embeddings. + + Parameters + ---------- + img_shape : tuple, optional + Image shape (height, width), by default (480, 960) + grid : str, optional + Grid type, by default "equiangular" + num_chans : int, optional + Number of channels, by default 1 + embed_type : str, optional + Embedding type ("lat" or "latlon"), by default "lat" + """ + + def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_type="lat"): + super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans) + + # if distributed, make sure to split correctly across ranks: + # in case of model parallelism, we need to make sure that we use the correct shapes per rank + # for h + if comm.get_size("h") > 1: + self.local_shape_h = compute_split_shapes(img_shape[0], comm.get_size("h"))[comm.get_rank("h")] + else: + self.local_shape_h = img_shape[0] + + # for w + if comm.get_size("w") > 1: + self.local_shape_w = compute_split_shapes(img_shape[1], comm.get_size("w"))[comm.get_rank("w")] + else: + self.local_shape_w = img_shape[1] + + if embed_type == "latlon": + self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.local_shape_h, self.local_shape_w)) + self.position_embeddings.is_shared_mp = [] + self.position_embeddings.sharded_dims_mp = [None, None, "h", "w"] + elif embed_type == "lat": + self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.local_shape_h, 1)) + self.position_embeddings.is_shared_mp = ["w"] + self.position_embeddings.sharded_dims_mp = [None, None, "h", None] + else: + raise ValueError(f"Unknown learnable position embedding type {embed_type}") + + def forward(self): + return self.position_embeddings.expand(-1,-1,self.local_shape_h, self.local_shape_w) diff --git a/makani/models/networks/fourcastnet3.py b/makani/models/networks/fourcastnet3.py index 9244e14..3764978 100644 --- a/makani/models/networks/fourcastnet3.py +++ b/makani/models/networks/fourcastnet3.py @@ -24,7 +24,7 @@ from itertools import groupby # helpers -from makani.models.common import DropPath, LayerScale, MLP, EncoderDecoder, SpectralConv +from makani.models.common import DropPath, LayerScale, MLP, EncoderDecoder, SpectralConv, LearnablePositionEmbedding from makani.utils.features import get_water_channels, get_channel_groups # get spectral transforms and spherical convolutions from torch_harmonics @@ -409,6 +409,7 @@ def __init__( atmo_embed_dim=8, surf_embed_dim=8, aux_embed_dim=8, + pos_embed_dim=0, num_layers=4, num_groups=1, use_mlp=True, @@ -437,6 +438,7 @@ def __init__( self.atmo_embed_dim = atmo_embed_dim self.surf_embed_dim = surf_embed_dim self.aux_embed_dim = aux_embed_dim + self.pos_embed_dim = pos_embed_dim self.big_skip = big_skip self.checkpointing_level = checkpointing_level @@ -561,6 +563,10 @@ def __init__( use_mlp=encoder_mlp, ) + # position embedding + if self.pos_embed_dim > 0: + self.pos_embed = LearnablePositionEmbedding(img_shape=(self.h, self.w), grid=sht_grid_type, num_chans=self.pos_embed_dim, embed_type="lat") + # dropout self.pos_drop = nn.Dropout(p=pos_drop_rate) if pos_drop_rate > 0.0 else nn.Identity() dpr = [x.item() for x in torch.linspace(0, path_drop_rate, num_layers)] @@ -583,7 +589,7 @@ def __init__( block = NeuralOperatorBlock( self.sht, self.isht, - self.total_embed_dim + (self.n_aux_chans > 0) * self.aux_embed_dim, + self.total_embed_dim + (self.n_aux_chans > 0) * self.aux_embed_dim + self.pos_embed_dim, self.total_embed_dim, conv_type=conv_type, mlp_ratio=mlp_ratio, @@ -726,6 +732,7 @@ def _precompute_channel_groups( self, channel_names=[], aux_channel_names=[], + n_history=0, ): """ group the channels appropriately into atmospheric pressure levels and surface variables @@ -783,6 +790,13 @@ def encode_auxiliary_channels(self, x): else: x_aux = None + if hasattr(self, "pos_embed"): + x_pos = self.pos_embed() + if x_aux is not None: + x_aux = torch.cat([x_aux, x_pos], dim=-3) + else: + x_aux = x_pos + return x_aux def decode(self, x): diff --git a/makani/utils/dataloaders/data_helpers.py b/makani/utils/dataloaders/data_helpers.py index 753e398..a52f684 100644 --- a/makani/utils/dataloaders/data_helpers.py +++ b/makani/utils/dataloaders/data_helpers.py @@ -58,6 +58,17 @@ def get_data_normalization(params): return bias, scale +def get_time_diff_stds(params): + + time_diff_stds = None + + if hasattr(params, "time_diff_stds_path"): + time_diff_stds = np.load(params.time_diff_stds_path) + else: + raise ValueError(f"time_diff_std_path not defined.") + + return time_diff_stds + def get_climatology(params): """ diff --git a/makani/utils/features.py b/makani/utils/features.py index 9b17750..658af01 100644 --- a/makani/utils/features.py +++ b/makani/utils/features.py @@ -88,7 +88,7 @@ def get_wind_channels(channel_names): wind_chans = [] for c, ch in enumerate(channel_names): - if ch[0] == "u": + if ch[0] == "u" and ("v" + ch[1:]) in channel_names: vc = channel_names.index("v" + ch[1:]) wind_chans = wind_chans + [c, vc] diff --git a/makani/utils/loss.py b/makani/utils/loss.py index 887c19f..677081c 100644 --- a/makani/utils/loss.py +++ b/makani/utils/loss.py @@ -24,16 +24,17 @@ from makani.utils import comm from makani.utils.grids import GridQuadrature -from makani.utils.dataloaders.data_helpers import get_data_normalization +from makani.utils.dataloaders.data_helpers import get_data_normalization, get_time_diff_stds from physicsnemo.distributed.utils import compute_split_shapes from physicsnemo.distributed.mappings import gather_from_parallel_region, reduce_from_parallel_region import torch_harmonics as harmonics from torch_harmonics.quadrature import clenshaw_curtiss_weights, legendre_gauss_weights -from .losses import LossType, GeometricLpLoss, SpectralH1Loss, SpectralAMSELoss, HydrostaticBalanceLoss -from .losses import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleNLLLoss, EnsembleMMDLoss -from .losses import DriftRegularization +from .losses import LossType, GeometricLpLoss, SpectralH1Loss, SpectralAMSELoss +from .losses import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleVortDivCRPSLoss +from .losses import EnsembleNLLLoss, EnsembleMMDLoss +from .losses import DriftRegularization, HydrostaticBalanceLoss class LossHandler(nn.Module): @@ -57,11 +58,12 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps # check whether dynamic loss weighting is required self.uncertainty_weighting = params.get("uncertainty_weighting", False) + self.balanced_weighting = params.get("balanced_weighting", False) self.randomized_loss_weights = params.get("randomized_loss_weights", False) self.random_slice_loss = params.get("random_slice_loss", False) # whether to keep running stats - self.track_running_stats = track_running_stats or self.uncertainty_weighting + self.track_running_stats = track_running_stats or self.uncertainty_weighting or self.balanced_weighting self.eps = eps n_channels = len(params.channel_names) @@ -127,32 +129,23 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps else: channel_weight_type = loss["channel_weights"] + # check if time difference weighting is required + if loss.get("temp_diff_normalization", False): + time_diff_scale = get_time_diff_stds(params).flatten() + time_diff_scale = torch.clamp(torch.from_numpy(time_diff_scale[params.out_channels]), min=1e-4) + time_diff_scale = scale.flatten() / time_diff_scale + else: + time_diff_scale = None + + # get channel weights either directly or through the compute routine if isinstance(channel_weight_type, List): - chw = torch.tensor(channel_weight_type, dtype=torch.float32).reshape(1, -1) + chw = torch.tensor(channel_weight_type, dtype=torch.float32) + chw = chw * time_diff_scale assert chw.shape[1] == loss_fn.n_channels else: - chw = loss_fn.compute_channel_weighting(channel_weight_type) - - # the option to normalize outputs with stds of the time difference rather than th - if ("temp_diff_normalization" in loss.keys()) and loss["temp_diff_normalization"]: - - # extract relevant stds - time_diff_stds = torch.from_numpy(np.load(params.time_diff_stds_path)).reshape(1, -1)[:, params.out_channels] - # the time differences are computed between two consecutive datapoints, - # so we need to account for the number of timesteps used in the prediction - # this is now commebnted out as we expect the stats to be computed with the correct dt - # time_diff_stds *= np.sqrt(params.dt) - - # to avoid division by very small numbers, we clamp the time differences from below - time_diff_stds = torch.clamp(time_diff_stds, min=1e-4) - - time_var_weights = scale.reshape(1, -1) / time_diff_stds - - if hasattr(loss_fn, "squared") and loss_fn.squared: - time_var_weights = time_var_weights**2 - - chw = chw * time_var_weights + chw = loss_fn.compute_channel_weighting(channel_weight_type, time_diff_scale=time_diff_scale) + # reshape channel weights for propewr broadcasting chw = chw.reshape(1, -1) # check for a relative weight that weights the loss relative to other losses @@ -226,11 +219,11 @@ def _parse_loss_type(self, loss_type: str): p_max = int(x.replace("p_max=", "")) loss_handle = partial(HydrostaticBalanceLoss, p_min=p_min, p_max=p_max, use_moist_air_formula=use_moist_air_formula) elif "ensemble_crps" in loss_type: - loss_handle = partial(EnsembleCRPSLoss, crps_type="cdf") + loss_handle = partial(EnsembleCRPSLoss) elif "ensemble_spectral_crps" in loss_type: - loss_handle = partial(EnsembleSpectralCRPSLoss, crps_type="cdf") - elif "gauss_crps" in loss_type: - loss_handle = partial(EnsembleCRPSLoss, crps_type="gauss") + loss_handle = partial(EnsembleSpectralCRPSLoss) + elif "ensemble_vort_div_crps" in loss_type: + loss_handle = partial(EnsembleVortDivCRPSLoss) elif "ensemble_nll" in loss_type: loss_handle = EnsembleNLLLoss elif "ensemble_mmd" in loss_type: @@ -348,6 +341,11 @@ def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tens if self.uncertainty_weighting and self.training: var, _ = self.get_running_stats() chw = chw / (torch.sqrt(2 * var) + self.eps) + elif self.balanced_weighting and self.training: + _, mean = self.get_running_stats() + if self.num_batches_tracked.item() <= 100: + mean = torch.ones_like(mean) + chw = chw / mean if self.randomized_loss_weights: rmask = torch.zeros_like(chw) diff --git a/makani/utils/losses/__init__.py b/makani/utils/losses/__init__.py index 48d63f3..a4aef77 100644 --- a/makani/utils/losses/__init__.py +++ b/makani/utils/losses/__init__.py @@ -18,7 +18,7 @@ from .lp_loss import GeometricLpLoss, SpectralL2Loss from .amse_loss import SpectralAMSELoss from .hydrostatic_loss import HydrostaticBalanceLoss -from .crps_loss import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss +from .crps_loss import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleVortDivCRPSLoss from .mmd_loss import EnsembleMMDLoss from .likelihood_loss import EnsembleNLLLoss from .drift_regularization import DriftRegularization diff --git a/makani/utils/losses/base_loss.py b/makani/utils/losses/base_loss.py index b71006c..6614efe 100644 --- a/makani/utils/losses/base_loss.py +++ b/makani/utils/losses/base_loss.py @@ -17,6 +17,7 @@ from dataclasses import dataclass from abc import ABCMeta, abstractmethod +import math import torch import torch.nn as nn @@ -26,9 +27,10 @@ from makani.utils.grids import grid_to_quadrature_rule, GridQuadrature from makani.utils import comm +from makani.utils.features import get_wind_channels -def _compute_channel_weighting_helper(channel_names: List[str], channel_weight_type: str) -> torch.Tensor: +def _compute_channel_weighting_helper(channel_names: List[str], channel_weight_type: str, time_diff_scale: torch.Tensor = None) -> torch.Tensor: """ auxiliary routine for predetermining channel weighting """ @@ -53,6 +55,32 @@ def _compute_channel_weighting_helper(channel_names: List[str], channel_weight_t else: channel_weights[c] = 0.01 + elif channel_weight_type == "new auto": + + for c, chn in enumerate(channel_names): + if chn in ["u10m", "v10m", "u100m", "v100m", "tp", "sp", "msl", "tcwv"]: + channel_weights[c] = 0.1 + elif chn in ["t2m", "2d"]: + channel_weights[c] = 2.0 + elif chn[0] in ["z", "u", "v", "t", "r", "q"]: + pressure_level = float(chn[1:]) + channel_weights[c] = max(0.2, 0.001 * pressure_level) + else: + channel_weights[c] = 0.01 + + elif channel_weight_type == "new auto 2": + + for c, chn in enumerate(channel_names): + if chn in ["u10m", "v10m", "u100m", "v100m", "tp", "sp", "msl", "tcwv"]: + channel_weights[c] = 0.1 + elif chn in ["t2m", "2d"]: + channel_weights[c] = 2.0 + elif chn[0] in ["z", "u", "v", "t", "r", "q"]: + pressure_level = float(chn[1:]) + channel_weights[c] = max(0.3, 0.001 * pressure_level) + else: + channel_weights[c] = 0.01 + elif channel_weight_type == "custom": weight_dict = { @@ -216,6 +244,10 @@ def _compute_channel_weighting_helper(channel_names: List[str], channel_weight_t # normalize channel_weights = channel_weights / torch.sum(channel_weights) + # get the time differences and weigh them additionally + if time_diff_scale is not None: + channel_weights = channel_weights * time_diff_scale + return channel_weights @@ -250,9 +282,8 @@ def __init__( self.pole_mask = pole_mask self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + # get the quadrature rule for the corresponding grid quadrature_rule = grid_to_quadrature_rule(grid_type) - - # get the quadrature self.quadrature = GridQuadrature( quadrature_rule, img_shape=self.img_shape, @@ -272,8 +303,8 @@ def n_channels(self): return len(self.channel_names) @torch.compiler.disable(recursive=False) - def compute_channel_weighting(self, channel_weight_type: str) -> torch.Tensor: - return _compute_channel_weighting_helper(self.channel_names, channel_weight_type) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: torch.Tensor = None) -> torch.Tensor: + return _compute_channel_weighting_helper(self.channel_names, channel_weight_type, time_diff_scale=time_diff_scale) @abstractmethod def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None) -> torch.Tensor: @@ -302,6 +333,7 @@ def __init__( self.channel_names = channel_names self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + # SHT for the computation of SH coefficients if self.spatial_distributed and (comm.get_size("spatial") > 1): if not thd.is_initialized(): polar_group = None if (comm.get_size("h") == 1) else comm.get_group("h") @@ -320,8 +352,84 @@ def n_channels(self): return len(self.n_channels) @torch.compiler.disable(recursive=False) - def compute_channel_weighting(self, channel_weight_type: str) -> torch.Tensor: - return _compute_channel_weighting_helper(self.channel_names, channel_weight_type) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: torch.Tensor = None) -> torch.Tensor: + return _compute_channel_weighting_helper(self.channel_names, channel_weight_type, time_diff_scale=time_diff_scale) + + @abstractmethod + def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None) -> torch.Tensor: + pass + + +class VortDivBaseLoss(nn.Module, metaclass=ABCMeta): + """ + Geometric base loss class used by all geometric losses + """ + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + pole_mask: int, + spatial_distributed: Optional[bool] = False, + ): + super().__init__() + + self.img_shape = img_shape + self.crop_shape = crop_shape + self.crop_offset = crop_offset + self.channel_names = channel_names + self.pole_mask = pole_mask + self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + + # get the wind channels + wind_chans = get_wind_channels(self.channel_names) + self.register_buffer("wind_chans", torch.LongTensor(wind_chans)) + + if self.spatial_distributed and (comm.get_size("spatial") > 1): + if not thd.is_initialized(): + polar_group = None if (comm.get_size("h") == 1) else comm.get_group("h") + azimuth_group = None if (comm.get_size("w") == 1) else comm.get_group("w") + thd.init(polar_group, azimuth_group) + self.vsht = thd.DistributedRealVectorSHT(*img_shape, grid=grid_type) + self.isht = thd.DistributedInverseRealVectorSHT(nlat=self.vsht.nlat, nlon=self.vsht.nlon, grid=grid_type) + else: + self.vsht = th.RealVectorSHT(*img_shape, grid=grid_type) + self.isht = th.InverseRealVectorSHT(nlat=self.vsht.nlat, nlon=self.vsht.nlon, grid=grid_type) + + # get the quadrature rule for the corresponding grid + quadrature_rule = grid_to_quadrature_rule(grid_type) + self.quadrature = GridQuadrature( + quadrature_rule, + img_shape=self.img_shape, + crop_shape=self.crop_shape, + crop_offset=self.crop_offset, + normalize=True, + pole_mask=self.pole_mask, + distributed=self.spatial_distributed, + ) + + @property + def type(self): + return LossType.Deterministic + + @property + def n_channels(self): + return len(self.n_channels) + + @torch.compiler.disable(recursive=False) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: torch.Tensor = None) -> torch.Tensor: + + chw = _compute_channel_weighting_helper(self.channel_names, channel_weight_type, time_diff_scale=time_diff_scale) + chw = chw[self.wind_chans.to(chw.device)] + + # average u and v component weightings to weight vort and div equally + chw[1::2] = (chw[1::2] + chw[0::2]) / 2 + chw[0::2] = chw[1::2] + + return chw @abstractmethod def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None) -> torch.Tensor: diff --git a/makani/utils/losses/crps_loss.py b/makani/utils/losses/crps_loss.py index 90c6224..8119e43 100644 --- a/makani/utils/losses/crps_loss.py +++ b/makani/utils/losses/crps_loss.py @@ -22,7 +22,7 @@ import torch.nn as nn from torch import amp -from makani.utils.losses.base_loss import GeometricBaseLoss, SpectralBaseLoss, LossType +from makani.utils.losses.base_loss import GeometricBaseLoss, SpectralBaseLoss, VortDivBaseLoss, LossType from makani.utils import comm # distributed stuff @@ -49,6 +49,9 @@ def _crps_ensemble_kernel(observation: torch.Tensor, forecasts: torch.Tensor, we CRPS ensemble score from integrating the PDF piecewise compare https://github.com/properscoring/properscoring/blob/master/properscoring/_gufuncs.py#L7 disabling torch compile for the moment due to very long startup times when training large ensembles with ensemble parallelism + + forecasts: [ensemble, ...], observation: [...], weights: [ensemble, ...] + Assumes forecasts are sorted along ensemble dimension 0. """ # beware: forecasts are assumed sorted in sorted order @@ -110,7 +113,7 @@ def _crps_ensemble_kernel(observation: torch.Tensor, forecasts: torch.Tensor, we def _crps_skillspread_kernel(observation: torch.Tensor, forecasts: torch.Tensor, weights: torch.Tensor, alpha: float) -> torch.Tensor: """ - alternative CRPS variant that uses spread and skill + fair CRPS variant that uses spread and skill. Assumes pre-sorted ensemble """ observation = observation.unsqueeze(0) @@ -142,6 +145,106 @@ def _crps_skillspread_kernel(observation: torch.Tensor, forecasts: torch.Tensor, return crps +def _crps_probability_weighted_moment_kernel(observation: torch.Tensor, forecasts: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + """ + CRPS estimator based on the probability weighted moment. see [1]. + + [1] Michael Zamo, Phillippe Naveau. Estimation of the Continuous Ranked Probability Score with Limited Information and Applications to Ensemble Weather Forecasts. Mathematical Geosciences. Volume 50 pp. 209-234. 2018. + """ + + observation = observation.unsqueeze(0) + + # get nanmask + nanmasks = torch.logical_or(torch.isnan(forecasts), torch.isnan(weights)) + nanmask = torch.sum(nanmasks, dim=0).bool() + + # compute total weights + nweights = torch.where(nanmasks, 0.0, weights) + total_weight = torch.sum(nweights, dim=0, keepdim=True) + + # ensemble size + num_ensemble = forecasts.shape[0] + + # get the ranks for the pwm computation + rank = torch.arange(num_ensemble, device=forecasts.device).reshape((num_ensemble,) + (1,) * (forecasts.dim() - 1)) + + # get the ensemble spread (total_weight is ensemble size here) + beta0 = forecasts.mean(dim=0) + beta1 = (rank * forecasts).sum(dim=0) / float(num_ensemble * (num_ensemble - 1)) + eskill = (observation - forecasts).abs().mean(dim=0) + + # crps = torch.where(nanmasks.sum(dim=0) != 0, torch.nan, eskill - 0.5 * espread) + crps = eskill + beta0 - 2 * beta1 + + # set to nan for first forecasts nan + crps = torch.where(nanmask, torch.nan, crps) + + return crps + + + +def _crps_independent_skillspread_kernel(observation: torch.Tensor, forecasts: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + """ + CRPS which uses separate samples for the estimation of spread and skill. Only one sample is used for the estimation of the skill + """ + + observation = observation.unsqueeze(0) + + # get nanmask + nanmasks = torch.logical_or(torch.isnan(forecasts), torch.isnan(weights)) + nanmask = torch.sum(nanmasks, dim=0).bool() + + # compute total weights + nweights = torch.where(nanmasks, 0.0, weights) + total_weight = torch.sum(nweights, dim=0, keepdim=True) + + # ensemble size + num_ensemble = forecasts.shape[0] + + # use broadcasting semantics to compute spread and skill + espread = (forecasts[1:].unsqueeze(1) - forecasts[1:].unsqueeze(0)).abs().sum(dim=(0,1)) / float((num_ensemble - 1)*(num_ensemble - 2)) + eskill = (observation - forecasts[0:1]).abs().mean(dim=0) + + # crps = torch.where(nanmasks.sum(dim=0) != 0, torch.nan, eskill - 0.5 * espread) + crps = eskill - 0.5 * espread + + # set to nan for first forecasts nan + crps = torch.where(nanmask, torch.nan, crps) + + return crps + + +def _crps_naive_skillspread_kernel(observation: torch.Tensor, forecasts: torch.Tensor, weights: torch.Tensor, alpha: float) -> torch.Tensor: + """ + alternative fair CRPS variant that uses spread and skill. Uses naive computation which is O(N^2) in the number of ensemble members. Useful for complex + """ + + observation = observation.unsqueeze(0) + + # get nanmask + nanmasks = torch.logical_or(torch.isnan(forecasts), torch.isnan(weights)) + nanmask = torch.sum(nanmasks, dim=0).bool() + + # compute total weights + nweights = torch.where(nanmasks, 0.0, weights) + total_weight = torch.sum(nweights, dim=0, keepdim=True) + + # ensemble size + num_ensemble = forecasts.shape[0] + + # use broadcasting semantics to compute spread and skill + espread = (forecasts.unsqueeze(1) - forecasts).abs().sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) + eskill = (observation - forecasts).abs().mean(dim=0) + + # crps = torch.where(nanmasks.sum(dim=0) != 0, torch.nan, eskill - 0.5 * espread) + crps = eskill - 0.5 * espread + + # set to nan for first forecasts nan + crps = torch.where(nanmask, torch.nan, crps) + + return crps + + # @torch.compile def _crps_gauss_kernel(observation: torch.Tensor, forecasts: torch.Tensor, weights: torch.Tensor, eps: float) -> torch.Tensor: """ @@ -278,6 +381,17 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w # compute score crps = _crps_ensemble_kernel(observations, forecasts, ensemble_weights) + elif self.crps_type == "probability weighted moment": + # now, E dimension is local and spatial dim is split further + # we need to sort the forecasts now + forecasts, idx = torch.sort(forecasts, dim=0) + if self.ensemble_weights is not None: + ensemble_weights = self.ensemble_weights[idx] + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_probability_weighted_moment_kernel(observations, forecasts, ensemble_weights) elif self.crps_type == "skillspread": if self.ensemble_weights is not None: raise NotImplementedError("currently only constant ensemble weights are supported") @@ -405,13 +519,8 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spectral_ forecasts = torch.abs(forecasts).to(dtype) observations = torch.abs(observations).to(dtype) else: - forecasts = torch.view_as_real(forecasts).to(dtype) - observations = torch.view_as_real(observations).to(dtype) - - # merge complex dimension after channel dimension and flatten - # this needs to be undone at the end - forecasts = torch.movedim(forecasts, 5, 3).flatten(2, 3) - observations = torch.movedim(observations, 4, 2).flatten(1, 2) + # since the other kernels require sorting, this approach only works with the naive CRPS kernel + assert self.crps_type == "skillspread" # we assume the following shapes: # forecasts: batch, ensemble, channels, mmax, lmax @@ -460,6 +569,17 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spectral_ # compute score crps = _crps_ensemble_kernel(observations, forecasts, ensemble_weights) + elif self.crps_type == "probability weighted moment": + # now, E dimension is local and spatial dim is split further + # we need to sort the forecasts now + forecasts, idx = torch.sort(forecasts, dim=0) + if self.ensemble_weights is not None: + ensemble_weights = self.ensemble_weights[idx] + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_probability_weighted_moment_kernel(observations, forecasts, ensemble_weights) elif self.crps_type == "skillspread": if self.ensemble_weights is not None: raise NotImplementedError("currently only constant ensemble weights are supported") @@ -467,7 +587,10 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spectral_ ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) # compute score - crps = _crps_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) + if self.absolute: + crps = _crps_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) + else: + crps = _crps_naive_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) elif self.crps_type == "gauss": if self.ensemble_weights is not None: ensemble_weights = self.ensemble_weights[idx] @@ -488,9 +611,165 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spectral_ if self.ensemble_distributed: crps = reduce_from_parallel_region(crps, "ensemble") - # finally undo the folding of the complex dimension into the channel dimension - if not self.absolute: - crps = crps.reshape(B, -1, 2).sum(dim=-1) - # the resulting tensor should have dimension B, C, which is what we return return crps + +class EnsembleVortDivCRPSLoss(VortDivBaseLoss): + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + pole_mask: int, + crps_type: str = "skillspread", + spatial_distributed: Optional[bool] = False, + ensemble_distributed: Optional[bool] = False, + ensemble_weights: Optional[torch.Tensor] = None, + alpha: Optional[float] = 1.0, + eps: Optional[float] = 1.0e-5, + **kwargs, + ): + + super().__init__( + img_shape=img_shape, + crop_shape=crop_shape, + crop_offset=crop_offset, + channel_names=channel_names, + grid_type=grid_type, + pole_mask=pole_mask, + spatial_distributed=spatial_distributed, + ) + + self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed + self.crps_type = crps_type + self.alpha = alpha + self.eps = eps + + if (self.crps_type != "skillspread") and (self.alpha < 1.0): + raise NotImplementedError("The alpha parameter (almost fair CRPS factor) is only supported for the skillspread kernel.") + + # we also need a variant of the weights split in ensemble direction: + quad_weight_split = self.quadrature.quad_weight.reshape(1, 1, -1) + if self.ensemble_distributed: + quad_weight_split = split_tensor_along_dim(quad_weight_split, dim=-1, num_chunks=comm.get_size("ensemble"))[comm.get_rank("ensemble")] + quad_weight_split = quad_weight_split.contiguous() + self.register_buffer("quad_weight_split", quad_weight_split, persistent=False) + + if ensemble_weights is not None: + self.register_buffer("ensemble_weights", ensemble_weights, persistent=False) + else: + self.ensemble_weights = ensemble_weights + + @property + def type(self): + return LossType.Probabilistic + + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + + # sanity checks + if forecasts.dim() != 5: + raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.") + + # we assume that spatial_weights have NO ensemble dim + if (spatial_weights is not None) and (spatial_weights.dim() != observations.dim()): + spdim = spatial_weights.dim() + odim = observations.dim() + raise ValueError(f"the weights have to have the same number of dimensions (found {spdim}) as observations (found {odim}).") + + # we assume the following shapes: + # forecasts: batch, ensemble, channels, lat, lon + # observations: batch, channels, lat, lon + B, E, _, H, W = forecasts.shape + C = self.wind_chans.shape[0] + + # extract wind channels + forecasts = forecasts[..., self.wind_chans, :, :].reshape(B, E, C//2, 2, H, W) + observations = observations[..., self.wind_chans, :, :].reshape(B, C//2, 2, H, W) + + # get the data type before stripping amp types + dtype = forecasts.dtype + + # before anything else compute the transform + # as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same + with amp.autocast(device_type="cuda", enabled=False): + forecasts = self.isht(self.vsht(forecasts.float())) + observations = self.isht(self.vsht(observations.float())) + + # extract wind channels + forecasts = forecasts.reshape(B, E, C, H, W) + observations = observations.reshape(B, C, H, W) + + # if ensemble dim is one dimensional then computing the score is quick: + if (not self.ensemble_distributed) and (E == 1): + # in this case, CRPS is straightforward + crps = torch.abs(observations - forecasts.squeeze(1)).reshape(B, C, H * W) + else: + # transpose forecasts: ensemble, batch, channels, lat, lon + forecasts = torch.moveaxis(forecasts, 1, 0) + + # now we need to transpose the forecasts into ensemble direction. + # ideally we split spatial dims + forecasts = forecasts.reshape(E, B, C, H * W) + if self.ensemble_distributed: + ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] + forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") + # observations does not need a transpose, but just a split + observations = observations.reshape(B, C, H * W) + if self.ensemble_distributed: + observations = scatter_to_parallel_region(observations, -1, "ensemble") + if spatial_weights is not None: + spatial_weights_split = spatial_weights.flatten(start_dim=-2, end_dim=-1) + spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") + + # run appropriate crps kernel to compute it pointwise + if self.crps_type == "cdf": + # now, E dimension is local and spatial dim is split further + # we need to sort the forecasts now + forecasts, idx = torch.sort(forecasts, dim=0) + if self.ensemble_weights is not None: + ensemble_weights = self.ensemble_weights[idx] + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_ensemble_kernel(observations, forecasts, ensemble_weights) + elif self.crps_type == "skillspread": + if self.ensemble_weights is not None: + raise NotImplementedError("currently only constant ensemble weights are supported") + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) + elif self.crps_type == "gauss": + if self.ensemble_weights is not None: + ensemble_weights = self.ensemble_weights[idx] + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_gauss_kernel(observations, forecasts, ensemble_weights, self.eps) + else: + raise ValueError(f"Unknown CRPS crps_type {self.crps_type}") + + # perform ensemble and spatial average of crps score + if spatial_weights is not None: + crps = torch.sum(crps * self.quad_weight_split * spatial_weights_split, dim=-1) + else: + crps = torch.sum(crps * self.quad_weight_split, dim=-1) + + # since we split spatial dim into ensemble dim, we need to do an ensemble sum as well + if self.ensemble_distributed: + crps = reduce_from_parallel_region(crps, "ensemble") + + # we need to do the spatial averaging manually since + # we are not calling he quadrature forward function + if self.spatial_distributed: + crps = reduce_from_parallel_region(crps, "spatial") + + # the resulting tensor should have dimension B, C, which is what we return + return crps \ No newline at end of file diff --git a/makani/utils/metric.py b/makani/utils/metric.py index 8e06260..4edba2a 100644 --- a/makani/utils/metric.py +++ b/makani/utils/metric.py @@ -73,14 +73,14 @@ def __init__(self, metric_name, metric_channels, metric_handle, channel_names, n # CPU buffers pin_memory = self.device.type == "cuda" - + if self.aux_shape_finalized is None: data_shape_finalized = (self.num_rollout_steps, self.num_channels) integral_shape = (self.num_channels) else: data_shape_finalized = (self.num_rollout_steps, self.num_channels, *self.aux_shape_finalized) integral_shape = (self.num_channels, *self.aux_shape_finalized) - + self.rollout_curve_cpu = torch.zeros(data_shape_finalized, dtype=torch.float32, device="cpu", pin_memory=pin_memory) if self.integrate: @@ -213,12 +213,12 @@ def __init__( climatology, num_rollout_steps, device, - l1_var_names=["u10m", "t2m", "u500", "z500", "q500", "sp"], - rmse_var_names=["u10m", "t2m", "u500", "z500", "q500", "sp"], - acc_var_names=["u10m", "t2m", "u500", "z500", "q500", "sp"], - crps_var_names=["u10m", "t2m", "u500", "z500", "q500", "sp"], - spread_var_names=["u10m", "t2m", "u500", "z500", "q500", "sp"], - ssr_var_names=["u10m", "t2m", "u500", "z500", "q500", "sp"], + l1_var_names=["u10m", "t2m", "u500", "z500", "q500", "q50", "sp"], + rmse_var_names=["u10m", "t2m", "u500", "z500", "q500", "q50", "sp"], + acc_var_names=["u10m", "t2m", "u500", "z500", "q500", "q50", "sp"], + crps_var_names=["u10m", "t2m", "u500", "z500", "q500", "q50", "sp"], + spread_var_names=["u10m", "t2m", "u500", "z500", "q500", "q50", "sp"], + ssr_var_names=["u10m", "t2m", "u500", "z500", "q500", "q50", "sp"], rh_var_names=[], wb2_compatible=False, ): From ae27deb3b5e4272a6fdd52da6f0a4c356fc6cd71 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Thu, 30 Oct 2025 03:58:19 -0700 Subject: [PATCH 14/32] some changes to the model and losses --- makani/models/networks/fourcastnet3.py | 252 ++++++++++++++++------ makani/models/networks/pangu.py | 38 ++-- makani/models/networks/pangu_onnx.py | 14 +- makani/models/stepper.py | 2 +- makani/utils/driver.py | 4 +- makani/utils/features.py | 14 +- makani/utils/loss.py | 15 +- makani/utils/losses/__init__.py | 1 + makani/utils/losses/base_loss.py | 6 +- makani/utils/losses/crps_loss.py | 180 ++++++++++++++-- makani/utils/metric.py | 12 +- makani/utils/training/ensemble_trainer.py | 14 +- 12 files changed, 404 insertions(+), 148 deletions(-) diff --git a/makani/models/networks/fourcastnet3.py b/makani/models/networks/fourcastnet3.py index 3764978..4eed2a6 100644 --- a/makani/models/networks/fourcastnet3.py +++ b/makani/models/networks/fourcastnet3.py @@ -32,7 +32,6 @@ import torch_harmonics.distributed as thd # get pre-formulated layers -#from makani.models.common import GeometricInstanceNormS2 from makani.mpu.layers import DistributedMLP, DistributedEncoderDecoder # more distributed stuff @@ -57,6 +56,47 @@ def _soft_clamp(x: torch.Tensor, offset: float = 0.0): y = torch.where(x >= 0.5, x - 0.25, y) return y +# helper module to handle imputation of SST +# class MLPImputer(nn.Module): +# def __init__( +# self, +# inp_chans = 2, +# out_chans = 2, +# mlp_ratio = 2.0, +# activation_function=nn.GELU, +# ): + +# self.mlp = EncoderDecoder( +# num_layers=1, +# input_dim=inp_chans, +# output_dim=out_chans, +# hidden_dim=int(mlp_ratio * out_chans), +# act_layer=activation_function, +# input_format="nchw", +# ) + +# def forward(self, inp, out): +# return torch.where(torch.isnan(out), self.mlp(inp), out) + +class ConstantImputation(nn.Module): + def __init__( + self, + inp_chans = 2, + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(inp_chans, 1, 1)) + + if comm.get_size("spatial") > 1: + self.weight.is_shared_mp = ["spatial"] + self.weight.sharded_dims_mp = [None, None, None] + + def forward(self, x, mask = None): + if mask is None: + mask = torch.isnan(x) + else: + mask = torch.logical_or(mask, torch.isnan(x)) + return torch.where(mask, self.weight, x) class DiscreteContinuousEncoder(nn.Module): def __init__( @@ -408,7 +448,8 @@ def __init__( n_history=0, atmo_embed_dim=8, surf_embed_dim=8, - aux_embed_dim=8, + dyn_aux_embed_dim=8, + stat_aux_embed_dim=8, pos_embed_dim=0, num_layers=4, num_groups=1, @@ -425,6 +466,7 @@ def __init__( sfno_block_frequency=2, big_skip=False, clamp_water=False, + encoder_bias=False, bias=False, checkpointing_level=0, freeze_encoder=False, @@ -437,13 +479,12 @@ def __init__( self.out_shape = out_shape self.atmo_embed_dim = atmo_embed_dim self.surf_embed_dim = surf_embed_dim - self.aux_embed_dim = aux_embed_dim + self.dyn_aux_embed_dim = dyn_aux_embed_dim + self.stat_aux_embed_dim = stat_aux_embed_dim self.pos_embed_dim = pos_embed_dim self.big_skip = big_skip self.checkpointing_level = checkpointing_level - - # currently doesn't support neither history nor future: - assert n_history == 0 + self.n_history = n_history # compute the downscaled image size self.h = int(self.inp_shape[0] // scale_factor) @@ -453,11 +494,12 @@ def __init__( self._init_spectral_transforms(model_grid_type, sht_grid_type, hard_thresholding_fraction, max_modes) # compute static permutations to extract - self._precompute_channel_groups(channel_names, aux_channel_names) + self._precompute_channel_groups(channel_names, aux_channel_names, n_history) # compute the total number of internal groups self.n_out_chans = self.n_atmo_groups * self.n_atmo_chans + self.n_surf_chans self.total_embed_dim = self.n_atmo_groups * self.atmo_embed_dim + self.surf_embed_dim + self.total_aux_embed_dim = (self.n_dyn_aux_chans > 0) * self.dyn_aux_embed_dim + (self.n_stat_aux_chans > 0) * self.stat_aux_embed_dim + self.pos_embed_dim # convert kernel shape to tuple kernel_shape = tuple(kernel_shape) @@ -473,11 +515,10 @@ def __init__( raise ValueError(f"Unknown activation function {activation_function}") # encoder for the atmospheric channels - # TODO: add the groups self.atmo_encoder = DiscreteContinuousEncoder( inp_shape=inp_shape, out_shape=(self.h, self.w), - inp_chans=self.n_atmo_chans, + inp_chans=self.n_atmo_chans * (self.n_history + 1), out_chans=self.atmo_embed_dim, grid_in=model_grid_type, grid_out=sht_grid_type, @@ -486,16 +527,16 @@ def __init__( basis_norm_mode=filter_basis_norm_mode, activation_function=activation_function, groups=math.gcd(self.n_atmo_chans, self.atmo_embed_dim), - bias=bias, + bias=encoder_bias, use_mlp=encoder_mlp, ) - # encoder for the auxiliary channels + # encoder for the surface channels if self.n_surf_chans > 0: self.surf_encoder = DiscreteContinuousEncoder( inp_shape=inp_shape, out_shape=(self.h, self.w), - inp_chans=self.n_surf_chans, + inp_chans=self.n_surf_chans * (self.n_history + 1), out_chans=self.surf_embed_dim, grid_in=model_grid_type, grid_out=sht_grid_type, @@ -504,10 +545,52 @@ def __init__( basis_norm_mode=filter_basis_norm_mode, activation_function=activation_function, groups=math.gcd(self.n_surf_chans, self.surf_embed_dim), - bias=bias, + bias=encoder_bias, + use_mlp=encoder_mlp, + ) + + if self.sst_channels_in.shape[0] > 0: + self.sst_imputation = ConstantImputation( + inp_chans=self.sst_channels_in.shape[0], + ) + + # encoder for the auxiliary channels + if self.n_dyn_aux_chans > 0: + self.dyn_aux_encoder = DiscreteContinuousEncoder( + inp_shape=inp_shape, + out_shape=(self.h, self.w), + inp_chans=self.n_dyn_aux_chans * (self.n_history + 1), + out_chans=self.dyn_aux_embed_dim, + grid_in=model_grid_type, + grid_out=sht_grid_type, + kernel_shape=kernel_shape, + basis_type=filter_basis_type, + basis_norm_mode=filter_basis_norm_mode, + activation_function=activation_function, + groups=math.gcd(self.n_dyn_aux_chans, self.dyn_aux_embed_dim), + bias=encoder_bias, use_mlp=encoder_mlp, ) + # encoder for the auxiliary channels + if self.n_stat_aux_chans > 0: + self.stat_aux_encoder = DiscreteContinuousEncoder( + inp_shape=inp_shape, + out_shape=(self.h, self.w), + inp_chans=self.n_stat_aux_chans, + out_chans=self.stat_aux_embed_dim, + grid_in=model_grid_type, + grid_out=sht_grid_type, + kernel_shape=kernel_shape, + basis_type=filter_basis_type, + basis_norm_mode=filter_basis_norm_mode, + activation_function=activation_function, + groups=math.gcd(self.n_stat_aux_chans, self.stat_aux_embed_dim), + bias=encoder_bias, + use_mlp=encoder_mlp, + ) + + # decoder for the atmospheric variables self.atmo_decoder = DiscreteContinuousDecoder( inp_shape=(self.h, self.w), @@ -521,7 +604,7 @@ def __init__( basis_norm_mode=filter_basis_norm_mode, activation_function=activation_function, groups=math.gcd(self.n_atmo_chans, self.atmo_embed_dim), - bias=bias, + bias=encoder_bias, use_mlp=encoder_mlp, upsample_sht=upsample_sht, ) @@ -540,29 +623,11 @@ def __init__( basis_norm_mode=filter_basis_norm_mode, activation_function=activation_function, groups=math.gcd(self.n_surf_chans, self.surf_embed_dim), - bias=bias, + bias=encoder_bias, use_mlp=encoder_mlp, upsample_sht=upsample_sht, ) - # encoder for the auxiliary channels - if self.n_aux_chans > 0: - self.aux_encoder = DiscreteContinuousEncoder( - inp_shape=inp_shape, - out_shape=(self.h, self.w), - inp_chans=self.n_aux_chans, - out_chans=self.aux_embed_dim, - grid_in=model_grid_type, - grid_out=sht_grid_type, - kernel_shape=kernel_shape, - basis_type=filter_basis_type, - basis_norm_mode=filter_basis_norm_mode, - activation_function=activation_function, - groups=math.gcd(self.n_aux_chans, self.aux_embed_dim), - bias=bias, - use_mlp=encoder_mlp, - ) - # position embedding if self.pos_embed_dim > 0: self.pos_embed = LearnablePositionEmbedding(img_shape=(self.h, self.w), grid=sht_grid_type, num_chans=self.pos_embed_dim, embed_type="lat") @@ -572,7 +637,7 @@ def __init__( dpr = [x.item() for x in torch.linspace(0, path_drop_rate, num_layers)] # get the handle for the normalization layer - norm_layer = self._get_norm_layer_handle(self.h, self.w, self.total_embed_dim, normalization_layer=normalization_layer, sht_grid_type=sht_grid_type) + norm_layer = self._get_norm_layer_handle(self.h, self.w, self.total_embed_dim + self.total_aux_embed_dim, normalization_layer=normalization_layer, sht_grid_type=sht_grid_type) # Internal NO blocks self.blocks = nn.ModuleList([]) @@ -589,7 +654,7 @@ def __init__( block = NeuralOperatorBlock( self.sht, self.isht, - self.total_embed_dim + (self.n_aux_chans > 0) * self.aux_embed_dim + self.pos_embed_dim, + self.total_embed_dim + self.total_aux_embed_dim, self.total_embed_dim, conv_type=conv_type, mlp_ratio=mlp_ratio, @@ -609,17 +674,6 @@ def __init__( self.blocks.append(block) - # residual prediction - if self.big_skip: - self.residual_transform = nn.Conv2d(self.n_out_chans, self.n_out_chans, 1, bias=False) - self.residual_transform.weight.is_shared_mp = ["spatial"] - self.residual_transform.weight.sharded_dims_mp = [None, None, None, None] - if self.residual_transform.bias is not None: - self.residual_transform.bias.is_shared_mp = ["spatial"] - self.residual_transform.bias.sharded_dims_mp = [None] - scale = math.sqrt(0.5 / self.n_out_chans) - nn.init.normal_(self.residual_transform.weight, mean=0.0, std=scale) - # controlled output normalization of q and tcwv if clamp_water: water_chans = get_water_channels(channel_names) @@ -738,22 +792,51 @@ def _precompute_channel_groups( group the channels appropriately into atmospheric pressure levels and surface variables """ - atmo_chans, surf_chans, aux_chans, pressure_lvls = get_channel_groups(channel_names, aux_channel_names) + atmo_chans, surf_chans, dyn_aux_chans, stat_aux_chans, pressure_lvls = get_channel_groups(channel_names, aux_channel_names) + sst_chans = [channel_names.index("sst")] if "sst" in channel_names else [] + lsml_chans = [len(channel_names) + aux_channel_names.index("xlsml")] if "xlsml" in aux_channel_names else [] # compute how many channel groups will be kept internally self.n_atmo_groups = len(pressure_lvls) self.n_atmo_chans = len(atmo_chans) // self.n_atmo_groups + self.n_surf_chans = len(surf_chans) + self.n_dyn_aux_chans = len(dyn_aux_chans) + self.n_stat_aux_chans= len(stat_aux_chans) + self.n_aux_chans = self.n_dyn_aux_chans + self.n_stat_aux_chans # make sure they are divisible. Attention! This does not guarantee that the grrouping is correct if len(atmo_chans) % self.n_atmo_groups: raise ValueError(f"Expected number of atmospheric variables to be divisible by number of atmospheric groups but got {len(atmo_chans)} and {self.n_atmo_groups}") - self.register_buffer("atmo_channels", torch.LongTensor(atmo_chans), persistent=False) - self.register_buffer("surf_channels", torch.LongTensor(surf_chans), persistent=False) - self.register_buffer("aux_channels", torch.LongTensor(aux_chans), persistent=False) - - self.n_surf_chans = self.surf_channels.shape[0] - self.n_aux_chans = self.aux_channels.shape[0] + # if history is included, adapt the channel lists to include the offsets + self.n_atmo_groups = self.n_atmo_groups + n_dyn_chans = len(atmo_chans) + len(surf_chans) + len(dyn_aux_chans) + atmo_chans_in = atmo_chans.copy() + surf_chans_in = surf_chans.copy() + sst_chans_in = sst_chans.copy() + for ih in range(1, n_history+1): + atmo_chans_in += [(c + ih*n_dyn_chans) for c in atmo_chans] + surf_chans_in += [(c + ih*n_dyn_chans) for c in surf_chans] + sst_chans_in += [(c + ih*n_dyn_chans) for c in ssts_chan] + dyn_aux_chans += [(c + ih*n_dyn_chans) for c in dyn_aux_chans] + # account for the history offset in the static aux channels + stat_aux_chans = [c + n_history*n_dyn_chans for c in stat_aux_chans] + lsml_chans = [c + n_history*n_dyn_chans for c in lsml_chans] + + self.register_buffer("atmo_channels_in", torch.LongTensor(atmo_chans_in), persistent=False) + self.register_buffer("atmo_channels_out", torch.LongTensor(atmo_chans), persistent=False) + self.register_buffer("surf_channels_in", torch.LongTensor(surf_chans_in), persistent=False) + self.register_buffer("surf_channels_out", torch.LongTensor(surf_chans), persistent=False) + self.register_buffer("sst_channels_in", torch.LongTensor(sst_chans_in), persistent=False) + self.register_buffer("sst_channels_out", torch.LongTensor(sst_chans), persistent=False) + self.register_buffer("dyn_aux_channels", torch.LongTensor(dyn_aux_chans), persistent=False) + self.register_buffer("stat_aux_channels", torch.LongTensor(stat_aux_chans), persistent=False) + self.register_buffer("land_mask_channels", torch.LongTensor(lsml_chans), persistent=False) + self.register_buffer("pred_channels", torch.LongTensor(surf_chans + atmo_chans), persistent=False) + + # print(f"in atmo: {self.atmo_channels_in}") + # print(f"out atmo: {self.atmo_channels_out}") + # print(f"q50: {channel_names.index("q50")}") return @@ -763,13 +846,24 @@ def encode(self, x): """ batchdims = x.shape[:-3] - # for atmospheric channels the same encoder is applied to each atmospheric level - x_atmo = x[..., self.atmo_channels, :, :].contiguous().reshape(-1, self.n_atmo_chans, *x.shape[-2:]) + if hasattr(self, "sst_imputation"): + if self.land_mask_channels.nelement() > 0: + mask = x[..., self.land_mask_channels, :, :] + else: + mask = None + x[..., self.sst_channels_in, :, :] = self.sst_imputation(x[..., self.sst_channels_in, :, :], mask=mask) + + # for atmospheric channels the same encoder is applied to each atmospheric level and takes the entire history into account + x_atmo = x[..., self.atmo_channels_in, :, :].reshape(-1, 1 + self.n_history, self.n_atmo_groups, self.n_atmo_chans, *x.shape[-2:]) + # move the history backwards and fold it into the channel dimension + x_atmo = x_atmo.permute(0,2,3,1,4,5).reshape(-1, self.n_atmo_chans * (1 + self.n_history), *x.shape[-2:]).contiguous() x_out = self.atmo_encoder(x_atmo) x_out = x_out.reshape(*batchdims, self.n_atmo_groups * self.atmo_embed_dim, *x_out.shape[-2:]) if hasattr(self, "surf_encoder"): - x_surf = x[..., self.surf_channels, :, :].contiguous() + x_surf = x[..., self.surf_channels_in, :, :].reshape(-1, 1 + self.n_history, self.n_surf_chans, *x.shape[-2:]) + # move the history backwards and fold it into the channel dimension + x_surf = x_surf.transpose(-3,-4).reshape(-1, self.n_surf_chans * (1 + self.n_history), *x.shape[-2:]).contiguous() x_surf = self.surf_encoder(x_surf) x_out = torch.cat((x_out, x_surf), dim=-3) @@ -783,19 +877,29 @@ def encode_auxiliary_channels(self, x): """ batchdims = x.shape[:-3] - if hasattr(self, "aux_encoder"): - x_aux = x[..., self.aux_channels, :, :] - x_aux = self.aux_encoder(x_aux) - x_aux = x_aux.reshape(*batchdims, self.aux_embed_dim, *x_aux.shape[-2:]) - else: - x_aux = None + aux_tensors = [] + + if hasattr(self, "dyn_aux_encoder"): + x_aux = x[..., self.dyn_aux_channels, :, :].reshape(-1, 1 + self.n_history, self.n_dyn_aux_chans, *x.shape[-2:]) + x_aux = x_aux.transpose(-3,-4).reshape(-1, self.n_dyn_aux_chans * (1 + self.n_history), *x.shape[-2:]).contiguous() + x_aux = self.dyn_aux_encoder(x_aux) + x_aux = x_aux.reshape(*batchdims, self.dyn_aux_embed_dim, *x_aux.shape[-2:]) + aux_tensors.append(x_aux) + + if hasattr(self, "stat_aux_encoder"): + x_aux = x[..., self.stat_aux_channels, :, :].contiguous() + x_aux = self.stat_aux_encoder(x_aux) + x_aux = x_aux.reshape(*batchdims, self.stat_aux_embed_dim, *x_aux.shape[-2:]) + aux_tensors.append(x_aux) if hasattr(self, "pos_embed"): x_pos = self.pos_embed() - if x_aux is not None: - x_aux = torch.cat([x_aux, x_pos], dim=-3) - else: - x_aux = x_pos + aux_tensors.append(x_pos) + + if len(aux_tensors) > 0: + x_aux = torch.cat(aux_tensors, dim=-3) + else: + x_aux = None return x_aux @@ -809,12 +913,13 @@ def decode(self, x): x_atmo = x[..., : (self.n_atmo_groups * self.atmo_embed_dim), :, :].reshape(-1, self.atmo_embed_dim, *x.shape[-2:]) x_atmo = self.atmo_decoder(x_atmo) x_out = torch.zeros(*batchdims, self.n_out_chans, *x_atmo.shape[-2:], dtype=x.dtype, device=x.device) - x_out[..., self.atmo_channels, :, :] = x_atmo.reshape(*batchdims, -1, *x_atmo.shape[-2:]) + x_out[..., self.atmo_channels_out, :, :] = x_atmo.reshape(*batchdims, -1, *x_atmo.shape[-2:]) + if hasattr(self, "surf_decoder"): x_surf = x[..., -self.surf_embed_dim :, :, :] x_surf = self.surf_decoder(x_surf) - x_out[..., self.surf_channels, :, :] = x_surf.reshape(*batchdims, -1, *x_surf.shape[-2:]) + x_out[..., self.surf_channels_out, :, :] = x_surf.reshape(*batchdims, -1, *x_surf.shape[-2:]) return x_out @@ -850,7 +955,7 @@ def forward(self, x): # save big skip if self.big_skip: - residual = x[..., : self.n_out_chans, :, :].contiguous() + residual = x[..., self.pred_channels, :, :].contiguous() # extract embeddings for the auxiliary embeddings x_aux = self.encode_auxiliary_channels(x) @@ -864,6 +969,13 @@ def forward(self, x): # run the processor x = self.processor_blocks(x, x_aux) + # for debugging print the activations + atmo_activations = x[..., : (self.n_atmo_groups * self.atmo_embed_dim), :, :].reshape(-1, self.n_atmo_groups, self.atmo_embed_dim, *x.shape[-2:]) + s, m = torch.std_mean(atmo_activations, dim=(0, -1, -2)) + print(f"group 0, stds: {s[0]} means: {m[0]}") + print(f"group -1, stds: {s[-1]} means: {m[-1]}") + + # run the decoder if self.checkpointing_level >= 1: x = checkpoint(self.decode, x, use_reentrant=False) @@ -871,7 +983,7 @@ def forward(self, x): x = self.decode(x) if self.big_skip: - x = x + self.residual_transform(residual) + x[..., self.pred_channels, :, :] = x + residual # apply output transform x = self.clamp_water_channels(x) diff --git a/makani/models/networks/pangu.py b/makani/models/networks/pangu.py index b5b8e55..ad6b352 100644 --- a/makani/models/networks/pangu.py +++ b/makani/models/networks/pangu.py @@ -452,7 +452,7 @@ def forward(self, x: torch.Tensor, mask=None): x: input features with shape of (B * num_lon, num_pl*num_lat, N, C) mask: (0/-inf) mask with shape of (num_lon, num_pl*num_lat, Wpl*Wlat*Wlon, Wpl*Wlat*Wlon) """ - + B_, nW_, N, C = x.shape qkv = ( self.qkv(x) @@ -478,7 +478,7 @@ def forward(self, x: torch.Tensor, mask=None): attn = self.attn_drop_fn(attn) x = self.apply_attention(attn, v, B_, nW_, N, C) - + else: if mask is not None: bias = mask.unsqueeze(1).unsqueeze(0) + earth_position_bias.unsqueeze(0).unsqueeze(0) @@ -486,10 +486,10 @@ def forward(self, x: torch.Tensor, mask=None): #bias = bias.squeeze(2) else: bias = earth_position_bias.unsqueeze(0) - + # extract batch size for q,k,v nLon = self.num_lon - q = q.view(B_ // nLon, nLon, q.shape[1], q.shape[2], q.shape[3], q.shape[4]) + q = q.view(B_ // nLon, nLon, q.shape[1], q.shape[2], q.shape[3], q.shape[4]) k = k.view(B_ // nLon, nLon, k.shape[1], k.shape[2], k.shape[3], k.shape[4]) v = v.view(B_ // nLon, nLon, v.shape[1], v.shape[2], v.shape[3], v.shape[4]) #### @@ -736,7 +736,7 @@ class Pangu(nn.Module): - https://arxiv.org/abs/2211.02556 """ - def __init__(self, + def __init__(self, inp_shape=(721,1440), out_shape=(721,1440), grid_in="equiangular", @@ -773,14 +773,14 @@ def __init__(self, self.checkpointing_level = checkpointing_level drop_path = np.linspace(0, drop_path_rate, 8).tolist() - + # Add static channels to surface self.num_aux = len(self.aux_channel_names) N_total_surface = self.num_aux + self.num_surface # compute static permutations to extract self._precompute_channel_groups(self.channel_names, self.aux_channel_names) - + # Patch embeddings are 2D or 3D convolutions, mapping the data to the required patches self.patchembed2d = PatchEmbed2D( img_size=self.inp_shape, @@ -791,7 +791,7 @@ def __init__(self, flatten=False, norm_layer=None, ) - + self.patchembed3d = PatchEmbed3D( img_size=(num_levels, self.inp_shape[0], self.inp_shape[1]), patch_size=patch_size, @@ -870,7 +870,7 @@ def __init__(self, self.patchrecovery3d = PatchRecovery3D( (num_levels, self.inp_shape[0], self.inp_shape[1]), patch_size, 2 * embed_dim, num_atmospheric ) - + def _precompute_channel_groups( self, channel_names=[], @@ -901,7 +901,7 @@ def _precompute_channel_groups( def prepare_input(self, input): """ - Prepares the input tensor for the Pangu model by splitting it into surface * static variables and atmospheric, + Prepares the input tensor for the Pangu model by splitting it into surface * static variables and atmospheric, and reshaping the atmospheric variables into the required format. """ @@ -932,23 +932,23 @@ def prepare_output(self, output_surface, output_atmospheric): level_dict = {level: [idx for idx, value in enumerate(self.channel_names) if value[1:] == level] for level in levels} reordered_ids = [idx for level in levels for idx in level_dict[level]] check_reorder = [f'{level}_{idx}' for level in levels for idx in level_dict[level]] - + # Flatten & reorder the output atmospheric to original order (doublechecked that this is working correctly!) flattened_atmospheric = output_atmospheric.reshape(output_atmospheric.shape[0], -1, output_atmospheric.shape[3], output_atmospheric.shape[4]) reordered_atmospheric = torch.cat([torch.zeros_like(output_surface), torch.zeros_like(flattened_atmospheric)], dim=1) for i in range(len(reordered_ids)): reordered_atmospheric[:, reordered_ids[i], :, :] = flattened_atmospheric[:, i, :, :] - + # Append the surface output, this has not been reordered. if output_surface is not None: - _, surf_chans, _, _ = features.get_channel_groups(self.channel_names, self.aux_channel_names) + _, surf_chans, _, _, _ = features.get_channel_groups(self.channel_names, self.aux_channel_names) reordered_atmospheric[:, surf_chans, :, :] = output_surface output = reordered_atmospheric else: output = reordered_atmospheric return output - + def forward(self, input): # Prep the input by splitting into surface and atmospheric variables @@ -959,7 +959,7 @@ def forward(self, input): surface = checkpoint(self.patchembed2d, surface_aux, use_reentrant=False) atmospheric = checkpoint(self.patchembed3d, atmospheric, use_reentrant=False) else: - surface = self.patchembed2d(surface_aux) + surface = self.patchembed2d(surface_aux) atmospheric = self.patchembed3d(atmospheric) if surface.shape[1] == 0: @@ -1011,11 +1011,5 @@ def forward(self, input): output_atmospheric = self.patchrecovery3d(output_atmospheric) output = self.prepare_output(output_surface, output_atmospheric) - - return output - - - - - + return output diff --git a/makani/models/networks/pangu_onnx.py b/makani/models/networks/pangu_onnx.py index 0805bad..bf3a006 100644 --- a/makani/models/networks/pangu_onnx.py +++ b/makani/models/networks/pangu_onnx.py @@ -38,7 +38,7 @@ class PanguOnnx(OnnxWrapper): channel_order_PL: List containing the names of the pressure levels with the ordering that the ONNX model expects onnx_file: Path to the ONNX file containing the model ''' - def __init__(self, + def __init__(self, channel_names=[], aux_channel_names=[], onnx_file=None, @@ -58,7 +58,7 @@ def _precompute_channel_groups( group the channels appropriately into atmospheric pressure levels and surface variables """ - atmo_chans, surf_chans, _, pressure_lvls = get_channel_groups(channel_names, aux_channel_names) + atmo_chans, surf_chans, _, _, pressure_lvls = get_channel_groups(channel_names, aux_channel_names) # compute how many channel groups will be kept internally self.n_atmo_groups = len(pressure_lvls) @@ -78,12 +78,12 @@ def prepare_input(self, input): B,V,Lat,Long=input.shape if B>1: - raise NotImplementedError("Not implemented yet for batch size greater than 1") + raise NotImplementedError("Not implemented yet for batch size greater than 1") input=input.squeeze(0) surface_aux_inp=input[self.surf_channels] atmospheric_inp=input[self.atmo_channels].reshape(self.n_atmo_groups,self.n_atmo_chans,Lat,Long).transpose(1,0) - + return surface_aux_inp, atmospheric_inp def prepare_output(self, output_surface, output_atmospheric): @@ -99,9 +99,9 @@ def prepare_output(self, output_surface, output_atmospheric): return output.unsqueeze(0) - + def forward(self, input): - + surface, atmospheric = self.prepare_input(input) @@ -109,5 +109,5 @@ def forward(self, input): output = self.prepare_output(output_surface, output) - + return output diff --git a/makani/models/stepper.py b/makani/models/stepper.py index f04590b..565402a 100644 --- a/makani/models/stepper.py +++ b/makani/models/stepper.py @@ -156,7 +156,7 @@ def _forward_eval(self, inp, update_state=True, replace_state=True): def forward(self, inp, update_state=True, replace_state=True): # decide which routine to call if self.training: - y = self._forward_train(inp, update_state=True, replace_state=replace_state) + y = self._forward_train(inp, update_state=update_state, replace_state=replace_state) else: y = self._forward_eval(inp, update_state=update_state, replace_state=replace_state) diff --git a/makani/utils/driver.py b/makani/utils/driver.py index 6e507c8..766df5a 100644 --- a/makani/utils/driver.py +++ b/makani/utils/driver.py @@ -632,11 +632,11 @@ def get_optimizer(self, model, params): if params.optimizer_type == "Adam": if self.log_to_screen: self.logger.info("using Adam optimizer") - optimizer = optim.Adam(all_parameters, betas=betas, lr=params.get("lr", 1e-3), weight_decay=params.get("weight_decay", 0), foreach=True) + optimizer = optim.Adam(all_parameters, lr=params.get("lr", 1e-3), betas=betas, eps=params.get("optimizer_eps", 1e-8), weight_decay=params.get("weight_decay", 0), foreach=True) elif params.optimizer_type == "AdamW": if self.log_to_screen: self.logger.info("using AdamW optimizer") - optimizer = optim.AdamW(all_parameters, betas=betas, lr=params.get("lr", 1e-3), weight_decay=params.get("weight_decay", 0), foreach=True) + optimizer = optim.AdamW(all_parameters, lr=params.get("lr", 1e-3), betas=betas, eps=params.get("optimizer_eps", 1e-8), weight_decay=params.get("weight_decay", 0), foreach=True) elif params.optimizer_type == "SGD": if self.log_to_screen: self.logger.info("using SGD optimizer") diff --git a/makani/utils/features.py b/makani/utils/features.py index 658af01..cab61c3 100644 --- a/makani/utils/features.py +++ b/makani/utils/features.py @@ -97,13 +97,15 @@ def get_wind_channels(channel_names): def get_channel_groups(channel_names, aux_channel_names=[]): """ - Helper routine to extract indices of atmospheric, surface and auxiliary variables and group them into their respective groups + Helper routine to extract indices of atmospheric, surface and auxiliary variables and group them into their respective groups. + The resulting numbering does NOT respect history. """ atmo_groups = OrderedDict() atmo_chans = [] surf_chans = [] - aux_chans = [] + dyn_aux_chans = [] + stat_aux_chans = [] # parse channel names and group variables by pressure level/surface variables for idx, chn in enumerate(channel_names): @@ -127,6 +129,10 @@ def get_channel_groups(channel_names, aux_channel_names=[]): atmo_chans += idx # append the auxiliary variable to the surface channels - aux_chans = [idx + len(channel_names) for idx in range(len(aux_channel_names))] + for idx, chn in enumerate(aux_channel_names): + if chn in ["xoro", "xlsml", "xlsms"]: + stat_aux_chans.append(idx + len(channel_names)) + else: + dyn_aux_chans.append(idx + len(channel_names)) - return atmo_chans, surf_chans, aux_chans, atmo_groups.keys() + return atmo_chans, surf_chans, dyn_aux_chans, stat_aux_chans, atmo_groups.keys() diff --git a/makani/utils/loss.py b/makani/utils/loss.py index 677081c..a3776f1 100644 --- a/makani/utils/loss.py +++ b/makani/utils/loss.py @@ -32,7 +32,7 @@ from torch_harmonics.quadrature import clenshaw_curtiss_weights, legendre_gauss_weights from .losses import LossType, GeometricLpLoss, SpectralH1Loss, SpectralAMSELoss -from .losses import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleVortDivCRPSLoss +from .losses import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleVortDivCRPSLoss, EnergyScoreLoss from .losses import EnsembleNLLLoss, EnsembleMMDLoss from .losses import DriftRegularization, HydrostaticBalanceLoss @@ -119,8 +119,6 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps ) # append to dict and compile before: - # TODO: fix the compile issue - # self.loss_fn[loss_type] = torch.compile(loss_fn) self.loss_fn.append(loss_fn) # determine channel weighting @@ -140,7 +138,8 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps # get channel weights either directly or through the compute routine if isinstance(channel_weight_type, List): chw = torch.tensor(channel_weight_type, dtype=torch.float32) - chw = chw * time_diff_scale + if time_diff_scale is not None: + chw = chw * time_diff_scale assert chw.shape[1] == loss_fn.n_channels else: chw = loss_fn.compute_channel_weighting(channel_weight_type, time_diff_scale=time_diff_scale) @@ -228,6 +227,8 @@ def _parse_loss_type(self, loss_type: str): loss_handle = EnsembleNLLLoss elif "ensemble_mmd" in loss_type: loss_handle = EnsembleMMDLoss + elif "energy_score" in loss_type: + loss_handle = partial(EnergyScoreLoss) elif "drift_regularization" in loss_type: loss_handle = DriftRegularization else: @@ -333,6 +334,8 @@ def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tens loss_vals.append(lfn(prd, tar, wgt)) all_losses = torch.cat(loss_vals, dim=-1) + # print(all_losses) + if self.training and self.track_running_stats: self._update_running_stats(all_losses.clone()) @@ -340,12 +343,14 @@ def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tens chw = self.channel_weights if self.uncertainty_weighting and self.training: var, _ = self.get_running_stats() + if self.num_batches_tracked.item() <= 100: + var = torch.ones_like(var) chw = chw / (torch.sqrt(2 * var) + self.eps) elif self.balanced_weighting and self.training: _, mean = self.get_running_stats() if self.num_batches_tracked.item() <= 100: mean = torch.ones_like(mean) - chw = chw / mean + chw = chw / (mean + self.eps) if self.randomized_loss_weights: rmask = torch.zeros_like(chw) diff --git a/makani/utils/losses/__init__.py b/makani/utils/losses/__init__.py index a4aef77..fff3e7a 100644 --- a/makani/utils/losses/__init__.py +++ b/makani/utils/losses/__init__.py @@ -19,6 +19,7 @@ from .amse_loss import SpectralAMSELoss from .hydrostatic_loss import HydrostaticBalanceLoss from .crps_loss import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleVortDivCRPSLoss +from .crps_loss import EnergyScoreLoss from .mmd_loss import EnsembleMMDLoss from .likelihood_loss import EnsembleNLLLoss from .drift_regularization import DriftRegularization diff --git a/makani/utils/losses/base_loss.py b/makani/utils/losses/base_loss.py index 6614efe..5eb8a3c 100644 --- a/makani/utils/losses/base_loss.py +++ b/makani/utils/losses/base_loss.py @@ -45,7 +45,7 @@ def _compute_channel_weighting_helper(channel_names: List[str], channel_weight_t elif channel_weight_type == "auto": for c, chn in enumerate(channel_names): - if chn in ["u10m", "v10m", "u100m", "v100m", "tp", "sp", "msl", "tcwv"]: + if chn in ["u10m", "v10m", "u100m", "v100m", "tp", "sp", "msl", "tcwv", "sst"]: channel_weights[c] = 0.1 elif chn in ["t2m", "2d"]: channel_weights[c] = 1.0 @@ -58,7 +58,7 @@ def _compute_channel_weighting_helper(channel_names: List[str], channel_weight_t elif channel_weight_type == "new auto": for c, chn in enumerate(channel_names): - if chn in ["u10m", "v10m", "u100m", "v100m", "tp", "sp", "msl", "tcwv"]: + if chn in ["u10m", "v10m", "u100m", "v100m", "tp", "sp", "msl", "tcwv", "sst"]: channel_weights[c] = 0.1 elif chn in ["t2m", "2d"]: channel_weights[c] = 2.0 @@ -71,7 +71,7 @@ def _compute_channel_weighting_helper(channel_names: List[str], channel_weight_t elif channel_weight_type == "new auto 2": for c, chn in enumerate(channel_names): - if chn in ["u10m", "v10m", "u100m", "v100m", "tp", "sp", "msl", "tcwv"]: + if chn in ["u10m", "v10m", "u100m", "v100m", "tp", "sp", "msl", "tcwv", "sst"]: channel_weights[c] = 0.1 elif chn in ["t2m", "2d"]: channel_weights[c] = 2.0 diff --git a/makani/utils/losses/crps_loss.py b/makani/utils/losses/crps_loss.py index 8119e43..3ef5f55 100644 --- a/makani/utils/losses/crps_loss.py +++ b/makani/utils/losses/crps_loss.py @@ -118,9 +118,8 @@ def _crps_skillspread_kernel(observation: torch.Tensor, forecasts: torch.Tensor, observation = observation.unsqueeze(0) - # get nanmask - nanmasks = torch.logical_or(torch.isnan(forecasts), torch.isnan(weights)) - nanmask = torch.sum(nanmasks, dim=0).bool() + # get nanmask from the observarions + nanmasks = torch.logical_or(torch.isnan(observation), torch.isnan(weights)) # compute total weights nweights = torch.where(nanmasks, 0.0, weights) @@ -136,11 +135,7 @@ def _crps_skillspread_kernel(observation: torch.Tensor, forecasts: torch.Tensor, espread = 2 * torch.mean((2 * rank - num_ensemble - 1) * forecasts, dim=0) * (float(num_ensemble) - 1.0 + alpha) / float(num_ensemble * (num_ensemble - 1)) eskill = (observation - forecasts).abs().mean(dim=0) - # crps = torch.where(nanmasks.sum(dim=0) != 0, torch.nan, eskill - 0.5 * espread) - crps = eskill - 0.5 * espread - - # set to nan for first forecasts nan - crps = torch.where(nanmask, torch.nan, crps) + crps = torch.where(nanmasks.sum(dim=0) != 0, 0.0, eskill - 0.5 * espread) return crps @@ -154,9 +149,8 @@ def _crps_probability_weighted_moment_kernel(observation: torch.Tensor, forecast observation = observation.unsqueeze(0) - # get nanmask - nanmasks = torch.logical_or(torch.isnan(forecasts), torch.isnan(weights)) - nanmask = torch.sum(nanmasks, dim=0).bool() + # get nanmask from the observarions + nanmasks = torch.logical_or(torch.isnan(observation), torch.isnan(weights)) # compute total weights nweights = torch.where(nanmasks, 0.0, weights) @@ -177,7 +171,7 @@ def _crps_probability_weighted_moment_kernel(observation: torch.Tensor, forecast crps = eskill + beta0 - 2 * beta1 # set to nan for first forecasts nan - crps = torch.where(nanmask, torch.nan, crps) + crps = torch.where(nanmasks.sum(dim=0) != 0, 0.0, crps) return crps @@ -190,9 +184,8 @@ def _crps_independent_skillspread_kernel(observation: torch.Tensor, forecasts: t observation = observation.unsqueeze(0) - # get nanmask - nanmasks = torch.logical_or(torch.isnan(forecasts), torch.isnan(weights)) - nanmask = torch.sum(nanmasks, dim=0).bool() + # get nanmask from the observarions + nanmasks = torch.logical_or(torch.isnan(observation), torch.isnan(weights)) # compute total weights nweights = torch.where(nanmasks, 0.0, weights) @@ -209,7 +202,7 @@ def _crps_independent_skillspread_kernel(observation: torch.Tensor, forecasts: t crps = eskill - 0.5 * espread # set to nan for first forecasts nan - crps = torch.where(nanmask, torch.nan, crps) + crps = torch.where(nanmasks.sum(dim=0) != 0, 0.0, crps) return crps @@ -221,9 +214,8 @@ def _crps_naive_skillspread_kernel(observation: torch.Tensor, forecasts: torch.T observation = observation.unsqueeze(0) - # get nanmask - nanmasks = torch.logical_or(torch.isnan(forecasts), torch.isnan(weights)) - nanmask = torch.sum(nanmasks, dim=0).bool() + # get nanmask from the observarions + nanmasks = torch.logical_or(torch.isnan(observation), torch.isnan(weights)) # compute total weights nweights = torch.where(nanmasks, 0.0, weights) @@ -233,14 +225,14 @@ def _crps_naive_skillspread_kernel(observation: torch.Tensor, forecasts: torch.T num_ensemble = forecasts.shape[0] # use broadcasting semantics to compute spread and skill - espread = (forecasts.unsqueeze(1) - forecasts).abs().sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) + espread = (forecasts.unsqueeze(1) - forecasts.unsqueeze(0)).abs().sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) eskill = (observation - forecasts).abs().mean(dim=0) # crps = torch.where(nanmasks.sum(dim=0) != 0, torch.nan, eskill - 0.5 * espread) crps = eskill - 0.5 * espread # set to nan for first forecasts nan - crps = torch.where(nanmask, torch.nan, crps) + crps = torch.where(nanmasks.sum(dim=0) != 0, 0.0, crps) return crps @@ -772,4 +764,148 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w crps = reduce_from_parallel_region(crps, "spatial") # the resulting tensor should have dimension B, C, which is what we return - return crps \ No newline at end of file + return crps + +class EnergyScoreLoss(GeometricBaseLoss): + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + pole_mask: int, + spatial_distributed: Optional[bool] = False, + ensemble_distributed: Optional[bool] = False, + ensemble_weights: Optional[torch.Tensor] = None, + alpha: Optional[float] = 1.0, + beta: Optional[float] = 1.0, + eps: Optional[float] = 1.0e-5, + **kwargs, + ): + + super().__init__( + img_shape=img_shape, + crop_shape=crop_shape, + crop_offset=crop_offset, + channel_names=channel_names, + grid_type=grid_type, + pole_mask=pole_mask, + spatial_distributed=spatial_distributed, + ) + + self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed + self.alpha = alpha + self.beta = beta + self.eps = eps + + # we also need a variant of the weights split in ensemble direction: + quad_weight_split = self.quadrature.quad_weight.reshape(1, 1, -1) + if self.ensemble_distributed: + quad_weight_split = split_tensor_along_dim(quad_weight_split, dim=-1, num_chunks=comm.get_size("ensemble"))[comm.get_rank("ensemble")] + quad_weight_split = quad_weight_split.contiguous() + self.register_buffer("quad_weight_split", quad_weight_split, persistent=False) + + if ensemble_weights is not None: + self.register_buffer("ensemble_weights", ensemble_weights, persistent=False) + else: + self.ensemble_weights = ensemble_weights + + @property + def type(self): + return LossType.Probabilistic + + @property + def n_channels(self): + return 1 + + @torch.compiler.disable(recursive=False) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: str) -> torch.Tensor: + return torch.ones(1) + + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + + # sanity checks + if forecasts.dim() != 5: + raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.") + + # we assume that spatial_weights have NO ensemble dim + if (spatial_weights is not None) and (spatial_weights.dim() != observations.dim()): + spdim = spatial_weights.dim() + odim = observations.dim() + raise ValueError(f"the weights have to have the same number of dimensions (found {spdim}) as observations (found {odim}).") + + # we assume the following shapes: + # forecasts: batch, ensemble, channels, lat, lon + # observations: batch, channels, lat, lon + B, E, C, H, W = forecasts.shape + + # transpose the forecasts to ensemble, batch, channels, lat, lon and then do distributed transpose into ensemble direction. + # ideally we split spatial dims + forecasts = torch.moveaxis(forecasts, 1, 0) + forecasts = forecasts.reshape(E, B, C, H * W) + if self.ensemble_distributed: + ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] + forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") + + # observations does not need a transpose, but just a split + observations = observations.reshape(1, B, C, H * W) + if self.ensemble_distributed: + observations = scatter_to_parallel_region(observations, -1, "ensemble") + + # for correct spatial reduction we need to do the same with spatial weights + if spatial_weights is not None: + spatial_weights_split = spatial_weights.flatten(start_dim=-2, end_dim=-1) + spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") + + if self.ensemble_weights is not None: + raise NotImplementedError("currently only constant ensemble weights are supported") + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # ensemble size + num_ensemble = forecasts.shape[0] + + # get nanmask from the observarions + nanmasks = torch.logical_or(torch.isnan(observations), torch.isnan(ensemble_weights)) + + # use broadcasting semantics to compute spread and skill and sum over channels (vector norm) + espread = (forecasts.unsqueeze(1) - forecasts.unsqueeze(0)).abs().pow(self.beta) + eskill = (observations - forecasts).abs().pow(self.beta) + + # perform masking before any reduction + espread = torch.where(nanmasks.sum(dim=0) != 0, 0.0, espread) + eskill = torch.where(nanmasks.sum(dim=0) != 0, 0.0, eskill) + + # do the channel reduction while ignoring NaNs + # if channel weights are required they should be added here to the reduction + espread = espread.sum(dim=-2, keepdim=True) + eskill = eskill.sum(dim=-2, keepdim=True) + + # do the spatial reduction + if spatial_weights is not None: + espread = torch.sum(espread * self.quad_weight_split * spatial_weights_split, dim=-1) + eskill = torch.sum(eskill * self.quad_weight_split * spatial_weights_split, dim=-1) + else: + espread = torch.sum(espread * self.quad_weight_split, dim=-1) + eskill = torch.sum(eskill * self.quad_weight_split, dim=-1) + + # since we split spatial dim into ensemble dim, we need to do an ensemble sum as well + if self.ensemble_distributed: + espread = reduce_from_parallel_region(espread, "ensemble") + eskill = reduce_from_parallel_region(eskill, "ensemble") + + # we need to do the spatial averaging manually since + # we are not calling the quadrature forward function + if self.spatial_distributed: + espread = reduce_from_parallel_region(espread, "spatial") + eskill = reduce_from_parallel_region(eskill, "spatial") + + # now we have reduced everything and need to sum appropriately + espread = espread.sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) + eskill = eskill.sum(dim=0) / float(num_ensemble) + + # the resulting tensor should have dimension B, C which is what we return + return eskill - 0.5 * espread \ No newline at end of file diff --git a/makani/utils/metric.py b/makani/utils/metric.py index 4edba2a..fd8118e 100644 --- a/makani/utils/metric.py +++ b/makani/utils/metric.py @@ -213,12 +213,12 @@ def __init__( climatology, num_rollout_steps, device, - l1_var_names=["u10m", "t2m", "u500", "z500", "q500", "q50", "sp"], - rmse_var_names=["u10m", "t2m", "u500", "z500", "q500", "q50", "sp"], - acc_var_names=["u10m", "t2m", "u500", "z500", "q500", "q50", "sp"], - crps_var_names=["u10m", "t2m", "u500", "z500", "q500", "q50", "sp"], - spread_var_names=["u10m", "t2m", "u500", "z500", "q500", "q50", "sp"], - ssr_var_names=["u10m", "t2m", "u500", "z500", "q500", "q50", "sp"], + l1_var_names=["u10m", "t2m", "sp", "sst", "u500", "z500", "q500", "q50"], + rmse_var_names=["u10m", "t2m", "sp", "sst", "u500", "z500", "q500", "q50"], + acc_var_names=["u10m", "t2m", "sp", "sst", "u500", "z500", "q500", "q50"], + crps_var_names=["u10m", "t2m", "sp", "sst", "u500", "z500", "q500", "q50"], + spread_var_names=["u10m", "t2m", "sp", "sst", "u500", "z500", "q500", "q50"], + ssr_var_names=["u10m", "t2m", "sp", "sst", "u500", "z500", "q500", "q50"], rh_var_names=[], wb2_compatible=False, ): diff --git a/makani/utils/training/ensemble_trainer.py b/makani/utils/training/ensemble_trainer.py index cf3ad3c..4934f8a 100644 --- a/makani/utils/training/ensemble_trainer.py +++ b/makani/utils/training/ensemble_trainer.py @@ -182,11 +182,11 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] = self.loss_obj = self.loss_obj.to(self.device) self.timers["loss handler init"] = timer.time - # channel weights: - if self.log_to_screen: - chw_weights = self.loss_obj.channel_weights.squeeze().cpu().numpy().tolist() - chw_output = {k: v for k,v in zip(self.params.channel_names, chw_weights)} - self.logger.info(f"Channel weights: {chw_output}") + # # channel weights: + # if self.log_to_screen: + # chw_weights = self.loss_obj.channel_weights.squeeze().cpu().numpy().tolist() + # chw_output = {k: v for k,v in zip(self.params.channel_names, chw_weights)} + # self.logger.info(f"Channel weights: {chw_output}") # optimizer and scheduler setup # model @@ -248,7 +248,9 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] = # visualization wrapper: with Timer() as timer: - plot_list = [{"name": "windspeed_uv10", "functor": "lambda x: np.sqrt(np.square(x[0, ...]) + np.square(x[1, ...]))", "diverging": False}] + plot_channel = "q50" + plot_index = self.params.channel_names.index(plot_channel) + plot_list = [{"name": plot_channel, "functor": f"lambda x: x[{plot_index}, ...]", "diverging": False}] out_bias, out_scale = self.train_dataloader.get_output_normalization() self.visualizer = visualize.VisualizationWrapper( self.params.log_to_wandb, From 9bc000c172955dbd3be8932953c4274179fa079d Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Thu, 30 Oct 2025 08:36:50 -0700 Subject: [PATCH 15/32] restoring logic from main regarding noise states --- makani/utils/training/ensemble_trainer.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/makani/utils/training/ensemble_trainer.py b/makani/utils/training/ensemble_trainer.py index 4934f8a..752b566 100644 --- a/makani/utils/training/ensemble_trainer.py +++ b/makani/utils/training/ensemble_trainer.py @@ -615,11 +615,10 @@ def train_one_epoch(self, profiler=None): return train_time, total_data_gb, logs - def _initialize_noise_states(self, seed_offset=666): + def _initialize_noise_states(self): noise_states = [] - for ide in range(self.params.local_ensemble_size): - member_seed = seed_offset + self.preprocessor.get_base_seed(default=333) * ide - self.preprocessor.set_rng(seed=member_seed, reset=True) + for _ in range(self.params.local_ensemble_size): + self.preprocessor.update_internal_state(replace_state=True) noise_states.append(self.preprocessor.get_internal_state(tensor=True)) return noise_states @@ -668,7 +667,7 @@ def validate_one_epoch(self, epoch, profiler=None): # do autoregression for each ensemble member individually # do the rollout # initialize the noise states with random seeds: - noise_states = self._initialize_noise_states(seed_offset=eval_steps) + noise_states = self._initialize_noise_states() inptlist = [inp.clone() for _ in range(self.params.local_ensemble_size)] # loop over lead times @@ -686,7 +685,7 @@ def validate_one_epoch(self, epoch, profiler=None): # retrieve input inpt = inptlist[e] - # restore noise state + # this is different, depending on local ensemble size if self.params.local_ensemble_size > 1: # recover correct state self.preprocessor.set_internal_state(noise_states[e]) From ff71af3fd5faf630b1ab4bbba0be635b496910df Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Sat, 1 Nov 2025 02:45:59 -0700 Subject: [PATCH 16/32] fixing model package --- makani/convert_checkpoint.py | 2 +- makani/models/model_package.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/makani/convert_checkpoint.py b/makani/convert_checkpoint.py index be4f95c..19fed0a 100644 --- a/makani/convert_checkpoint.py +++ b/makani/convert_checkpoint.py @@ -100,7 +100,7 @@ def consolidate_checkpoints(input_path, output_path, checkpoint_version=0): print(checkpoint_paths) # load the static data necessary for instantiating the preprocessor (required due to the way the registry works) - LocalPackage._load_static_data(input_path, params) + LocalPackage._load_static_data(LocalPackage(input_path), params) # get the model multistep = params.n_future > 0 diff --git a/makani/models/model_package.py b/makani/models/model_package.py index b674133..94ac71d 100644 --- a/makani/models/model_package.py +++ b/makani/models/model_package.py @@ -38,7 +38,7 @@ class LocalPackage: """ - Implements the earth2mip/modulus Package interface. + Implements the modulus Package interface. """ # These define the model package in terms of where makani expects the files to be located @@ -49,7 +49,7 @@ class LocalPackage: MEANS_FILE = "global_means.npy" STDS_FILE = "global_stds.npy" OROGRAPHY_FILE = "orography.nc" - LANDMASK_FILE = "land_mask.nc" + LANDMASK_FILE = "land_sea_mask.nc" SOILTYPE_FILE = "soil_type.nc" def __init__(self, root): @@ -148,11 +148,11 @@ def timestep(self): def update_state(self, replace_state=True): self.model.preprocessor.update_internal_state(replace_state=replace_state) return - + def set_rng(self, reset=True, seed=333): self.model.preprocessor.set_rng(reset=reset, seed=seed) return - + def forward(self, x, time, normalized_data=True, replace_state=None): if not normalized_data: x = (x - self.in_bias) / self.in_scale From 45a4c2e6f1bacd0f1f944b1571211ab5a667280e Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Wed, 12 Nov 2025 04:56:32 -0800 Subject: [PATCH 17/32] some changes to deterministic trainer --- makani/utils/training/deterministic_trainer.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/makani/utils/training/deterministic_trainer.py b/makani/utils/training/deterministic_trainer.py index 1ed27d1..b3e149d 100644 --- a/makani/utils/training/deterministic_trainer.py +++ b/makani/utils/training/deterministic_trainer.py @@ -179,11 +179,11 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] = self.loss_obj = self.loss_obj.to(self.device) self.timers["loss handler init"] = timer.time - # channel weights: - if self.log_to_screen: - chw_weights = self.loss_obj.channel_weights.squeeze().cpu().numpy().tolist() - chw_output = {k: v for k,v in zip(self.params.channel_names, chw_weights)} - self.logger.info(f"Channel weights: {chw_output}") + # # channel weights: + # if self.log_to_screen: + # chw_weights = self.loss_obj.channel_weights.squeeze().cpu().numpy().tolist() + # chw_output = {k: v for k,v in zip(self.params.channel_names, chw_weights)} + # self.logger.info(f"Channel weights: {chw_output}") # optimizer and scheduler setup with Timer() as timer: @@ -243,7 +243,10 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] = # visualization wrapper: with Timer() as timer: - plot_list = [{"name": "windspeed_uv10", "functor": "lambda x: np.sqrt(np.square(x[0, ...]) + np.square(x[1, ...]))", "diverging": False}] + plot_channel = "q50" + plot_index = self.params.channel_names.index(plot_channel) + print(self.params.channel_names) + plot_list = [{"name": plot_channel, "functor": f"lambda x: x[{plot_index}, ...]", "diverging": False}] out_bias, out_scale = self.train_dataloader.get_output_normalization() self.visualizer = visualize.VisualizationWrapper( self.params.log_to_wandb, From 25862ae7ba1595cbc8af2c586988e89e040fe3b4 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Mon, 24 Nov 2025 01:32:19 -0800 Subject: [PATCH 18/32] removing the residual training optiona and instead extending loss handler to be able to handle tendencies --- config/afnonet.yaml | 3 - config/debug.yaml | 3 - config/fourcastnet3.yaml | 3 - config/pangu.yaml | 7 +- config/sfnonet.yaml | 3 - makani/models/common/layers.py | 2 +- makani/models/noise.py | 11 +- makani/models/preprocessor.py | 22 +-- makani/models/stepper.py | 10 -- makani/mpu/layers.py | 128 ++++++++++++++++++ makani/utils/loss.py | 65 ++++++++- makani/utils/losses/amse_loss.py | 4 +- makani/utils/losses/h1_loss.py | 2 +- .../utils/training/deterministic_trainer.py | 10 +- makani/utils/training/ensemble_trainer.py | 8 +- makani/utils/training/stochastic_trainer.py | 4 +- tests/distributed/distributed_helpers.py | 5 +- tests/testutils.py | 1 - 18 files changed, 220 insertions(+), 71 deletions(-) diff --git a/config/afnonet.yaml b/config/afnonet.yaml index 6eab5f8..02a1d88 100644 --- a/config/afnonet.yaml +++ b/config/afnonet.yaml @@ -57,9 +57,6 @@ full_field: &BASE_CONFIG nested_skip_fno: !!bool True # whether to nest the inner skip connection or have it be sequential, inside the AFNO block verbose: False - #options default, residual - target: "default" - channel_names: ["u10m", "v10m", "t2m", "sp", "msl", "t850", "u1000", "v1000", "z1000", "u850", "v850", "z850", "u500", "v500", "z500", "t500", "z50", "r500", "r850", "tcwv", "u100m", "v100m", "u250", "v250", "z250", "t250", "u100", "v100", "z100", "t100", "u900", "v900", "z900", "t900"] normalization: "zscore" #options zscore or minmax hard_thresholding_fraction: 1.0 diff --git a/config/debug.yaml b/config/debug.yaml index 8474473..cbd7b19 100644 --- a/config/debug.yaml +++ b/config/debug.yaml @@ -90,9 +90,6 @@ base_config: &BASE_CONFIG add_noise: !!bool False noise_std: 0. - target: "default" # options default, residual - normalize_residual: false - # define channels to be read from data channel_names: ["u10m", "v10m", "u100m", "v100m", "t2m", "sp", "msl", "tcwv", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"] # normalization mode zscore but for q diff --git a/config/fourcastnet3.yaml b/config/fourcastnet3.yaml index d27c629..21925c9 100644 --- a/config/fourcastnet3.yaml +++ b/config/fourcastnet3.yaml @@ -87,9 +87,6 @@ base_config: &BASE_CONFIG add_noise: !!bool False noise_std: 0. - target: "default" # options default, residual - normalize_residual: false - # define channels to be read from data. sp has been removed here channel_names: ["u10m", "v10m", "u100m", "v100m", "t2m", "msl", "tcwv", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"] # normalization mode zscore but for q diff --git a/config/pangu.yaml b/config/pangu.yaml index c2852a1..7f49193 100644 --- a/config/pangu.yaml +++ b/config/pangu.yaml @@ -78,9 +78,6 @@ base_config: &BASE_CONFIG add_noise: !!bool False noise_std: 0. - target: "default" # options default, residual - normalize_residual: !!bool False - # define channels to be read from data channel_names: ["u10m", "v10m", "t2m", "msl", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"] normalization: "zscore" # options zscore or minmax or none @@ -131,10 +128,10 @@ base_onnx: &BASE_ONNX # ONNX wrapper related overwrite nettype: "/makani/makani/makani/models/networks/pangu_onnx.py:PanguOnnx" onnx_file: '/model/pangu_weather_6.onnx' - + amp_mode: "none" disable_ddp: True - + # Set Pangu ONNX channel order channel_names: ["msl", "u10m", "v10m", "t2m", "z1000", "z925", "z850", "z700", "z600", "z500", "z400", "z300", "z250", "z200", "z150", "z100", "z50", "q1000", "q925", "q850", "q700", "q600", "q500", "q400", "q300", "q250", "q200", "q150", "q100", "q50", "t1000", "t925", "t850", "t700", "t600", "t500", "t400", "t300", "t250", "t200", "t150", "t100", "t50", "u1000", "u925", "u850", "u700", "u600", "u500", "u400", "u300", "u250", "u200", "u150", "u100", "u50", "v1000", "v925", "v850", "v700", "v600", "v500", "v400", "v300", "v250", "v200", "v150", "v100", "v50"] # Remove input/output normalization diff --git a/config/sfnonet.yaml b/config/sfnonet.yaml index 271a0ca..b70b1d4 100644 --- a/config/sfnonet.yaml +++ b/config/sfnonet.yaml @@ -86,9 +86,6 @@ base_config: &BASE_CONFIG add_noise: !!bool False noise_std: 0. - target: "default" # options default, residual - normalize_residual: !!bool False - # define channels to be read from data channel_names: ["u10m", "v10m", "u100m", "v100m", "t2m", "sp", "msl", "tcwv", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"] normalization: "zscore" # options zscore or minmax or none diff --git a/makani/models/common/layers.py b/makani/models/common/layers.py index 49c5910..46e394a 100644 --- a/makani/models/common/layers.py +++ b/makani/models/common/layers.py @@ -605,5 +605,5 @@ def forward(self, x): x = self.norm(x) x = self.linear(x) - + return x diff --git a/makani/models/noise.py b/makani/models/noise.py index 7b7d845..a1f3d52 100644 --- a/makani/models/noise.py +++ b/makani/models/noise.py @@ -193,21 +193,28 @@ def __init__( alpha = float(alpha) # Compute ls, angular power spectrum and sigma_l: - ls = torch.arange(self.lmax) + ls = torch.arange(self.lmax).reshape(-1 ,1) + ms = torch.arange(self.mmax) power_spectrum = torch.pow(2 * ls + 1, -alpha) norm_factor = torch.sum((2 * ls + 1) * power_spectrum / 4.0 / math.pi) sigma_l = sigma * torch.sqrt(power_spectrum / norm_factor) + sigma_l = torch.where(ls <= ms, sigma_l, 0.0) # the new shape is B, T, C, L, M - sigma_l = sigma_l.reshape((1, 1, 1, self.lmax, 1)).to(dtype=torch.float32) + sigma_l = sigma_l.reshape((1, 1, 1, self.lmax, self.mmax)).to(dtype=torch.float32) # split tensor if comm.get_size("h") > 1: sigma_l = split_tensor_along_dim(sigma_l, dim=-2, num_chunks=comm.get_size("h"))[comm.get_rank("h")] + # split tensor + if comm.get_size("w") > 1: + sigma_l = split_tensor_along_dim(sigma_l, dim=-1, num_chunks=comm.get_size("w"))[comm.get_rank("w")] + # register buffer if learnable: self.register_parameter("sigma_l", nn.Parameter(sigma_l)) + self.sigma_l.sharded_dims_mp = [None, None, None, "h", "w"] else: self.register_buffer("sigma_l", sigma_l, persistent=False) diff --git a/makani/models/preprocessor.py b/makani/models/preprocessor.py index 59dbc2f..071de92 100644 --- a/makani/models/preprocessor.py +++ b/makani/models/preprocessor.py @@ -53,13 +53,7 @@ def __init__(self, params): self.history_eps = 1e-6 # residual normalization - self.learn_residual = params.target == "residual" - if self.learn_residual and (params.normalize_residual): - with torch.no_grad(): - residual_scale = torch.as_tensor(np.load(params.time_diff_stds_path)).to(torch.float32) - self.register_buffer("residual_scale", residual_scale, persistent=False) - else: - self.residual_scale = None + self.residual_scale = None # image shape self.img_shape = [params.img_shape_x, params.img_shape_y] @@ -178,20 +172,6 @@ def expand_history(self, x, nhist): x = torch.reshape(x, (b_, nhist, ct_ // nhist, h_, w_)) return x - def add_residual(self, x, dx): - if self.learn_residual: - if self.residual_scale is not None: - dx = dx * self.residual_scale - - # add residual: deal with history - x = self.expand_history(x, nhist=self.n_history + 1) - x[:, -1, ...] = x[:, -1, ...] + dx - x = self.flatten_history(x) - else: - x = dx - - return x - def add_static_features(self, x): if self.do_add_static_features: # we need to replicate the grid for each batch: diff --git a/makani/models/stepper.py b/makani/models/stepper.py index 565402a..f7edfea 100644 --- a/makani/models/stepper.py +++ b/makani/models/stepper.py @@ -49,9 +49,6 @@ def forward(self, inp, update_state=True, replace_state=True): # undo normalization y = self.preprocessor.history_denormalize(yn, target=True) - # add residual (for residual learning, no-op for direct learning - y = self.preprocessor.add_residual(inp, y) - return y @@ -60,7 +57,6 @@ def __init__(self, params, model_handle): super().__init__() self.preprocessor = Preprocessor2D(params) self.model = model_handle() - self.residual_mode = True if (params.target == "target") else False self.push_forward_mode = params.get("multistep_push_forward", False) # collect parameters for history @@ -102,9 +98,6 @@ def _forward_train(self, inp, update_state=True, replace_state=True): # will have been updated later: pred = self.preprocessor.history_denormalize(predn, target=True) - # add residual (for residual learning, no-op for direct learning - pred = self.preprocessor.add_residual(inpt, pred) - # append output result.append(pred) @@ -148,9 +141,6 @@ def _forward_eval(self, inp, update_state=True, replace_state=True): # because otherwise normalization stats are already outdated y = self.preprocessor.history_denormalize(yn, target=True) - # add residual (for residual learning, no-op for direct learning - y = self.preprocessor.add_residual(inp, y) - return y def forward(self, inp, update_state=True, replace_state=True): diff --git a/makani/mpu/layers.py b/makani/mpu/layers.py index 18f470b..dd67a96 100644 --- a/makani/mpu/layers.py +++ b/makani/mpu/layers.py @@ -266,6 +266,134 @@ def forward(self, x): else: return self.fwd(x) +# Stochastic MLP needs comm datastructure +class StochasticMLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + output_bias=True, + input_format="nchw", + drop_rate=0.0, + drop_type="iid", + checkpointing=False, + gain=1.0, + seed=333, + **kwargs, + ): + super().__init__() + + self.checkpointing = checkpointing + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + # generator objects: + self.set_rng(seed=seed) + + # First fully connected layer + if input_format == "nchw": + self.fc1_weight_std = nn.Parameter(torch.zeros(hidden_features, in_features, 1, 1)) + self.fc1_weight_mean = nn.Parameter(torch.zeros(hidden_features, in_features, 1, 1)) + self.fc1_bias = nn.Parameter(torch.zeros(hidden_features)) + else: + raise NotImplementedError(f"Error, input format {input_format} not supported.") + + # sharing settings + self.fc1_weight_std.is_shared_mp = ["spatial"] + self.fc1_weight_mean.is_shared_mp = ["spatial"] + self.fc1_bias.is_shared_mp = ["spatial"] + + # initialize the weights correctly + scale = math.sqrt(1.0 / in_features) + nn.init.normal_(self.fc1_weight_std, mean=0.0, std=scale) + nn.init.normal_(self.fc1_weight_mean, mean=0.0, std=scale) + + # activation + self.act = act_layer() + + # sanity checks + if (input_format == "traditional") and (drop_type == "features"): + raise NotImplementedError(f"Error, traditional input format and feature dropout cannot be selected simultaneously") + + # output layer + if input_format == "nchw": + self.fc2_weight_std = nn.Parameter(torch.zeros(out_features, hidden_features, 1, 1)) + self.fc2_weight_mean = nn.Parameter(torch.zeros(out_features, hidden_features, 1, 1)) + self.fc2_bias = nn.Parameter(torch.zeros(out_features)) if output_bias else None + else: + raise NotImplementedError(f"Error, input format {input_format} not supported.") + + # sharing settings + self.fc2_weight_std.is_shared_mp = ["spatial"] + self.fc2_weight_mean.is_shared_mp = ["spatial"] + if self.fc2_bias is not None: + self.fc2_bias.is_shared_mp = ["spatial"] + + # gain factor for the output determines the scaling of the output init + scale = math.sqrt(gain / hidden_features / 2) + nn.init.normal_(self.fc2_weight_std, mean=0.0, std=scale) + nn.init.normal_(self.fc2_weight_mean, mean=0.0, std=scale) + if self.fc2_bias is not None: + nn.init.constant_(self.fc2_bias, 0.0) + + if drop_rate > 0.0: + if drop_type == "iid": + self.drop = nn.Dropout(drop_rate) + elif drop_type == "features": + self.drop = nn.Dropout2d(drop_rate) + else: + raise NotImplementedError(f"Error, drop_type {drop_type} not supported") + else: + self.drop = nn.Identity() + + @torch.compiler.disable(recursive=False) + def set_rng(self, seed=333): + self.rng_cpu = torch.Generator(device=torch.device("cpu")) + self.rng_cpu.manual_seed(seed) + if torch.cuda.is_available(): + self.rng_gpu = torch.Generator(device=torch.device(f"cuda:{comm.get_local_rank()}")) + self.rng_gpu.manual_seed(seed) + + @torch.compiler.disable(recursive=False) + def checkpoint_forward(self, x): + return checkpoint(self.fwd, x, use_reentrant=False) + + def fwd(self, x): + + # generate weight1 + weight1 = torch.empty_like(self.fc1_weight_mean) + weight1.normal_(mean=0.0, std=1.0, generator=self.rng_gpu if weight1.is_cuda else self.rng_cpu) + weight1 = self.fc1_weight_std * weight1 + self.fc1_weight_mean + + # fully connected 1 + x = nn.functional.conv2d(x, weight1, bias=self.fc1_bias) + + # activation + x = self.act(x) + + # dropout + x = self.drop(x) + + # generate weight1 + weight2 = torch.empty_like(self.fc2_weight_mean) + weight2.normal_(mean=0.0, std=1.0, generator=self.rng_gpu if weight2.is_cuda else self.rng_cpu) + weight2 = self.fc2_weight_std * weight2 + self.fc2_weight_mean + + # fully connected 2 + x = nn.functional.conv2d(x, weight2, bias=self.fc2_bias) + + # dropout + x = self.drop(x) + + return x + + def forward(self, x): + if self.checkpointing: + return self.checkpoint_forward(x) + else: + return self.fwd(x) class DistributedPatchEmbed(nn.Module): def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768, input_is_matmul_parallel=False, output_is_matmul_parallel=True): diff --git a/makani/utils/loss.py b/makani/utils/loss.py index a3776f1..7263f81 100644 --- a/makani/utils/loss.py +++ b/makani/utils/loss.py @@ -48,6 +48,7 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps self.rank = comm.get_rank("matmul") self.n_future = params.n_future + self.n_history = params.n_history self.spatial_distributed = comm.is_distributed("spatial") and (comm.get_size("spatial") > 1) self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) @@ -90,12 +91,16 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps # create module list self.loss_fn = nn.ModuleList([]) + self.loss_requires_input = [] # track which losses need input state channel_weights = [] for loss in losses: loss_type = loss["type"] + # check if this is a tendency loss (from explicit field, not string parsing) + requires_input = loss.get("tendency", False) + # get pole mask if it was specified pole_mask = loss.get("pole_mask", 0) @@ -120,6 +125,7 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps # append to dict and compile before: self.loss_fn.append(loss_fn) + self.loss_requires_input.append(requires_input) # determine channel weighting if "channel_weights" not in loss.keys(): @@ -287,7 +293,23 @@ def reset_running_stats(self): self.running_var.fill_(1) self.num_batches_tracked.zero_() - def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None): + def _extract_input_state(self, inp: torch.Tensor) -> torch.Tensor: + """ + Extract last timestep from flattened history input. + + Args: + inp: Input tensor with shape (B, (n_history+1)*C, H, W) + + Returns: + Last timestep with shape (B, C, H, W) + """ + # inp shape: (B, (n_history+1)*C, H, W) + # we want: (B, C, H, W) - the last timestep + n_channels_per_step = inp.shape[1] // (self.n_history + 1) + inp_last = inp[..., -n_channels_per_step:, :, :] + return inp_last + + def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None, inp: Optional[torch.Tensor] = None): # we assume the following: # if prd is 5D, we assume that the dims are # batch, ensemble, channel, h, w @@ -325,13 +347,48 @@ def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tens else: prdm = prd + # transform to tendency space if any loss requires it + if inp is not None and any(self.loss_requires_input): + inp_state = self._extract_input_state(inp) + + # validate channel counts for single-step predictions + if self.n_future == 0: + n_pred_channels = prdm.shape[1] + n_inp_channels = inp_state.shape[1] + assert n_pred_channels == n_inp_channels, \ + f"Channel mismatch: prediction has {n_pred_channels} channels but input has {n_inp_channels} channels" + + # transform predictions and targets to tendency space + # this allows ANY loss function to compute tendency-based metrics + prdm_tendency = prdm - inp_state + tar_tendency = tar - inp_state + + # also transform ensemble predictions if present + if prd.dim() == 5: + # expand inp_state to match ensemble dim + inp_state_expanded = inp_state.unsqueeze(1) + prd_tendency = prd - inp_state_expanded + else: + prd_tendency = prdm_tendency + else: + prdm_tendency = prdm + tar_tendency = tar + prd_tendency = prd + # compute loss contributions from each loss loss_vals = [] - for lfn in self.loss_fn: + for lfn, requires_inp in zip(self.loss_fn, self.loss_requires_input): if lfn.type == LossType.Deterministic: - loss_vals.append(lfn(prdm, tar, wgt)) + if requires_inp: + loss_vals.append(lfn(prdm_tendency, tar_tendency, wgt)) + else: + loss_vals.append(lfn(prdm, tar, wgt)) else: - loss_vals.append(lfn(prd, tar, wgt)) + # probabilistic losses: use tendency-transformed ensemble if needed + if requires_inp: + loss_vals.append(lfn(prd_tendency, tar_tendency, wgt)) + else: + loss_vals.append(lfn(prd, tar, wgt)) all_losses = torch.cat(loss_vals, dim=-1) # print(all_losses) diff --git a/makani/utils/losses/amse_loss.py b/makani/utils/losses/amse_loss.py index 89db8fb..8ae392c 100644 --- a/makani/utils/losses/amse_loss.py +++ b/makani/utils/losses/amse_loss.py @@ -61,7 +61,7 @@ def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tens with amp.autocast(device_type="cuda", enabled=False): xcoeffs = self.sht(prd) ycoeffs = self.sht(tar) - + # compute the SHT: xcoeffssq = torch.square(torch.abs(xcoeffs)) ycoeffssq = torch.square(torch.abs(ycoeffs)) @@ -100,5 +100,5 @@ def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tens loss = torch.sum(loss, dim=-1) if self.spatial_distributed and (comm.get_size("h") > 1): loss = reduce_from_parallel_region(loss, "h") - + return loss diff --git a/makani/utils/losses/h1_loss.py b/makani/utils/losses/h1_loss.py index 0cac68b..75acff4 100644 --- a/makani/utils/losses/h1_loss.py +++ b/makani/utils/losses/h1_loss.py @@ -134,7 +134,7 @@ def rel(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] tar_norm2 = 2 * torch.sum(tar_coeffssq, dim=-1) if self.spatial_distributed and (comm.get_size("w") > 1): tar_norm2 = reduce_from_parallel_region(tar_norm2, "w") - + # compute target norms tar_norm2 = tar_norm2.reshape(B, C, -1) tar_h1_norm2 = torch.sum(tar_norm2 * self.h1_weights, dim=-1) diff --git a/makani/utils/training/deterministic_trainer.py b/makani/utils/training/deterministic_trainer.py index b3e149d..f69761b 100644 --- a/makani/utils/training/deterministic_trainer.py +++ b/makani/utils/training/deterministic_trainer.py @@ -243,8 +243,10 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] = # visualization wrapper: with Timer() as timer: - plot_channel = "q50" - plot_index = self.params.channel_names.index(plot_channel) + plot_channel = "z500" + # plot_channel = "q50" + # plot_index = self.params.channel_names.index(plot_channel) + plot_index = 0 print(self.params.channel_names) plot_list = [{"name": plot_channel, "functor": f"lambda x: x[{plot_index}, ...]", "diverging": False}] out_bias, out_scale = self.train_dataloader.get_output_normalization() @@ -532,12 +534,12 @@ def train_one_epoch(self, profiler=None): if do_update: # regular forward pass including DDP pred = self.model_train(inp) - loss = self.loss_obj(pred, tar) + loss = self.loss_obj(pred, tar, inp=inp) else: # disable sync step with self.model_train.no_sync(): pred = self.model_train(inp) - loss = self.loss_obj(pred, tar) + loss = self.loss_obj(pred, tar, inp=inp) loss = loss * loss_scaling_fact # backward pass diff --git a/makani/utils/training/ensemble_trainer.py b/makani/utils/training/ensemble_trainer.py index 752b566..2f6a702 100644 --- a/makani/utils/training/ensemble_trainer.py +++ b/makani/utils/training/ensemble_trainer.py @@ -248,8 +248,10 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] = # visualization wrapper: with Timer() as timer: - plot_channel = "q50" - plot_index = self.params.channel_names.index(plot_channel) + plot_channel = "z500" + # plot_channel = "q50" + # plot_index = self.params.channel_names.index(plot_channel) + plot_index = 0 plot_list = [{"name": plot_channel, "functor": f"lambda x: x[{plot_index}, ...]", "diverging": False}] out_bias, out_scale = self.train_dataloader.get_output_normalization() self.visualizer = visualize.VisualizationWrapper( @@ -493,7 +495,7 @@ def _ensemble_step(self, inp: torch.Tensor, tar: torch.Tensor): # stack predictions along new dim (ensemble dim): pred = torch.stack(predlist, dim=1) # compute loss - loss = self.loss_obj(pred, tar) + loss = self.loss_obj(pred, tar, inp=inp) return pred, loss diff --git a/makani/utils/training/stochastic_trainer.py b/makani/utils/training/stochastic_trainer.py index 9d0ae1e..519cd37 100644 --- a/makani/utils/training/stochastic_trainer.py +++ b/makani/utils/training/stochastic_trainer.py @@ -506,11 +506,11 @@ def train_one_epoch(self): with amp.autocast(device_type="cuda", enabled=self.amp_enabled, dtype=self.amp_dtype): if do_update: pred, tar = self.model_train(inp, tar, n_samples=self.params.stochastic_size) - loss = self.loss_obj(pred, tar) + loss = self.loss_obj(pred, tar, inp=inp) else: with self.model_train.no_sync(): pred, tar = self.model_train(inp, tar, n_samples=self.params.stochastic_size) - loss = self.loss_obj(pred, tar) + loss = self.loss_obj(pred, tar, inp=inp) loss = loss * loss_scaling_fact self.gscaler.scale(loss).backward() diff --git a/tests/distributed/distributed_helpers.py b/tests/distributed/distributed_helpers.py index 334f269..303383a 100644 --- a/tests/distributed/distributed_helpers.py +++ b/tests/distributed/distributed_helpers.py @@ -61,7 +61,6 @@ def get_default_parameters(): params.N_in_channels = len(params.in_channels) params.N_out_channels = len(params.out_channels) - params.target = "default" params.batch_size = 1 params.valid_autoreg_steps = 0 params.num_data_workers = 1 @@ -93,7 +92,7 @@ def split_helper(tensor, dim=None, group=None): tensor_local = tensor_list_local[grank] else: tensor_local = tensor.clone() - + return tensor_local @@ -117,5 +116,5 @@ def gather_helper(tensor, dim=None, group=None): tensor_gather = torch.cat(tens_gather, dim=dim) else: tensor_gather = tensor.clone() - + return tensor_gather diff --git a/tests/testutils.py b/tests/testutils.py index dd78b08..86f6d85 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -62,7 +62,6 @@ def get_default_parameters(): params.N_in_channels = len(params.in_channels) params.N_out_channels = len(params.out_channels) - params.target = "default" params.batch_size = 1 params.valid_autoreg_steps = 0 params.num_data_workers = 1 From 756a51b362b7ca55f31670467438f923bc16313b Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Mon, 24 Nov 2025 12:40:58 -0800 Subject: [PATCH 19/32] updating other configs --- config/icml_models.yaml | 3 --- config/vit.yaml | 3 --- 2 files changed, 6 deletions(-) diff --git a/config/icml_models.yaml b/config/icml_models.yaml index e135a7f..ddde08e 100644 --- a/config/icml_models.yaml +++ b/config/icml_models.yaml @@ -66,9 +66,6 @@ base_config: &BASE_CONFIG N_grid_channels: 0 gridtype: "sinusoidal" #options "sinusoidal" or "linear" - #options default, residual - target: "default" - channel_names: ["u10m", "v10m", "t2m", "sp", "msl", "t850", "u1000", "v1000", "z1000", "u850", "v850", "z850", "u500", "v500", "z500", "t500", "z50", "r500", "r850", "tcwv", "u100m", "v100m", "u250", "v250", "z250", "t250", "u100", "v100", "z100", "t100", "u900", "v900", "z900", "t900"] normalization: "zscore" #options zscore or minmax or none diff --git a/config/vit.yaml b/config/vit.yaml index 58fcee0..fc2571c 100644 --- a/config/vit.yaml +++ b/config/vit.yaml @@ -65,9 +65,6 @@ full_field: &BASE_CONFIG embed_dim: 384 normalization_layer: "layer_norm" - #options default, residual - target: "default" - channel_names: ["u10m", "v10m", "u100m", "v100m", "t2m", "sp", "msl", "tcwv", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"] normalization: "zscore" #options zscore or minmax hard_thresholding_fraction: 1.0 From 092e7a54c91b20e9e4910789703dcc2daa84bb98 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Mon, 24 Nov 2025 14:04:43 -0800 Subject: [PATCH 20/32] added routine to compute spherical bandlimit --- makani/utils/grids.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/makani/utils/grids.py b/makani/utils/grids.py index 5208529..4446fff 100644 --- a/makani/utils/grids.py +++ b/makani/utils/grids.py @@ -33,6 +33,20 @@ def grid_to_quadrature_rule(grid_type): return grid_to_quad_dict[grid_type] +def compute_spherical_bandlimit(img_shape, grid_type): + + if grid_type == "equiangular": + lmax = (img_shape[0] - 1) // 2 + mmax = img_shape[1] // 2 + return min(lmax, mmax) + elif grid_type == "legendre-gauss": + lmax = img_shape[0] - 1 + mmax = img_shape[1] // 2 + return min(lmax, mmax) + else: + raise NotImplementedError(f"Unknown type {grid_type} not implemented") + + class GridConverter(torch.nn.Module): def __init__(self, src_grid, dst_grid, lat_rad, lon_rad): super(GridConverter, self).__init__() From df70e1ba824402742c222c59283b9680ccd9ec91 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Tue, 2 Dec 2025 00:39:52 -0800 Subject: [PATCH 21/32] changes to registry to pass normalization parameters --- makani/models/model_registry.py | 22 ++- makani/utils/loss.py | 16 +- makani/utils/losses/__init__.py | 7 +- makani/utils/losses/base_loss.py | 79 ++++++++- makani/utils/losses/crps_loss.py | 198 ++++++++++++++-------- makani/utils/losses/lp_loss.py | 98 ++++++----- makani/utils/training/ensemble_trainer.py | 3 +- 7 files changed, 296 insertions(+), 127 deletions(-) diff --git a/makani/models/model_registry.py b/makani/models/model_registry.py index 5b860a3..6d1a8ff 100644 --- a/makani/models/model_registry.py +++ b/makani/models/model_registry.py @@ -166,7 +166,27 @@ def get_model(params: ParamsBase, use_stochastic_interpolation: bool = False, mu if isinstance(model_handle, (EntryPoint, importlib_metadata.EntryPoint)): model_handle = model_handle.load() - model_handle = partial(model_handle, inp_shape=inp_shape, out_shape=out_shape, inp_chans=inp_chans, out_chans=out_chans, **params.to_dict()) + model_kwargs = params.to_dict() + + # pass normalization statistics to the model + if params.get("normalization", "none") in ["zscore", "minmax"]: + try: + bias, scale = get_data_normalization(params) + # Slice the stats to match the model's output channels + # Assuming the model's output corresponds to params.out_channels + if hasattr(params, "out_channels"): + if bias is not None: + bias = bias.flatten()[params.out_channels] + if scale is not None: + scale = scale.flatten()[params.out_channels] + + if bias is not None and scale is not None: + model_kwargs["normalization_means"] = bias + model_kwargs["normalization_stds"] = scale + except Exception as e: + logging.warning(f"Could not load normalization stats. Error: {e}") + + model_handle = partial(model_handle, inp_shape=inp_shape, out_shape=out_shape, inp_chans=inp_chans, out_chans=out_chans, **model_kwargs) else: raise KeyError(f"No model is registered under the name {name}") diff --git a/makani/utils/loss.py b/makani/utils/loss.py index 7263f81..8ab1afb 100644 --- a/makani/utils/loss.py +++ b/makani/utils/loss.py @@ -31,10 +31,11 @@ import torch_harmonics as harmonics from torch_harmonics.quadrature import clenshaw_curtiss_weights, legendre_gauss_weights -from .losses import LossType, GeometricLpLoss, SpectralH1Loss, SpectralAMSELoss -from .losses import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleVortDivCRPSLoss, EnergyScoreLoss +from .losses import LossType, GeometricLpLoss, SpectralLpLoss, SpectralH1Loss, SpectralAMSELoss +from .losses import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleGradientCRPSLoss, EnsembleVortDivCRPSLoss, EnergyScoreLoss from .losses import EnsembleNLLLoss, EnsembleMMDLoss from .losses import DriftRegularization, HydrostaticBalanceLoss +from .losses import RandomizedKernelCRPS class LossHandler(nn.Module): @@ -197,13 +198,18 @@ def _parse_loss_type(self, loss_type: str): loss_type = set(loss_type.split()) + # this can probably all be moved to the loss function itself relative = "relative" in loss_type squared = "squared" in loss_type jacobian = "s2" if "geometric" in loss_type else "flat" # decide which loss to use - if "l2" in loss_type: + if "spectral" in loss_type and "l2" in loss_type: + loss_handle = partial(SpectralLpLoss, p=2, relative=relative, squared=squared) + elif "spectral" in loss_type and "l1" in loss_type: + loss_handle = partial(SpectralLpLoss, p=1, relative=relative, squared=squared) + elif "l2" in loss_type: loss_handle = partial(GeometricLpLoss, p=2, relative=relative, squared=squared, jacobian=jacobian) elif "l1" in loss_type: loss_handle = partial(GeometricLpLoss, p=1, relative=relative, squared=squared, jacobian=jacobian) @@ -229,6 +235,8 @@ def _parse_loss_type(self, loss_type: str): loss_handle = partial(EnsembleSpectralCRPSLoss) elif "ensemble_vort_div_crps" in loss_type: loss_handle = partial(EnsembleVortDivCRPSLoss) + elif "ensemble_gradient_crps" in loss_type: + loss_handle = partial(EnsembleGradientCRPSLoss) elif "ensemble_nll" in loss_type: loss_handle = EnsembleNLLLoss elif "ensemble_mmd" in loss_type: @@ -237,6 +245,8 @@ def _parse_loss_type(self, loss_type: str): loss_handle = partial(EnergyScoreLoss) elif "drift_regularization" in loss_type: loss_handle = DriftRegularization + elif "randomized_kernel" in loss_type: + loss_handle = RandomizedKernelCRPS else: raise NotImplementedError(f"Unknown loss function: {loss_type}") diff --git a/makani/utils/losses/__init__.py b/makani/utils/losses/__init__.py index fff3e7a..c13c236 100644 --- a/makani/utils/losses/__init__.py +++ b/makani/utils/losses/__init__.py @@ -15,11 +15,12 @@ from .base_loss import LossType from .h1_loss import SpectralH1Loss -from .lp_loss import GeometricLpLoss, SpectralL2Loss +from .lp_loss import GeometricLpLoss, SpectralLpLoss from .amse_loss import SpectralAMSELoss from .hydrostatic_loss import HydrostaticBalanceLoss -from .crps_loss import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleVortDivCRPSLoss -from .crps_loss import EnergyScoreLoss +from .crps_loss import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleGradientCRPSLoss, EnsembleVortDivCRPSLoss +from .energy_score import EnergyScoreLoss from .mmd_loss import EnsembleMMDLoss from .likelihood_loss import EnsembleNLLLoss from .drift_regularization import DriftRegularization +from .randomized_crps import RandomizedKernelCRPS diff --git a/makani/utils/losses/base_loss.py b/makani/utils/losses/base_loss.py index 5eb8a3c..171fb0a 100644 --- a/makani/utils/losses/base_loss.py +++ b/makani/utils/losses/base_loss.py @@ -323,6 +323,7 @@ def __init__( crop_offset: Tuple[int, int], channel_names: List[str], grid_type: str, + lmax: Optional[int] = None, spatial_distributed: Optional[bool] = False, ): super().__init__() @@ -339,9 +340,9 @@ def __init__( polar_group = None if (comm.get_size("h") == 1) else comm.get_group("h") azimuth_group = None if (comm.get_size("w") == 1) else comm.get_group("w") thd.init(polar_group, azimuth_group) - self.sht = thd.DistributedRealSHT(*img_shape, grid=grid_type) + self.sht = thd.DistributedRealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type) else: - self.sht = th.RealSHT(*img_shape, grid=grid_type).float() + self.sht = th.RealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type).float() @property def type(self): @@ -373,6 +374,7 @@ def __init__( channel_names: List[str], grid_type: str, pole_mask: int, + lmax: Optional[int] = None, spatial_distributed: Optional[bool] = False, ): super().__init__() @@ -393,11 +395,11 @@ def __init__( polar_group = None if (comm.get_size("h") == 1) else comm.get_group("h") azimuth_group = None if (comm.get_size("w") == 1) else comm.get_group("w") thd.init(polar_group, azimuth_group) - self.vsht = thd.DistributedRealVectorSHT(*img_shape, grid=grid_type) - self.isht = thd.DistributedInverseRealVectorSHT(nlat=self.vsht.nlat, nlon=self.vsht.nlon, grid=grid_type) + self.vsht = thd.DistributedRealVectorSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type) + self.isht = thd.DistributedInverseRealVectorSHT(nlat=self.vsht.nlat, nlon=self.vsht.nlon, lmax=lmax, mmax=lmax, grid=grid_type) else: - self.vsht = th.RealVectorSHT(*img_shape, grid=grid_type) - self.isht = th.InverseRealVectorSHT(nlat=self.vsht.nlat, nlon=self.vsht.nlon, grid=grid_type) + self.vsht = th.RealVectorSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type) + self.isht = th.InverseRealVectorSHT(nlat=self.vsht.nlat, nlon=self.vsht.nlon, lmax=lmax, mmax=lmax, grid=grid_type) # get the quadrature rule for the corresponding grid quadrature_rule = grid_to_quadrature_rule(grid_type) @@ -434,3 +436,68 @@ def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: t @abstractmethod def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None) -> torch.Tensor: pass + +class GradientBaseLoss(nn.Module, metaclass=ABCMeta): + """ + Gradient base loss class used by all gradient based losses + """ + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + pole_mask: int, + lmax: Optional[int] = None, + spatial_distributed: Optional[bool] = False, + ): + super().__init__() + + self.img_shape = img_shape + self.crop_shape = crop_shape + self.crop_offset = crop_offset + self.channel_names = channel_names + self.pole_mask = pole_mask + self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + + if self.spatial_distributed and (comm.get_size("spatial") > 1): + if not thd.is_initialized(): + polar_group = None if (comm.get_size("h") == 1) else comm.get_group("h") + azimuth_group = None if (comm.get_size("w") == 1) else comm.get_group("w") + thd.init(polar_group, azimuth_group) + self.sht = thd.DistributedRealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type) + self.ivsht = thd.DistributedInverseRealVectorSHT(nlat=self.sht.nlat, nlon=self.sht.nlon, lmax=lmax, mmax=lmax, grid=grid_type) + else: + self.sht = th.RealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type) + self.ivsht = th.InverseRealVectorSHT(nlat=self.sht.nlat, nlon=self.sht.nlon, lmax=lmax, mmax=lmax, grid=grid_type) + + # get the quadrature rule for the corresponding grid + quadrature_rule = grid_to_quadrature_rule(grid_type) + self.quadrature = GridQuadrature( + quadrature_rule, + img_shape=self.img_shape, + crop_shape=self.crop_shape, + crop_offset=self.crop_offset, + normalize=True, + pole_mask=self.pole_mask, + distributed=self.spatial_distributed, + ) + + @property + def type(self): + return LossType.Deterministic + + @property + def n_channels(self): + return len(self.n_channels) + + @torch.compiler.disable(recursive=False) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: torch.Tensor = None) -> torch.Tensor: + return _compute_channel_weighting_helper(self.channel_names, channel_weight_type, time_diff_scale=time_diff_scale) + + + @abstractmethod + def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None) -> torch.Tensor: + pass diff --git a/makani/utils/losses/crps_loss.py b/makani/utils/losses/crps_loss.py index 3ef5f55..3657aa4 100644 --- a/makani/utils/losses/crps_loss.py +++ b/makani/utils/losses/crps_loss.py @@ -22,7 +22,7 @@ import torch.nn as nn from torch import amp -from makani.utils.losses.base_loss import GeometricBaseLoss, SpectralBaseLoss, VortDivBaseLoss, LossType +from makani.utils.losses.base_loss import LossType, GeometricBaseLoss, SpectralBaseLoss, VortDivBaseLoss, GradientBaseLoss from makani.utils import comm # distributed stuff @@ -431,6 +431,7 @@ def __init__( crop_offset: Tuple[int, int], channel_names: List[str], grid_type: str, + lmax: Optional[int] = None, crps_type: str = "skillspread", spatial_distributed: Optional[bool] = False, ensemble_distributed: Optional[bool] = False, @@ -447,6 +448,7 @@ def __init__( crop_offset=crop_offset, channel_names=channel_names, grid_type=grid_type, + lmax=lmax, spatial_distributed=spatial_distributed, ) @@ -470,8 +472,8 @@ def __init__( # get the local l weights lmax = self.sht.lmax ls = torch.arange(lmax).reshape(-1, 1) - # l_weights = 1 / (2*ls+1) - l_weights = torch.ones(lmax).reshape(-1, 1) + l_weights = 1 / (2*ls+1) + # l_weights = torch.ones(lmax).reshape(-1, 1) if comm.get_size("h") > 1: l_weights = split_tensor_along_dim(l_weights, dim=-2, num_chunks=comm.get_size("h"))[comm.get_rank("h")] self.register_buffer("l_weights", l_weights, persistent=False) @@ -606,7 +608,7 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spectral_ # the resulting tensor should have dimension B, C, which is what we return return crps -class EnsembleVortDivCRPSLoss(VortDivBaseLoss): +class EnsembleGradientCRPSLoss(GradientBaseLoss): def __init__( self, @@ -616,10 +618,12 @@ def __init__( channel_names: List[str], grid_type: str, pole_mask: int, + lmax: Optional[int] = None, crps_type: str = "skillspread", spatial_distributed: Optional[bool] = False, ensemble_distributed: Optional[bool] = False, ensemble_weights: Optional[torch.Tensor] = None, + absolute: Optional[bool] = True, alpha: Optional[float] = 1.0, eps: Optional[float] = 1.0e-5, **kwargs, @@ -632,9 +636,13 @@ def __init__( channel_names=channel_names, grid_type=grid_type, pole_mask=pole_mask, + lmax=lmax, spatial_distributed=spatial_distributed, ) + # if absolute is true, the loss is computed only on the absolute value of the gradient + self.absolute = absolute + self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed self.crps_type = crps_type @@ -660,6 +668,22 @@ def __init__( def type(self): return LossType.Probabilistic + @property + def n_channels(self): + if self.absolute: + return len(self.channel_names) + else: + return 2 * len(self.channel_names) + + @torch.compiler.disable(recursive=False) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: torch.Tensor = None) -> torch.Tensor: + chw = super().compute_channel_weighting(channel_weight_type, time_diff_scale=time_diff_scale) + + if self.absolute: + return chw + else: + return [weight for weight in chw for _ in range(2)] + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: # sanity checks @@ -675,12 +699,7 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w # we assume the following shapes: # forecasts: batch, ensemble, channels, lat, lon # observations: batch, channels, lat, lon - B, E, _, H, W = forecasts.shape - C = self.wind_chans.shape[0] - - # extract wind channels - forecasts = forecasts[..., self.wind_chans, :, :].reshape(B, E, C//2, 2, H, W) - observations = observations[..., self.wind_chans, :, :].reshape(B, C//2, 2, H, W) + B, E, C, H, W = forecasts.shape # get the data type before stripping amp types dtype = forecasts.dtype @@ -688,10 +707,27 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w # before anything else compute the transform # as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same with amp.autocast(device_type="cuda", enabled=False): - forecasts = self.isht(self.vsht(forecasts.float())) - observations = self.isht(self.vsht(observations.float())) - # extract wind channels + # compute the SH coefficients of the forecasts and observations + forecasts = self.sht(forecasts.float()).unsqueeze(-3) + observations = self.sht(observations.float()).unsqueeze(-3) + + # append zeros, so that we can use the inverse vector SHT + forecasts = torch.cat([forecasts, torch.zeros_like(forecasts)], dim=-3) + observations = torch.cat([observations, torch.zeros_like(observations)], dim=-3) + + forecasts = self.ivsht(forecasts) + observations = self.ivsht(observations) + + forecasts = forecasts.to(dtype) + observations = observations.to(dtype) + + if self.absolute: + forecasts = forecasts.pow(2).sum(dim=-3).sqrt() + observations = observations.pow(2).sum(dim=-3).sqrt() + else: + C = 2 * C + forecasts = forecasts.reshape(B, E, C, H, W) observations = observations.reshape(B, C, H, W) @@ -766,7 +802,7 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w # the resulting tensor should have dimension B, C, which is what we return return crps -class EnergyScoreLoss(GeometricBaseLoss): +class EnsembleVortDivCRPSLoss(VortDivBaseLoss): def __init__( self, @@ -776,11 +812,11 @@ def __init__( channel_names: List[str], grid_type: str, pole_mask: int, + crps_type: str = "skillspread", spatial_distributed: Optional[bool] = False, ensemble_distributed: Optional[bool] = False, ensemble_weights: Optional[torch.Tensor] = None, alpha: Optional[float] = 1.0, - beta: Optional[float] = 1.0, eps: Optional[float] = 1.0e-5, **kwargs, ): @@ -797,10 +833,13 @@ def __init__( self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed + self.crps_type = crps_type self.alpha = alpha - self.beta = beta self.eps = eps + if (self.crps_type != "skillspread") and (self.alpha < 1.0): + raise NotImplementedError("The alpha parameter (almost fair CRPS factor) is only supported for the skillspread kernel.") + # we also need a variant of the weights split in ensemble direction: quad_weight_split = self.quadrature.quad_weight.reshape(1, 1, -1) if self.ensemble_distributed: @@ -817,14 +856,6 @@ def __init__( def type(self): return LossType.Probabilistic - @property - def n_channels(self): - return 1 - - @torch.compiler.disable(recursive=False) - def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: str) -> torch.Tensor: - return torch.ones(1) - def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: # sanity checks @@ -840,72 +871,93 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w # we assume the following shapes: # forecasts: batch, ensemble, channels, lat, lon # observations: batch, channels, lat, lon - B, E, C, H, W = forecasts.shape + B, E, _, H, W = forecasts.shape + C = self.wind_chans.shape[0] - # transpose the forecasts to ensemble, batch, channels, lat, lon and then do distributed transpose into ensemble direction. - # ideally we split spatial dims - forecasts = torch.moveaxis(forecasts, 1, 0) - forecasts = forecasts.reshape(E, B, C, H * W) - if self.ensemble_distributed: - ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] - forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") + # extract wind channels + forecasts = forecasts[..., self.wind_chans, :, :].reshape(B, E, C//2, 2, H, W) + observations = observations[..., self.wind_chans, :, :].reshape(B, C//2, 2, H, W) - # observations does not need a transpose, but just a split - observations = observations.reshape(1, B, C, H * W) - if self.ensemble_distributed: - observations = scatter_to_parallel_region(observations, -1, "ensemble") + # get the data type before stripping amp types + dtype = forecasts.dtype - # for correct spatial reduction we need to do the same with spatial weights - if spatial_weights is not None: - spatial_weights_split = spatial_weights.flatten(start_dim=-2, end_dim=-1) - spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") + # before anything else compute the transform + # as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same + with amp.autocast(device_type="cuda", enabled=False): + forecasts = self.isht(self.vsht(forecasts.float())) + observations = self.isht(self.vsht(observations.float())) - if self.ensemble_weights is not None: - raise NotImplementedError("currently only constant ensemble weights are supported") + # extract wind channels + forecasts = forecasts.reshape(B, E, C, H, W) + observations = observations.reshape(B, C, H, W) + + # if ensemble dim is one dimensional then computing the score is quick: + if (not self.ensemble_distributed) and (E == 1): + # in this case, CRPS is straightforward + crps = torch.abs(observations - forecasts.squeeze(1)).reshape(B, C, H * W) else: - ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + # transpose forecasts: ensemble, batch, channels, lat, lon + forecasts = torch.moveaxis(forecasts, 1, 0) - # ensemble size - num_ensemble = forecasts.shape[0] + # now we need to transpose the forecasts into ensemble direction. + # ideally we split spatial dims + forecasts = forecasts.reshape(E, B, C, H * W) + if self.ensemble_distributed: + ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] + forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") + # observations does not need a transpose, but just a split + observations = observations.reshape(B, C, H * W) + if self.ensemble_distributed: + observations = scatter_to_parallel_region(observations, -1, "ensemble") + if spatial_weights is not None: + spatial_weights_split = spatial_weights.flatten(start_dim=-2, end_dim=-1) + spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") - # get nanmask from the observarions - nanmasks = torch.logical_or(torch.isnan(observations), torch.isnan(ensemble_weights)) + # run appropriate crps kernel to compute it pointwise + if self.crps_type == "cdf": + # now, E dimension is local and spatial dim is split further + # we need to sort the forecasts now + forecasts, idx = torch.sort(forecasts, dim=0) + if self.ensemble_weights is not None: + ensemble_weights = self.ensemble_weights[idx] + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) - # use broadcasting semantics to compute spread and skill and sum over channels (vector norm) - espread = (forecasts.unsqueeze(1) - forecasts.unsqueeze(0)).abs().pow(self.beta) - eskill = (observations - forecasts).abs().pow(self.beta) + # compute score + crps = _crps_ensemble_kernel(observations, forecasts, ensemble_weights) + elif self.crps_type == "skillspread": + if self.ensemble_weights is not None: + raise NotImplementedError("currently only constant ensemble weights are supported") + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) - # perform masking before any reduction - espread = torch.where(nanmasks.sum(dim=0) != 0, 0.0, espread) - eskill = torch.where(nanmasks.sum(dim=0) != 0, 0.0, eskill) + # compute score + crps = _crps_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) + elif self.crps_type == "gauss": + if self.ensemble_weights is not None: + ensemble_weights = self.ensemble_weights[idx] + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) - # do the channel reduction while ignoring NaNs - # if channel weights are required they should be added here to the reduction - espread = espread.sum(dim=-2, keepdim=True) - eskill = eskill.sum(dim=-2, keepdim=True) + # compute score + crps = _crps_gauss_kernel(observations, forecasts, ensemble_weights, self.eps) + else: + raise ValueError(f"Unknown CRPS crps_type {self.crps_type}") - # do the spatial reduction + # perform ensemble and spatial average of crps score if spatial_weights is not None: - espread = torch.sum(espread * self.quad_weight_split * spatial_weights_split, dim=-1) - eskill = torch.sum(eskill * self.quad_weight_split * spatial_weights_split, dim=-1) + crps = torch.sum(crps * self.quad_weight_split * spatial_weights_split, dim=-1) else: - espread = torch.sum(espread * self.quad_weight_split, dim=-1) - eskill = torch.sum(eskill * self.quad_weight_split, dim=-1) + crps = torch.sum(crps * self.quad_weight_split, dim=-1) # since we split spatial dim into ensemble dim, we need to do an ensemble sum as well if self.ensemble_distributed: - espread = reduce_from_parallel_region(espread, "ensemble") - eskill = reduce_from_parallel_region(eskill, "ensemble") + crps = reduce_from_parallel_region(crps, "ensemble") # we need to do the spatial averaging manually since - # we are not calling the quadrature forward function + # we are not calling he quadrature forward function if self.spatial_distributed: - espread = reduce_from_parallel_region(espread, "spatial") - eskill = reduce_from_parallel_region(eskill, "spatial") - - # now we have reduced everything and need to sum appropriately - espread = espread.sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) - eskill = eskill.sum(dim=0) / float(num_ensemble) + crps = reduce_from_parallel_region(crps, "spatial") - # the resulting tensor should have dimension B, C which is what we return - return eskill - 0.5 * espread \ No newline at end of file + # the resulting tensor should have dimension B, C, which is what we return + return crps diff --git a/makani/utils/losses/lp_loss.py b/makani/utils/losses/lp_loss.py index b6aae83..5fda9b5 100644 --- a/makani/utils/losses/lp_loss.py +++ b/makani/utils/losses/lp_loss.py @@ -22,6 +22,8 @@ from makani.utils import comm +from physicsnemo.distributed.mappings import reduce_from_parallel_region + class GeometricLpLoss(GeometricBaseLoss): """ @@ -112,9 +114,9 @@ def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tens return loss -class SpectralL2Loss(SpectralBaseLoss): +class SpectralLpLoss(SpectralBaseLoss): """ - Computes the geometric L2 loss but using the spherical Harmonic transform + Computes the Lp loss in spectral (SH coefficients) space """ def __init__( @@ -124,6 +126,7 @@ def __init__( crop_offset: Tuple[int, int], channel_names: List[str], grid_type: str, + p: Optional[float] = 2.0, relative: Optional[bool] = False, squared: Optional[bool] = False, spatial_distributed: Optional[bool] = False, @@ -138,6 +141,7 @@ def __init__( spatial_distributed=spatial_distributed, ) + self.p = p self.relative = relative self.squared = squared self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed @@ -145,80 +149,95 @@ def __init__( def abs(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None): B, C, H, W = prd.shape - coeffssq = torch.square(torch.abs(self.sht(prd - tar))) / torch.pi / 4.0 + # compute SH coefficients of the difference + coeffs = self.sht(prd - tar) + + # compute |coeffs|^p (orthonormal convention) + coeffsp = torch.abs(coeffs) ** self.p if wgt is not None: - coeffssq = coeffssq * wgt + coeffsp = coeffsp * wgt + # sum over m: m=0 contributes once, m!=0 contribute twice (due to conjugate symmetry) if comm.get_rank("w") == 0: - norm2 = coeffssq[..., 0] + 2 * torch.sum(coeffssq[..., 1:], dim=-1) + normp = coeffsp[..., 0] + 2 * torch.sum(coeffsp[..., 1:], dim=-1) else: - norm2 = 2 * torch.sum(coeffssq, dim=-1) + normp = 2 * torch.sum(coeffsp, dim=-1) + if self.spatial_distributed and (comm.get_size("w") > 1): - norm2 = reduce_from_parallel_region(norm2, "w") + normp = reduce_from_parallel_region(normp, "w") - # compute norms - norm2 = norm2.reshape(B, C, -1) - norm2 = torch.sum(norm2, dim=-1) + # sum over l (degrees) + normp = normp.reshape(B, C, -1) + normp = torch.sum(normp, dim=-1) if self.spatial_distributed and (comm.get_size("h") > 1): - norm2 = reduce_from_parallel_region(norm2, "h") + normp = reduce_from_parallel_region(normp, "h") + # take p-th root unless squared is True if not self.squared: - norm2 = torch.sqrt(norm2) + normp = normp ** (1.0 / self.p) - return norm2 + return normp def rel(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None): B, C, H, W = prd.shape - coeffssq = torch.square(torch.abs(self.sht(prd - tar))) / torch.pi / 4.0 + # compute SH coefficients of the difference + coeffs = self.sht(prd - tar) + coeffsp = torch.abs(coeffs) ** self.p if wgt is not None: - coeffssq = coeffssq * wgt + coeffsp = coeffsp * wgt - # sum m != 0 coeffs: + # sum m != 0 coeffs for numerator if comm.get_rank("w") == 0: - norm2 = coeffssq[..., 0] + 2 * torch.sum(coeffssq[..., 1:], dim=-1) + normp = coeffsp[..., 0] + 2 * torch.sum(coeffsp[..., 1:], dim=-1) else: - norm2 = 2 * torch.sum(coeffssq, dim=-1) + normp = 2 * torch.sum(coeffsp, dim=-1) + if self.spatial_distributed and (comm.get_size("w") > 1): - norm2 = reduce_from_parallel_region(norm2, "w") + normp = reduce_from_parallel_region(normp, "w") + + # sum over l + normp = normp.reshape(B, C, -1) + normp = torch.sum(normp, dim=-1) - # compute norms - norm2 = norm2.reshape(B, C, -1) - norm2 = torch.sum(norm2, dim=-1) if self.spatial_distributed and (comm.get_size("h") > 1): - norm2 = reduce_from_parallel_region(norm2, "h") + normp = reduce_from_parallel_region(normp, "h") - # target - tar_coeffssq = torch.square(torch.abs(self.sht(tar))) / torch.pi / 4.0 + # compute target norm + tar_coeffs = self.sht(tar) + tar_coeffsp = torch.abs(tar_coeffs) ** self.p if wgt is not None: - tar_coeffssq = tar_coeffssq * wgt + tar_coeffsp = tar_coeffsp * wgt - # sum m != 0 coeffs: + # sum m != 0 coeffs for denominator if comm.get_rank("w") == 0: - tar_norm2 = tar_coeffssq[..., 0] + 2 * torch.sum(tar_coeffssq[..., 1:], dim=-1) + tar_normp = tar_coeffsp[..., 0] + 2 * torch.sum(tar_coeffsp[..., 1:], dim=-1) else: - tar_norm2 = 2 * torch.sum(tar_coeffssq, dim=-1) + tar_normp = 2 * torch.sum(tar_coeffsp, dim=-1) + if self.spatial_distributed and (comm.get_size("w") > 1): - tar_norm2 = reduce_from_parallel_region(tar_norm2, "w") + tar_normp = reduce_from_parallel_region(tar_normp, "w") + + # sum over l + tar_normp = tar_normp.reshape(B, C, -1) + tar_normp = torch.sum(tar_normp, dim=-1) - # compute target norms - tar_norm2 = tar_norm2.reshape(B, C, -1) - tar_norm2 = torch.sum(tar_norm2, dim=-1) if self.spatial_distributed and (comm.get_size("h") > 1): - tar_norm2 = reduce_from_parallel_region(tar_norm2, "h") + tar_normp = reduce_from_parallel_region(tar_normp, "h") + # take p-th root unless squared is True if not self.squared: - diff_norms = torch.sqrt(norm2) - tar_norms = torch.sqrt(tar_norm2) + diff_norms = normp ** (1.0 / self.p) + tar_norms = tar_normp ** (1.0 / self.p) else: - diff_norms = norm2 - tar_norms = tar_norm2 + diff_norms = normp + tar_norms = tar_normp - # setup return value + # compute relative error retval = diff_norms / tar_norms return retval @@ -231,3 +250,4 @@ def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tens loss = self.abs(prd, tar, wgt) return loss + diff --git a/makani/utils/training/ensemble_trainer.py b/makani/utils/training/ensemble_trainer.py index 2f6a702..3b60440 100644 --- a/makani/utils/training/ensemble_trainer.py +++ b/makani/utils/training/ensemble_trainer.py @@ -248,8 +248,7 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] = # visualization wrapper: with Timer() as timer: - plot_channel = "z500" - # plot_channel = "q50" + plot_channel = "sst" # plot_index = self.params.channel_names.index(plot_channel) plot_index = 0 plot_list = [{"name": plot_channel, "functor": f"lambda x: x[{plot_index}, ...]", "diverging": False}] From 2e79a17593b1566f3eafd3ca63dc8aa505772542 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Tue, 2 Dec 2025 00:44:28 -0800 Subject: [PATCH 22/32] adding energy score loss --- makani/utils/losses/energy_score.py | 177 ++++++++++++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 makani/utils/losses/energy_score.py diff --git a/makani/utils/losses/energy_score.py b/makani/utils/losses/energy_score.py new file mode 100644 index 0000000..018d9e3 --- /dev/null +++ b/makani/utils/losses/energy_score.py @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, List + +import math +import torch +import torch.nn as nn +from torch import amp + +from makani.utils.losses.base_loss import GeometricBaseLoss, LossType +from makani.utils import comm + +import torch_harmonics as th +import torch_harmonics.distributed as thd + +# distributed stuff +from physicsnemo.distributed.utils import compute_split_shapes, split_tensor_along_dim +from physicsnemo.distributed.mappings import scatter_to_parallel_region, reduce_from_parallel_region +from makani.mpu.mappings import distributed_transpose + + +class EnergyScoreLoss(GeometricBaseLoss): + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + pole_mask: int, + spatial_distributed: Optional[bool] = False, + ensemble_distributed: Optional[bool] = False, + ensemble_weights: Optional[torch.Tensor] = None, + alpha: Optional[float] = 1.0, + beta: Optional[float] = 2.0, + eps: Optional[float] = 1.0e-5, + **kwargs, + ): + + super().__init__( + img_shape=img_shape, + crop_shape=crop_shape, + crop_offset=crop_offset, + channel_names=channel_names, + grid_type=grid_type, + pole_mask=pole_mask, + spatial_distributed=spatial_distributed, + ) + + self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed + self.alpha = alpha + self.beta = beta + self.eps = eps + + # we also need a variant of the weights split in ensemble direction: + quad_weight_split = self.quadrature.quad_weight.reshape(1, 1, -1) + if self.ensemble_distributed: + quad_weight_split = split_tensor_along_dim(quad_weight_split, dim=-1, num_chunks=comm.get_size("ensemble"))[comm.get_rank("ensemble")] + quad_weight_split = quad_weight_split.contiguous() + self.register_buffer("quad_weight_split", quad_weight_split, persistent=False) + + if ensemble_weights is not None: + self.register_buffer("ensemble_weights", ensemble_weights, persistent=False) + else: + self.ensemble_weights = ensemble_weights + + @property + def type(self): + return LossType.Probabilistic + + @property + def n_channels(self): + return 1 + + @torch.compiler.disable(recursive=False) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: str) -> torch.Tensor: + return torch.ones(1) + + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + + # sanity checks + if forecasts.dim() != 5: + raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.") + + # we assume that spatial_weights have NO ensemble dim + if (spatial_weights is not None) and (spatial_weights.dim() != observations.dim()): + spdim = spatial_weights.dim() + odim = observations.dim() + raise ValueError(f"the weights have to have the same number of dimensions (found {spdim}) as observations (found {odim}).") + + # we assume the following shapes: + # forecasts: batch, ensemble, channels, lat, lon + # observations: batch, channels, lat, lon + B, E, C, H, W = forecasts.shape + + # transpose the forecasts to ensemble, batch, channels, lat, lon and then do distributed transpose into ensemble direction. + # ideally we split spatial dims + forecasts = torch.moveaxis(forecasts, 1, 0) + forecasts = forecasts.reshape(E, B, C, H * W) + if self.ensemble_distributed: + ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] + forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") + + # observations does not need a transpose, but just a split + observations = observations.reshape(1, B, C, H * W) + if self.ensemble_distributed: + observations = scatter_to_parallel_region(observations, -1, "ensemble") + + # for correct spatial reduction we need to do the same with spatial weights + if spatial_weights is not None: + spatial_weights_split = spatial_weights.flatten(start_dim=-2, end_dim=-1) + spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") + + if self.ensemble_weights is not None: + raise NotImplementedError("currently only constant ensemble weights are supported") + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # ensemble size + num_ensemble = forecasts.shape[0] + + # get nanmask from the observarions + nanmasks = torch.logical_or(torch.isnan(observations), torch.isnan(ensemble_weights)) + + # use broadcasting semantics to compute spread and skill and sum over channels (vector norm) + espread = (forecasts.unsqueeze(1) - forecasts.unsqueeze(0)).abs().pow(self.beta) + eskill = (observations - forecasts).abs().pow(self.beta) + + # perform masking before any reduction + espread = torch.where(nanmasks.sum(dim=0) != 0, 0.0, espread) + eskill = torch.where(nanmasks.sum(dim=0) != 0, 0.0, eskill) + + # do the spatial reduction + if spatial_weights is not None: + espread = torch.sum(espread * self.quad_weight_split * spatial_weights_split, dim=-1) + eskill = torch.sum(eskill * self.quad_weight_split * spatial_weights_split, dim=-1) + else: + espread = torch.sum(espread * self.quad_weight_split, dim=-1) + eskill = torch.sum(eskill * self.quad_weight_split, dim=-1) + + # since we split spatial dim into ensemble dim, we need to do an ensemble sum as well + if self.ensemble_distributed: + espread = reduce_from_parallel_region(espread, "ensemble") + eskill = reduce_from_parallel_region(eskill, "ensemble") + + # we need to do the spatial averaging manually since + # we are not calling the quadrature forward function + if self.spatial_distributed: + espread = reduce_from_parallel_region(espread, "spatial") + eskill = reduce_from_parallel_region(eskill, "spatial") + + # do the channel reduction while ignoring NaNs + # if channel weights are required they should be added here to the reduction + espread = espread.sum(dim=-1, keepdim=True) + eskill = eskill.sum(dim=-1, keepdim=True) + + # now we have reduced everything and need to sum appropriately + espread = espread.sum(dim=(0,1)).pow(1/self.beta) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) + eskill = eskill.sum(dim=0).pow(1/self.beta) / float(num_ensemble) + + # the resulting tensor should have dimension B, C which is what we return + return eskill - 0.5 * espread \ No newline at end of file From 08b8f2eccfc092bce53636c69d1a12541ef974de Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Tue, 2 Dec 2025 00:56:39 -0800 Subject: [PATCH 23/32] removing the random CRPS --- makani/utils/loss.py | 3 --- makani/utils/losses/__init__.py | 3 +-- makani/utils/losses/energy_score.py | 2 +- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/makani/utils/loss.py b/makani/utils/loss.py index 8ab1afb..aaf3358 100644 --- a/makani/utils/loss.py +++ b/makani/utils/loss.py @@ -35,7 +35,6 @@ from .losses import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleGradientCRPSLoss, EnsembleVortDivCRPSLoss, EnergyScoreLoss from .losses import EnsembleNLLLoss, EnsembleMMDLoss from .losses import DriftRegularization, HydrostaticBalanceLoss -from .losses import RandomizedKernelCRPS class LossHandler(nn.Module): @@ -245,8 +244,6 @@ def _parse_loss_type(self, loss_type: str): loss_handle = partial(EnergyScoreLoss) elif "drift_regularization" in loss_type: loss_handle = DriftRegularization - elif "randomized_kernel" in loss_type: - loss_handle = RandomizedKernelCRPS else: raise NotImplementedError(f"Unknown loss function: {loss_type}") diff --git a/makani/utils/losses/__init__.py b/makani/utils/losses/__init__.py index c13c236..f6c9f10 100644 --- a/makani/utils/losses/__init__.py +++ b/makani/utils/losses/__init__.py @@ -19,8 +19,7 @@ from .amse_loss import SpectralAMSELoss from .hydrostatic_loss import HydrostaticBalanceLoss from .crps_loss import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleGradientCRPSLoss, EnsembleVortDivCRPSLoss -from .energy_score import EnergyScoreLoss +from .energy_score import LpEnergyScoreLoss from .mmd_loss import EnsembleMMDLoss from .likelihood_loss import EnsembleNLLLoss from .drift_regularization import DriftRegularization -from .randomized_crps import RandomizedKernelCRPS diff --git a/makani/utils/losses/energy_score.py b/makani/utils/losses/energy_score.py index 018d9e3..2a4cfa3 100644 --- a/makani/utils/losses/energy_score.py +++ b/makani/utils/losses/energy_score.py @@ -32,7 +32,7 @@ from makani.mpu.mappings import distributed_transpose -class EnergyScoreLoss(GeometricBaseLoss): +class LpEnergyScoreLoss(GeometricBaseLoss): def __init__( self, From 7bca7d1dd19afe8c88825c4e8a209f843a36b0fb Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Tue, 2 Dec 2025 01:02:07 -0800 Subject: [PATCH 24/32] fixing another bug --- makani/utils/loss.py | 2 +- makani/utils/losses/energy_score.py | 144 ++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+), 1 deletion(-) diff --git a/makani/utils/loss.py b/makani/utils/loss.py index aaf3358..51551d5 100644 --- a/makani/utils/loss.py +++ b/makani/utils/loss.py @@ -32,7 +32,7 @@ from torch_harmonics.quadrature import clenshaw_curtiss_weights, legendre_gauss_weights from .losses import LossType, GeometricLpLoss, SpectralLpLoss, SpectralH1Loss, SpectralAMSELoss -from .losses import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleGradientCRPSLoss, EnsembleVortDivCRPSLoss, EnergyScoreLoss +from .losses import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleGradientCRPSLoss, EnsembleVortDivCRPSLoss, LpEnergyScoreLoss from .losses import EnsembleNLLLoss, EnsembleMMDLoss from .losses import DriftRegularization, HydrostaticBalanceLoss diff --git a/makani/utils/losses/energy_score.py b/makani/utils/losses/energy_score.py index 2a4cfa3..b642ee0 100644 --- a/makani/utils/losses/energy_score.py +++ b/makani/utils/losses/energy_score.py @@ -91,6 +91,150 @@ def n_channels(self): def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: str) -> torch.Tensor: return torch.ones(1) + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + + # sanity checks + if forecasts.dim() != 5: + raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.") + + # we assume that spatial_weights have NO ensemble dim + if (spatial_weights is not None) and (spatial_weights.dim() != observations.dim()): + spdim = spatial_weights.dim() + odim = observations.dim() + raise ValueError(f"the weights have to have the same number of dimensions (found {spdim}) as observations (found {odim}).") + + # we assume the following shapes: + # forecasts: batch, ensemble, channels, lat, lon + # observations: batch, channels, lat, lon + B, E, C, H, W = forecasts.shape + + # transpose the forecasts to ensemble, batch, channels, lat, lon and then do distributed transpose into ensemble direction. + # ideally we split spatial dims + forecasts = torch.moveaxis(forecasts, 1, 0) + forecasts = forecasts.reshape(E, B, C, H * W) + if self.ensemble_distributed: + ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] + forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") + + # observations does not need a transpose, but just a split + observations = observations.reshape(1, B, C, H * W) + if self.ensemble_distributed: + observations = scatter_to_parallel_region(observations, -1, "ensemble") + + # for correct spatial reduction we need to do the same with spatial weights + if spatial_weights is not None: + spatial_weights_split = spatial_weights.flatten(start_dim=-2, end_dim=-1) + spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") + + if self.ensemble_weights is not None: + raise NotImplementedError("currently only constant ensemble weights are supported") + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # ensemble size + num_ensemble = forecasts.shape[0] + + # get nanmask from the observarions + nanmasks = torch.logical_or(torch.isnan(observations), torch.isnan(ensemble_weights)) + + # use broadcasting semantics to compute spread and skill and sum over channels (vector norm) + espread = (forecasts.unsqueeze(1) - forecasts.unsqueeze(0)).abs().pow(self.beta) + eskill = (observations - forecasts).abs().pow(self.beta) + + # perform masking before any reduction + espread = torch.where(nanmasks.sum(dim=0) != 0, 0.0, espread) + eskill = torch.where(nanmasks.sum(dim=0) != 0, 0.0, eskill) + + # do the spatial reduction + if spatial_weights is not None: + espread = torch.sum(espread * self.quad_weight_split * spatial_weights_split, dim=-1) + eskill = torch.sum(eskill * self.quad_weight_split * spatial_weights_split, dim=-1) + else: + espread = torch.sum(espread * self.quad_weight_split, dim=-1) + eskill = torch.sum(eskill * self.quad_weight_split, dim=-1) + + # since we split spatial dim into ensemble dim, we need to do an ensemble sum as well + if self.ensemble_distributed: + espread = reduce_from_parallel_region(espread, "ensemble") + eskill = reduce_from_parallel_region(eskill, "ensemble") + + # we need to do the spatial averaging manually since + # we are not calling the quadrature forward function + if self.spatial_distributed: + espread = reduce_from_parallel_region(espread, "spatial") + eskill = reduce_from_parallel_region(eskill, "spatial") + + # do the channel reduction while ignoring NaNs + # if channel weights are required they should be added here to the reduction + espread = espread.sum(dim=-1, keepdim=True) + eskill = eskill.sum(dim=-1, keepdim=True) + + # now we have reduced everything and need to sum appropriately + espread = espread.sum(dim=(0,1)).pow(1/self.beta) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) + eskill = eskill.sum(dim=0).pow(1/self.beta) / float(num_ensemble) + + # the resulting tensor should have dimension B, C which is what we return + return eskill - 0.5 * espread + +class H1EnergyScoreLoss(GeometricBaseLoss): + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + pole_mask: int, + spatial_distributed: Optional[bool] = False, + ensemble_distributed: Optional[bool] = False, + ensemble_weights: Optional[torch.Tensor] = None, + alpha: Optional[float] = 1.0, + beta: Optional[float] = 2.0, + eps: Optional[float] = 1.0e-5, + **kwargs, + ): + + super().__init__( + img_shape=img_shape, + crop_shape=crop_shape, + crop_offset=crop_offset, + channel_names=channel_names, + grid_type=grid_type, + pole_mask=pole_mask, + spatial_distributed=spatial_distributed, + ) + + self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed + self.alpha = alpha + self.beta = beta + self.eps = eps + + # we also need a variant of the weights split in ensemble direction: + quad_weight_split = self.quadrature.quad_weight.reshape(1, 1, -1) + if self.ensemble_distributed: + quad_weight_split = split_tensor_along_dim(quad_weight_split, dim=-1, num_chunks=comm.get_size("ensemble"))[comm.get_rank("ensemble")] + quad_weight_split = quad_weight_split.contiguous() + self.register_buffer("quad_weight_split", quad_weight_split, persistent=False) + + if ensemble_weights is not None: + self.register_buffer("ensemble_weights", ensemble_weights, persistent=False) + else: + self.ensemble_weights = ensemble_weights + + @property + def type(self): + return LossType.Probabilistic + + @property + def n_channels(self): + return 1 + + @torch.compiler.disable(recursive=False) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: str) -> torch.Tensor: + return torch.ones(1) + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: # sanity checks From 7f5a59a98752e792eaa5785a3bfc6441dfcbc103 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Tue, 2 Dec 2025 01:20:45 -0800 Subject: [PATCH 25/32] more fixes for energy score loss --- makani/utils/loss.py | 4 ++-- makani/utils/losses/energy_score.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/makani/utils/loss.py b/makani/utils/loss.py index 51551d5..4698799 100644 --- a/makani/utils/loss.py +++ b/makani/utils/loss.py @@ -240,8 +240,8 @@ def _parse_loss_type(self, loss_type: str): loss_handle = EnsembleNLLLoss elif "ensemble_mmd" in loss_type: loss_handle = EnsembleMMDLoss - elif "energy_score" in loss_type: - loss_handle = partial(EnergyScoreLoss) + elif "energy score" in loss_type: + loss_handle = partial(LpEnergyScoreLoss) elif "drift_regularization" in loss_type: loss_handle = DriftRegularization else: diff --git a/makani/utils/losses/energy_score.py b/makani/utils/losses/energy_score.py index b642ee0..a939cdf 100644 --- a/makani/utils/losses/energy_score.py +++ b/makani/utils/losses/energy_score.py @@ -186,9 +186,12 @@ def __init__( channel_names: List[str], grid_type: str, pole_mask: int, + lmax: Optional[int] = None, + crps_type: str = "skillspread", spatial_distributed: Optional[bool] = False, ensemble_distributed: Optional[bool] = False, ensemble_weights: Optional[torch.Tensor] = None, + absolute: Optional[bool] = True, alpha: Optional[float] = 1.0, beta: Optional[float] = 2.0, eps: Optional[float] = 1.0e-5, From e29a588863b6bb18e5e54c4a7fcb121b45c61adb Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Tue, 2 Dec 2025 02:15:24 -0800 Subject: [PATCH 26/32] another fix --- makani/utils/loss.py | 5 +++- makani/utils/losses/__init__.py | 2 +- makani/utils/losses/energy_score.py | 36 +++++++++++++++++++++++++---- 3 files changed, 37 insertions(+), 6 deletions(-) diff --git a/makani/utils/loss.py b/makani/utils/loss.py index 4698799..033b6e1 100644 --- a/makani/utils/loss.py +++ b/makani/utils/loss.py @@ -32,7 +32,8 @@ from torch_harmonics.quadrature import clenshaw_curtiss_weights, legendre_gauss_weights from .losses import LossType, GeometricLpLoss, SpectralLpLoss, SpectralH1Loss, SpectralAMSELoss -from .losses import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleGradientCRPSLoss, EnsembleVortDivCRPSLoss, LpEnergyScoreLoss +from .losses import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleGradientCRPSLoss, EnsembleVortDivCRPSLoss +from .losses import LpEnergyScoreLoss, H1EnergyScoreLoss from .losses import EnsembleNLLLoss, EnsembleMMDLoss from .losses import DriftRegularization, HydrostaticBalanceLoss @@ -242,6 +243,8 @@ def _parse_loss_type(self, loss_type: str): loss_handle = EnsembleMMDLoss elif "energy score" in loss_type: loss_handle = partial(LpEnergyScoreLoss) + elif "h1 energy score" in loss_type: + loss_handle = partial(H1EnergyScoreLoss) elif "drift_regularization" in loss_type: loss_handle = DriftRegularization else: diff --git a/makani/utils/losses/__init__.py b/makani/utils/losses/__init__.py index f6c9f10..3af0b58 100644 --- a/makani/utils/losses/__init__.py +++ b/makani/utils/losses/__init__.py @@ -19,7 +19,7 @@ from .amse_loss import SpectralAMSELoss from .hydrostatic_loss import HydrostaticBalanceLoss from .crps_loss import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleGradientCRPSLoss, EnsembleVortDivCRPSLoss -from .energy_score import LpEnergyScoreLoss +from .energy_score import LpEnergyScoreLoss, H1EnergyScoreLoss from .mmd_loss import EnsembleMMDLoss from .likelihood_loss import EnsembleNLLLoss from .drift_regularization import DriftRegularization diff --git a/makani/utils/losses/energy_score.py b/makani/utils/losses/energy_score.py index a939cdf..4052bc5 100644 --- a/makani/utils/losses/energy_score.py +++ b/makani/utils/losses/energy_score.py @@ -20,7 +20,7 @@ import torch.nn as nn from torch import amp -from makani.utils.losses.base_loss import GeometricBaseLoss, LossType +from makani.utils.losses.base_loss import GeometricBaseLoss, GradientBaseLoss, LossType from makani.utils import comm import torch_harmonics as th @@ -108,16 +108,40 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w # observations: batch, channels, lat, lon B, E, C, H, W = forecasts.shape + # get the data type before stripping amp types + dtype = forecasts.dtype + + # before anything else compute the transform + # as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same + with amp.autocast(device_type="cuda", enabled=False): + + # compute the SH coefficients of the forecasts and observations + forecasts = self.sht(forecasts.float()).unsqueeze(-3) + observations = self.sht(observations.float()).unsqueeze(-3) + + # append zeros, so that we can use the inverse vector SHT + forecasts = torch.cat([forecasts, torch.zeros_like(forecasts)], dim=-3) + observations = torch.cat([observations, torch.zeros_like(observations)], dim=-3) + + forecasts = self.ivsht(forecasts) + observations = self.ivsht(observations) + + forecasts = forecasts.to(dtype) + observations = observations.to(dtype) + + forecasts = forecasts.reshape(B, E, 2*C, H, W) + observations = observations.reshape(B, 2*C, H, W) + # transpose the forecasts to ensemble, batch, channels, lat, lon and then do distributed transpose into ensemble direction. # ideally we split spatial dims forecasts = torch.moveaxis(forecasts, 1, 0) - forecasts = forecasts.reshape(E, B, C, H * W) + forecasts = forecasts.reshape(E, B, 2*C, H * W) if self.ensemble_distributed: ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") # observations does not need a transpose, but just a split - observations = observations.reshape(1, B, C, H * W) + observations = observations.reshape(1, B, 2*C, H * W) if self.ensemble_distributed: observations = scatter_to_parallel_region(observations, -1, "ensemble") @@ -176,7 +200,7 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w # the resulting tensor should have dimension B, C which is what we return return eskill - 0.5 * espread -class H1EnergyScoreLoss(GeometricBaseLoss): +class H1EnergyScoreLoss(GradientBaseLoss): def __init__( self, @@ -205,9 +229,13 @@ def __init__( channel_names=channel_names, grid_type=grid_type, pole_mask=pole_mask, + lmax=lmax, spatial_distributed=spatial_distributed, ) + # if absolute is true, the loss is computed only on the absolute value of the gradient + self.absolute = absolute + self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed self.alpha = alpha From 207fa6dbeed34a0176ec7934a7d7c9903e6e43cc Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Tue, 2 Dec 2025 02:48:44 -0800 Subject: [PATCH 27/32] fixing another bug --- makani/utils/loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/makani/utils/loss.py b/makani/utils/loss.py index 033b6e1..34cf5f8 100644 --- a/makani/utils/loss.py +++ b/makani/utils/loss.py @@ -241,9 +241,9 @@ def _parse_loss_type(self, loss_type: str): loss_handle = EnsembleNLLLoss elif "ensemble_mmd" in loss_type: loss_handle = EnsembleMMDLoss - elif "energy score" in loss_type: + elif "energy_score" in loss_type: loss_handle = partial(LpEnergyScoreLoss) - elif "h1 energy score" in loss_type: + elif "h1_energy_score" in loss_type: loss_handle = partial(H1EnergyScoreLoss) elif "drift_regularization" in loss_type: loss_handle = DriftRegularization From af59f917ca7fe30f2e34c1f97135d6a4b073d0db Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Tue, 2 Dec 2025 03:34:39 -0800 Subject: [PATCH 28/32] fixing energy score --- makani/utils/losses/energy_score.py | 55 +++++++++++++++-------------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/makani/utils/losses/energy_score.py b/makani/utils/losses/energy_score.py index 4052bc5..2c849bc 100644 --- a/makani/utils/losses/energy_score.py +++ b/makani/utils/losses/energy_score.py @@ -111,37 +111,16 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w # get the data type before stripping amp types dtype = forecasts.dtype - # before anything else compute the transform - # as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same - with amp.autocast(device_type="cuda", enabled=False): - - # compute the SH coefficients of the forecasts and observations - forecasts = self.sht(forecasts.float()).unsqueeze(-3) - observations = self.sht(observations.float()).unsqueeze(-3) - - # append zeros, so that we can use the inverse vector SHT - forecasts = torch.cat([forecasts, torch.zeros_like(forecasts)], dim=-3) - observations = torch.cat([observations, torch.zeros_like(observations)], dim=-3) - - forecasts = self.ivsht(forecasts) - observations = self.ivsht(observations) - - forecasts = forecasts.to(dtype) - observations = observations.to(dtype) - - forecasts = forecasts.reshape(B, E, 2*C, H, W) - observations = observations.reshape(B, 2*C, H, W) - # transpose the forecasts to ensemble, batch, channels, lat, lon and then do distributed transpose into ensemble direction. # ideally we split spatial dims forecasts = torch.moveaxis(forecasts, 1, 0) - forecasts = forecasts.reshape(E, B, 2*C, H * W) + forecasts = forecasts.reshape(E, B, C, H * W) if self.ensemble_distributed: ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") # observations does not need a transpose, but just a split - observations = observations.reshape(1, B, 2*C, H * W) + observations = observations.reshape(1, B, C, H * W) if self.ensemble_distributed: observations = scatter_to_parallel_region(observations, -1, "ensemble") @@ -283,16 +262,40 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w # observations: batch, channels, lat, lon B, E, C, H, W = forecasts.shape + # get the data type before stripping amp types + dtype = forecasts.dtype + + # before anything else compute the transform + # as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same + with amp.autocast(device_type="cuda", enabled=False): + + # compute the SH coefficients of the forecasts and observations + forecasts = self.sht(forecasts.float()).unsqueeze(-3) + observations = self.sht(observations.float()).unsqueeze(-3) + + # append zeros, so that we can use the inverse vector SHT + forecasts = torch.cat([forecasts, torch.zeros_like(forecasts)], dim=-3) + observations = torch.cat([observations, torch.zeros_like(observations)], dim=-3) + + forecasts = self.ivsht(forecasts) + observations = self.ivsht(observations) + + forecasts = forecasts.to(dtype) + observations = observations.to(dtype) + + forecasts = forecasts.reshape(B, E, 2*C, H, W) + observations = observations.reshape(B, 2*C, H, W) + # transpose the forecasts to ensemble, batch, channels, lat, lon and then do distributed transpose into ensemble direction. # ideally we split spatial dims forecasts = torch.moveaxis(forecasts, 1, 0) - forecasts = forecasts.reshape(E, B, C, H * W) + forecasts = forecasts.reshape(E, B, 2*C, H * W) if self.ensemble_distributed: ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") # observations does not need a transpose, but just a split - observations = observations.reshape(1, B, C, H * W) + observations = observations.reshape(1, B, 2*C, H * W) if self.ensemble_distributed: observations = scatter_to_parallel_region(observations, -1, "ensemble") @@ -348,5 +351,5 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w espread = espread.sum(dim=(0,1)).pow(1/self.beta) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) eskill = eskill.sum(dim=0).pow(1/self.beta) / float(num_ensemble) - # the resulting tensor should have dimension B, C which is what we return + # the resulting tensor should have dimension B, 1 which is what we return return eskill - 0.5 * espread \ No newline at end of file From ef4efc4c42a1c375e3a7d2fcef9938f67d462519 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Wed, 10 Dec 2025 15:32:38 -0800 Subject: [PATCH 29/32] cleaning up some losses --- makani/utils/loss.py | 17 +- makani/utils/losses/__init__.py | 4 +- makani/utils/losses/crps_loss.py | 8 +- makani/utils/losses/energy_score.py | 8 +- makani/utils/losses/mmd_loss.py | 330 +++++++++++++----- makani/utils/metrics/functions.py | 14 +- tests/distributed/tests_distributed_losses.py | 20 +- tests/test_losses.py | 66 ++-- 8 files changed, 319 insertions(+), 148 deletions(-) diff --git a/makani/utils/loss.py b/makani/utils/loss.py index 34cf5f8..b0056e0 100644 --- a/makani/utils/loss.py +++ b/makani/utils/loss.py @@ -32,9 +32,10 @@ from torch_harmonics.quadrature import clenshaw_curtiss_weights, legendre_gauss_weights from .losses import LossType, GeometricLpLoss, SpectralLpLoss, SpectralH1Loss, SpectralAMSELoss -from .losses import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleGradientCRPSLoss, EnsembleVortDivCRPSLoss +from .losses import CRPSLoss, SpectralCRPSLoss, GradientCRPSLoss, VortDivCRPSLoss from .losses import LpEnergyScoreLoss, H1EnergyScoreLoss -from .losses import EnsembleNLLLoss, EnsembleMMDLoss +from .losses import GaussianMMDLoss +from .losses import EnsembleNLLLoss from .losses import DriftRegularization, HydrostaticBalanceLoss @@ -230,17 +231,17 @@ def _parse_loss_type(self, loss_type: str): p_max = int(x.replace("p_max=", "")) loss_handle = partial(HydrostaticBalanceLoss, p_min=p_min, p_max=p_max, use_moist_air_formula=use_moist_air_formula) elif "ensemble_crps" in loss_type: - loss_handle = partial(EnsembleCRPSLoss) + loss_handle = partial(CRPSLoss) elif "ensemble_spectral_crps" in loss_type: - loss_handle = partial(EnsembleSpectralCRPSLoss) + loss_handle = partial(SpectralCRPSLoss) elif "ensemble_vort_div_crps" in loss_type: - loss_handle = partial(EnsembleVortDivCRPSLoss) + loss_handle = partial(VortDivCRPSLoss) elif "ensemble_gradient_crps" in loss_type: - loss_handle = partial(EnsembleGradientCRPSLoss) + loss_handle = partial(GradientCRPSLoss) elif "ensemble_nll" in loss_type: loss_handle = EnsembleNLLLoss - elif "ensemble_mmd" in loss_type: - loss_handle = EnsembleMMDLoss + elif "gaussian_mmd" in loss_type: + loss_handle = GaussianMMDLoss elif "energy_score" in loss_type: loss_handle = partial(LpEnergyScoreLoss) elif "h1_energy_score" in loss_type: diff --git a/makani/utils/losses/__init__.py b/makani/utils/losses/__init__.py index 3af0b58..cb11cde 100644 --- a/makani/utils/losses/__init__.py +++ b/makani/utils/losses/__init__.py @@ -18,8 +18,8 @@ from .lp_loss import GeometricLpLoss, SpectralLpLoss from .amse_loss import SpectralAMSELoss from .hydrostatic_loss import HydrostaticBalanceLoss -from .crps_loss import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleGradientCRPSLoss, EnsembleVortDivCRPSLoss +from .crps_loss import CRPSLoss, SpectralCRPSLoss, GradientCRPSLoss, VortDivCRPSLoss from .energy_score import LpEnergyScoreLoss, H1EnergyScoreLoss -from .mmd_loss import EnsembleMMDLoss +from .mmd_loss import GaussianMMDLoss from .likelihood_loss import EnsembleNLLLoss from .drift_regularization import DriftRegularization diff --git a/makani/utils/losses/crps_loss.py b/makani/utils/losses/crps_loss.py index 3657aa4..e71470e 100644 --- a/makani/utils/losses/crps_loss.py +++ b/makani/utils/losses/crps_loss.py @@ -268,7 +268,7 @@ def _crps_gauss_kernel(observation: torch.Tensor, forecasts: torch.Tensor, weigh return crps -class EnsembleCRPSLoss(GeometricBaseLoss): +class CRPSLoss(GeometricBaseLoss): def __init__( self, @@ -422,7 +422,7 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w return crps -class EnsembleSpectralCRPSLoss(SpectralBaseLoss): +class SpectralCRPSLoss(SpectralBaseLoss): def __init__( self, @@ -608,7 +608,7 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spectral_ # the resulting tensor should have dimension B, C, which is what we return return crps -class EnsembleGradientCRPSLoss(GradientBaseLoss): +class GradientCRPSLoss(GradientBaseLoss): def __init__( self, @@ -802,7 +802,7 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w # the resulting tensor should have dimension B, C, which is what we return return crps -class EnsembleVortDivCRPSLoss(VortDivBaseLoss): +class VortDivCRPSLoss(VortDivBaseLoss): def __init__( self, diff --git a/makani/utils/losses/energy_score.py b/makani/utils/losses/energy_score.py index 2c849bc..758c548 100644 --- a/makani/utils/losses/energy_score.py +++ b/makani/utils/losses/energy_score.py @@ -173,8 +173,8 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w eskill = eskill.sum(dim=-1, keepdim=True) # now we have reduced everything and need to sum appropriately - espread = espread.sum(dim=(0,1)).pow(1/self.beta) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) - eskill = eskill.sum(dim=0).pow(1/self.beta) / float(num_ensemble) + espread = espread.pow(1/self.beta).sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) + eskill = eskill.pow(1/self.beta).sum(dim=0) / float(num_ensemble) # the resulting tensor should have dimension B, C which is what we return return eskill - 0.5 * espread @@ -348,8 +348,8 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w eskill = eskill.sum(dim=-1, keepdim=True) # now we have reduced everything and need to sum appropriately - espread = espread.sum(dim=(0,1)).pow(1/self.beta) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) - eskill = eskill.sum(dim=0).pow(1/self.beta) / float(num_ensemble) + espread = espread.pow(1/self.beta).sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) + eskill = eskill.pow(1/self.beta).sum(dim=0) / float(num_ensemble) # the resulting tensor should have dimension B, 1 which is what we return return eskill - 0.5 * espread \ No newline at end of file diff --git a/makani/utils/losses/mmd_loss.py b/makani/utils/losses/mmd_loss.py index 21d2d8d..7f0912c 100644 --- a/makani/utils/losses/mmd_loss.py +++ b/makani/utils/losses/mmd_loss.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,58 +15,24 @@ from typing import Optional, Tuple, List -import numpy as np - +import math import torch import torch.nn as nn -from torch.cuda import amp +from torch import amp -from makani.utils.losses.base_loss import GeometricBaseLoss, SpectralBaseLoss, LossType +from makani.utils.losses.base_loss import GeometricBaseLoss, GradientBaseLoss, LossType from makani.utils import comm +import torch_harmonics as th +import torch_harmonics.distributed as thd + # distributed stuff from physicsnemo.distributed.utils import compute_split_shapes, split_tensor_along_dim from physicsnemo.distributed.mappings import scatter_to_parallel_region, reduce_from_parallel_region from makani.mpu.mappings import distributed_transpose -# @torch.compile -# def _mmd_rbf_kernel(x: torch.Tensor, y: torch.Tensor): -# return torch.abs(x - y) - - -@torch.compile -def _mmd_rbf_kernel(x: torch.Tensor, y: torch.Tensor, bandwidth: float = 1.0): - return torch.exp(-0.5 * torch.square(torch.abs(x - y)) / bandwidth) - - -# Computes the squared maximum mean discrepancy -# @torch.compile -def _mmd2_ensemble_kernel(observation: torch.Tensor, forecasts: torch.Tensor) -> torch.Tensor: - - # initial values - spread_term = torch.zeros_like(observation) - disc_term = torch.zeros_like(observation) - - num_forecasts = forecasts.shape[0] - - for m in range(num_forecasts): - # get the forecast - ym = forecasts[m] - - # account for contributions on the off-diasgonal assuming that the kernel is symmetric - spread_term = spread_term + 2.0 * torch.sum(_mmd_rbf_kernel(ym, forecasts[m:]), dim=0) - - # contributions to the discrepancy term - disc_term = disc_term + _mmd_rbf_kernel(ym, observation) - - # compute the squared mmd - mmd2 = spread_term / (num_forecasts - 1) / num_forecasts - 2.0 * disc_term / num_forecasts - - return mmd2 - - -class EnsembleMMDLoss(GeometricBaseLoss): +class GaussianMMDLoss(GeometricBaseLoss): r""" Computes the maximum mean discrepancy loss for a specific kernel. For details see [1] @@ -80,10 +46,15 @@ def __init__( crop_offset: Tuple[int, int], channel_names: List[str], grid_type: str, - squared: Optional[bool] = False, - pole_mask: Optional[int] = 0, + pole_mask: int, spatial_distributed: Optional[bool] = False, ensemble_distributed: Optional[bool] = False, + ensemble_weights: Optional[torch.Tensor] = None, + sigma: Optional[float] = 1.0, + alpha: Optional[float] = 1.0, + beta: Optional[float] = 2.0, + eps: Optional[float] = 1.0e-5, + channel_reduction: Optional[bool] = False, **kwargs, ): @@ -97,10 +68,13 @@ def __init__( spatial_distributed=spatial_distributed, ) - self.squared = squared - self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed + self.alpha = alpha + self.beta = beta + self.eps = eps + self.channel_reduction = channel_reduction + self.sigma = sigma # we also need a variant of the weights split in ensemble direction: quad_weight_split = self.quadrature.quad_weight.reshape(1, 1, -1) @@ -109,10 +83,26 @@ def __init__( quad_weight_split = quad_weight_split.contiguous() self.register_buffer("quad_weight_split", quad_weight_split, persistent=False) + if ensemble_weights is not None: + self.register_buffer("ensemble_weights", ensemble_weights, persistent=False) + else: + self.ensemble_weights = ensemble_weights + @property def type(self): return LossType.Probabilistic + @property + def n_channels(self): + if self.channel_reduction: + return 1 + else: + return len(self.channel_names) + + @torch.compiler.disable(recursive=False) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: str) -> torch.Tensor: + return torch.ones(1) + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: # sanity checks @@ -121,51 +111,231 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w # we assume that spatial_weights have NO ensemble dim if (spatial_weights is not None) and (spatial_weights.dim() != observations.dim()): - raise ValueError("the weights have to have the same number of dimensions as observations") + spdim = spatial_weights.dim() + odim = observations.dim() + raise ValueError(f"the weights have to have the same number of dimensions (found {spdim}) as observations (found {odim}).") # we assume the following shapes: # forecasts: batch, ensemble, channels, lat, lon # observations: batch, channels, lat, lon B, E, C, H, W = forecasts.shape - # if ensemble dim is one dimensional then computing the score is quick: - if (not self.ensemble_distributed) and (forecasts.shape[1] == 1): - # in this case, CRPS is straightforward - mmd = _mmd_rbf_kernel(observations, forecasts.squeeze(1)).reshape(B, C, H * W) + # get the data type before stripping amp types + dtype = forecasts.dtype + + # transpose the forecasts to ensemble, batch, channels, lat, lon and then do distributed transpose into ensemble direction. + # ideally we split spatial dims + forecasts = torch.moveaxis(forecasts, 1, 0) + forecasts = forecasts.reshape(E, B, C, H * W) + if self.ensemble_distributed: + ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] + forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") + + # observations does not need a transpose, but just a split + observations = observations.reshape(1, B, C, H * W) + if self.ensemble_distributed: + observations = scatter_to_parallel_region(observations, -1, "ensemble") + + # for correct spatial reduction we need to do the same with spatial weights + if spatial_weights is not None: + spatial_weights_split = spatial_weights.flatten(start_dim=-2, end_dim=-1) + spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") + + if self.ensemble_weights is not None: + raise NotImplementedError("currently only constant ensemble weights are supported") else: - # transpose forecasts: ensemble, batch, channels, lat, lon - forecasts = torch.moveaxis(forecasts, 1, 0) - - # now we need to transpose the forecasts into ensemble direction. - # ideally we split spatial dims - forecasts = forecasts.reshape(E, B, C, H * W) - if self.ensemble_distributed: - ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] - forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") - # observations does not need a transpose, but just a split - observations = observations.reshape(B, C, H * W) - if self.ensemble_distributed: - observations = scatter_to_parallel_region(observations, -1, "ensemble") - if spatial_weights is not None: - spatial_weights_split = spatial_weights.flatten(-2, -1) - spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") - - # now, E dimension is local and spatial dim is split further. Compute the mmd - mmd = _mmd2_ensemble_kernel(observations, forecasts) - - # perform spatial average of crps score + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # ensemble size + num_ensemble = forecasts.shape[0] + + # get nanmask from the observarions + nanmasks = torch.logical_or(torch.isnan(observations), torch.isnan(ensemble_weights)) + + # use broadcasting semantics to compute spread and skill and sum over channels (vector norm) + espread = (forecasts.unsqueeze(1) - forecasts.unsqueeze(0)).abs().pow(self.beta) + eskill = (observations - forecasts).abs().pow(self.beta) + + # perform masking before any reduction + espread = torch.where(nanmasks.sum(dim=0) != 0, 0.0, espread) + eskill = torch.where(nanmasks.sum(dim=0) != 0, 0.0, eskill) + + # do the spatial reduction if spatial_weights is not None: - mmd = torch.sum(mmd * self.quad_weight_split * spatial_weights_split, dim=-1) + espread = torch.sum(espread * self.quad_weight_split * spatial_weights_split, dim=-1) + eskill = torch.sum(eskill * self.quad_weight_split * spatial_weights_split, dim=-1) else: - mmd = torch.sum(mmd * self.quad_weight_split, dim=-1) + espread = torch.sum(espread * self.quad_weight_split, dim=-1) + eskill = torch.sum(eskill * self.quad_weight_split, dim=-1) + + # since we split spatial dim into ensemble dim, we need to do an ensemble sum as well if self.ensemble_distributed: - mmd = reduce_from_parallel_region(mmd, "ensemble") + espread = reduce_from_parallel_region(espread, "ensemble") + eskill = reduce_from_parallel_region(eskill, "ensemble") + # we need to do the spatial averaging manually since + # we are not calling the quadrature forward function if self.spatial_distributed: - mmd = reduce_from_parallel_region(mmd, "spatial") + espread = reduce_from_parallel_region(espread, "spatial") + eskill = reduce_from_parallel_region(eskill, "spatial") + + # do the channel reduction while ignoring NaNs + # if channel weights are required they should be added here to the reduction + if self.channel_reduction: + espread = espread.sum(dim=-2, keepdim=True) + eskill = eskill.sum(dim=-2, keepdim=True) + + # apply the Gaussian kernel + espread = torch.exp(-0.5 * torch.square(espread) / self.sigma) + eskill = torch.exp(-0.5 * torch.square(eskill) / self.sigma) + + # mask out the diagonal elements in the spread term + espread = torch.where(torch.eye(num_ensemble, device=espread.device).bool().reshape(num_ensemble, num_ensemble, 1, 1), 0.0, espread) + + # now we have reduced everything and need to sum appropriately + espread = espread.sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) + eskill = eskill.sum(dim=0) / float(num_ensemble) + + # the resulting tensor should have dimension B, C which is what we return + return eskill - 0.5 * espread - if not self.squared: - mmd = torch.sqrt(mmd) +# @torch.compile +# def _mmd_rbf_kernel(x: torch.Tensor, y: torch.Tensor): +# return torch.abs(x - y) + + +# @torch.compile +# def _mmd_rbf_kernel(x: torch.Tensor, y: torch.Tensor, bandwidth: float = 1.0): +# return torch.exp(-0.5 * torch.square(torch.abs(x - y)) / bandwidth) - # the resulting tensor should have dimension B, C, which is what we return - return mmd + +# # Computes the squared maximum mean discrepancy +# # @torch.compile +# def _mmd2_ensemble_kernel(observation: torch.Tensor, forecasts: torch.Tensor) -> torch.Tensor: + +# # initial values +# spread_term = torch.zeros_like(observation) +# disc_term = torch.zeros_like(observation) + +# num_forecasts = forecasts.shape[0] + +# for m in range(num_forecasts): + +# # get the forecast +# ym = forecasts[m] + +# # account for contributions on the off-diasgonal assuming that the kernel is symmetric +# spread_term = spread_term + 2.0 * torch.sum(_mmd_rbf_kernel(ym, forecasts[m:]), dim=0) + +# # contributions to the discrepancy term +# disc_term = disc_term + _mmd_rbf_kernel(ym, observation) + +# # compute the squared mmd +# mmd2 = spread_term / (num_forecasts - 1) / num_forecasts - 2.0 * disc_term / num_forecasts + +# return mmd2 + + +# class EnsembleMMDLoss(GeometricBaseLoss): +# r""" +# Computes the maximum mean discrepancy loss for a specific kernel. For details see [1] + +# [1] Dziugaite, Gintare Karolina; Roy, Daniel M.; Ghahramani, Zhoubin; Training generative neural networks via Maximum Mean Discrepancy optimization; arXiv:1505.03906 +# """ + +# def __init__( +# self, +# img_shape: Tuple[int, int], +# crop_shape: Tuple[int, int], +# crop_offset: Tuple[int, int], +# channel_names: List[str], +# grid_type: str, +# squared: Optional[bool] = False, +# pole_mask: Optional[int] = 0, +# spatial_distributed: Optional[bool] = False, +# ensemble_distributed: Optional[bool] = False, +# **kwargs, +# ): + +# super().__init__( +# img_shape=img_shape, +# crop_shape=crop_shape, +# crop_offset=crop_offset, +# channel_names=channel_names, +# grid_type=grid_type, +# pole_mask=pole_mask, +# spatial_distributed=spatial_distributed, +# ) + +# self.squared = squared + +# self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed +# self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed + +# # we also need a variant of the weights split in ensemble direction: +# quad_weight_split = self.quadrature.quad_weight.reshape(1, 1, -1) +# if self.ensemble_distributed: +# quad_weight_split = split_tensor_along_dim(quad_weight_split, dim=-1, num_chunks=comm.get_size("ensemble"))[comm.get_rank("ensemble")] +# quad_weight_split = quad_weight_split.contiguous() +# self.register_buffer("quad_weight_split", quad_weight_split, persistent=False) + +# @property +# def type(self): +# return LossType.Probabilistic + +# def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + +# # sanity checks +# if forecasts.dim() != 5: +# raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.") + +# # we assume that spatial_weights have NO ensemble dim +# if (spatial_weights is not None) and (spatial_weights.dim() != observations.dim()): +# raise ValueError("the weights have to have the same number of dimensions as observations") + +# # we assume the following shapes: +# # forecasts: batch, ensemble, channels, lat, lon +# # observations: batch, channels, lat, lon +# B, E, C, H, W = forecasts.shape + +# # if ensemble dim is one dimensional then computing the score is quick: +# if (not self.ensemble_distributed) and (forecasts.shape[1] == 1): +# # in this case, CRPS is straightforward +# mmd = _mmd_rbf_kernel(observations, forecasts.squeeze(1)).reshape(B, C, H * W) +# else: +# # transpose forecasts: ensemble, batch, channels, lat, lon +# forecasts = torch.moveaxis(forecasts, 1, 0) + +# # now we need to transpose the forecasts into ensemble direction. +# # ideally we split spatial dims +# forecasts = forecasts.reshape(E, B, C, H * W) +# if self.ensemble_distributed: +# ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] +# forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") +# # observations does not need a transpose, but just a split +# observations = observations.reshape(B, C, H * W) +# if self.ensemble_distributed: +# observations = scatter_to_parallel_region(observations, -1, "ensemble") +# if spatial_weights is not None: +# spatial_weights_split = spatial_weights.flatten(-2, -1) +# spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") + +# # now, E dimension is local and spatial dim is split further. Compute the mmd +# mmd = _mmd2_ensemble_kernel(observations, forecasts) + +# # perform spatial average of crps score +# if spatial_weights is not None: +# mmd = torch.sum(mmd * self.quad_weight_split * spatial_weights_split, dim=-1) +# else: +# mmd = torch.sum(mmd * self.quad_weight_split, dim=-1) +# if self.ensemble_distributed: +# mmd = reduce_from_parallel_region(mmd, "ensemble") + +# if self.spatial_distributed: +# mmd = reduce_from_parallel_region(mmd, "spatial") + +# if not self.squared: +# mmd = torch.sqrt(mmd) + +# # the resulting tensor should have dimension B, C, which is what we return +# return mmd diff --git a/makani/utils/metrics/functions.py b/makani/utils/metrics/functions.py index ea4ae51..d99038c 100644 --- a/makani/utils/metrics/functions.py +++ b/makani/utils/metrics/functions.py @@ -22,7 +22,7 @@ from physicsnemo.distributed.utils import split_tensor_along_dim from makani.mpu.mappings import distributed_transpose -from makani.utils.losses import EnsembleCRPSLoss, LossType +from makani.utils.losses import CRPSLoss, LossType from makani.utils.metrics.base_metric import _sanitize_shapes, _welford_reduction_helper, GeometricBaseMetric class GeometricL1(GeometricBaseMetric): @@ -197,7 +197,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tenso # stack along dim -1: # we form the ratio in the finalization step acc = torch.stack([cov_xy, var_x, var_y], dim=-1) - + # reduce if self.channel_reduction == "mean": acc = torch.mean(acc, dim=1) @@ -252,12 +252,12 @@ def __init__( def combine(self, vals, counts, dim=0): # sanitize shapes vals, counts = _sanitize_shapes(vals, counts, dim=dim) - + # extract parameters covs = vals[..., 0].unsqueeze(-1) m2s = vals[..., 1:3] means = vals[..., 3:5] - + # counts are: n = sum_k n_k counts_agg = torch.sum(counts, dim=0) # means are: mu = sum_i n_i * mu_i / n @@ -280,7 +280,7 @@ def finalize(self, vals, counts): return vals[..., 0] / torch.sqrt(vals[..., 1] * vals[..., 2]) def forward(self, x: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor] = None) -> torch.Tensor: - + if hasattr(self, "bias"): x = x - self.bias y = y - self.bias @@ -349,7 +349,7 @@ def __init__( @property def type(self): return LossType.Probabilistic - + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, weight: Optional[torch.Tensor] = None) -> torch.Tensor: # sanity checks @@ -498,7 +498,7 @@ def __init__( ): super().__init__() - self.metric_func = EnsembleCRPSLoss( + self.metric_func = CRPSLoss( img_shape=img_shape, crop_shape=crop_shape, crop_offset=crop_offset, diff --git a/tests/distributed/tests_distributed_losses.py b/tests/distributed/tests_distributed_losses.py index ac9a8aa..ba69485 100644 --- a/tests/distributed/tests_distributed_losses.py +++ b/tests/distributed/tests_distributed_losses.py @@ -31,7 +31,7 @@ from physicsnemo.distributed.mappings import gather_from_parallel_region, scatter_to_parallel_region, reduce_from_parallel_region from makani.utils.grids import GridQuadrature -from makani.utils.losses import EnsembleCRPSLoss, EnsembleNLLLoss +from makani.utils.losses import CRPSLoss, EnsembleNLLLoss from distributed_helpers import split_helper, gather_helper @@ -116,7 +116,7 @@ def _gather_helper_bwd(self, tensor, ensemble=False): return tensor_gather - + @parameterized.expand( [ [128, 256, 32, 8, 1e-6, "naive", False], @@ -149,7 +149,7 @@ def test_distributed_quadrature(self, nlat, nlon, batch_size, num_chan, tol, qua ograd_full = torch.randn_like(out_full) out_full.backward(ograd_full) igrad_full = inp_full.grad.clone() - + # distributed out_local = quad_dist(inp_local) out_local.backward(ograd_full) @@ -172,7 +172,7 @@ def test_distributed_quadrature(self, nlat, nlon, batch_size, num_chan, tol, qua ############################################################# with torch.no_grad(): igrad_gather_full = self._gather_helper_bwd(igrad_local, False) - + # compute errors err = fn.relative_error(igrad_gather_full, igrad_full) if verbose and (self.world_rank == 0): @@ -204,7 +204,7 @@ def test_distributed_crps(self, nlat, nlon, batch_size, num_chan, ens_size, tol, if loss_type == "ensemble_crps": # local loss - loss_fn_local = EnsembleCRPSLoss( + loss_fn_local = CRPSLoss( img_shape=(H, W), crop_shape=None, crop_offset=(0, 0), @@ -218,7 +218,7 @@ def test_distributed_crps(self, nlat, nlon, batch_size, num_chan, ens_size, tol, ).to(self.device) # distributed loss - loss_fn_dist = EnsembleCRPSLoss( + loss_fn_dist = CRPSLoss( img_shape=(H, W), crop_shape=None, crop_offset=(0, 0), @@ -232,7 +232,7 @@ def test_distributed_crps(self, nlat, nlon, batch_size, num_chan, ens_size, tol, ).to(self.device) elif loss_type == "gauss_crps": # local loss - loss_fn_local = EnsembleCRPSLoss( + loss_fn_local = CRPSLoss( img_shape=(H, W), crop_shape=None, crop_offset=(0, 0), @@ -246,7 +246,7 @@ def test_distributed_crps(self, nlat, nlon, batch_size, num_chan, ens_size, tol, ).to(self.device) # distributed loss - loss_fn_dist = EnsembleCRPSLoss( + loss_fn_dist = CRPSLoss( img_shape=(H, W), crop_shape=None, crop_offset=(0, 0), @@ -260,7 +260,7 @@ def test_distributed_crps(self, nlat, nlon, batch_size, num_chan, ens_size, tol, ).to(self.device) elif loss_type == "skillspread_crps": # local loss - loss_fn_local = EnsembleCRPSLoss( + loss_fn_local = CRPSLoss( img_shape=(H, W), crop_shape=None, crop_offset=(0, 0), @@ -274,7 +274,7 @@ def test_distributed_crps(self, nlat, nlon, batch_size, num_chan, ens_size, tol, ).to(self.device) # distributed loss - loss_fn_dist = EnsembleCRPSLoss( + loss_fn_dist = CRPSLoss( img_shape=(H, W), crop_shape=None, crop_offset=(0, 0), diff --git a/tests/test_losses.py b/tests/test_losses.py index 0062333..9c3ffdc 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -24,7 +24,7 @@ from makani.models import model_registry from makani.utils import LossHandler -from makani.utils.losses import EnsembleCRPSLoss +from makani.utils.losses import CRPSLoss from testutils import get_default_parameters @@ -138,7 +138,7 @@ def test_loss_batchsize_independence(self, losses, uncertainty_weighting=False): # test initialization of loss object loss_obj = LossHandler(self.params) - + shape = (self.params.batch_size, self.params.N_out_channels, self.params.img_shape_x, self.params.img_shape_y) inp = torch.randn(*shape) @@ -150,14 +150,14 @@ def test_loss_batchsize_independence(self, losses, uncertainty_weighting=False): out2 = loss_obj(tar2, inp2) self.assertTrue(torch.allclose(out, out2)) - + @parameterized.expand(_loss_weighted_params) def test_loss_weighted(self, losses, uncertainty_weighting=False): """ Tests initialization of loss, as well as the forward and backward pass """ - + self.params.losses = losses self.params.uncertainty_weighting = uncertainty_weighting @@ -179,7 +179,7 @@ def test_loss_weighted(self, losses, uncertainty_weighting=False): # compute weighted loss out_weighted = loss_obj(tar, inp, wgt) - + self.assertTrue(torch.allclose(out, out_weighted)) @parameterized.expand(_loss_weighted_params) @@ -211,44 +211,44 @@ def test_loss_multistep(self, losses, uncertainty_weighting=False): # compute weighted loss out_weighted = loss_obj(tar, inp, wgt) - self.assertTrue(torch.allclose(out, out_weighted)) + self.assertTrue(torch.allclose(out, out_weighted)) def test_running_stats(self): """ Tests computation of the running stats """ - + self.params.losses = [{"type": "l2"}] - + # test initialization of loss object loss_obj = LossHandler(self.params, track_running_stats=True) loss_obj.train() - + shape = (self.params.batch_size, self.params.N_out_channels, self.params.img_shape_x, self.params.img_shape_y) - + # this needs to be sufficiently large to mitigarte the bias due to the initialization of the running stats num_samples = 100 for i in range(num_samples): - + inp = i * torch.ones(*shape) inp.requires_grad = True tar = torch.zeros(*shape) tar.requires_grad = True - + # forward pass and check shapes out = loss_obj(tar, inp) - + # generate simulated dataset data = torch.arange(num_samples).float().reshape(1, 1, -1).repeat(self.params.batch_size, self.params.N_out_channels, 1) expected_var, expected_mean = torch.var_mean(data, correction=0, dim=(0, -1)) - + var, mean = loss_obj.get_running_stats() - + self.assertTrue(torch.allclose(mean, expected_mean)) self.assertTrue(torch.allclose(var, expected_var)) def test_ensemble_crps(self): - crps_func = EnsembleCRPSLoss( + crps_func = CRPSLoss( img_shape=(self.params.img_shape_x, self.params.img_shape_y), crop_shape=(self.params.img_shape_x, self.params.img_shape_y), crop_offset=(0, 0), @@ -260,24 +260,24 @@ def test_ensemble_crps(self): ensemble_distributed=False, ensemble_weights=None, ) - + for ensemble_size in [1, 10]: with self.subTest(f"{ensemble_size}"): # generate input tensor inp = torch.empty((self.params.batch_size, ensemble_size, self.params.N_in_channels, self.params.img_shape_x, self.params.img_shape_y), dtype=torch.float32) with torch.no_grad(): inp.normal_(1.0, 1.0) - + # target tensor tar = torch.ones((self.params.batch_size, self.params.N_in_channels, self.params.img_shape_x, self.params.img_shape_y), dtype=torch.float32) - + # torch result result = crps_func(inp, tar).cpu().numpy() - + # properscoring result tar_arr = tar.cpu().numpy() inp_arr = inp.cpu().numpy() - + # I think this is a bug with the axis index in properscoring # for the degenerate case: if ensemble_size == 1: @@ -285,19 +285,19 @@ def test_ensemble_crps(self): inp_arr = np.squeeze(inp_arr, axis=1) else: axis = 1 - + result_proper = crps_ensemble(tar_arr, inp_arr, weights=None, issorted=False, axis=axis) quad_weight_arr = crps_func.quadrature.quad_weight.cpu().numpy() result_proper = np.sum(result_proper * quad_weight_arr, axis=(2, 3)) - + self.assertTrue(np.allclose(result, result_proper, rtol=1e-5, atol=0)) def test_gauss_crps(self): - + # protext against sigma=0 eps = 1.0e-5 - - crps_func = EnsembleCRPSLoss( + + crps_func = CRPSLoss( img_shape=(self.params.img_shape_x, self.params.img_shape_y), crop_shape=(self.params.img_shape_x, self.params.img_shape_y), crop_offset=(0, 0), @@ -309,32 +309,32 @@ def test_gauss_crps(self): ensemble_distributed=False, eps=eps, ) - + for ensemble_size in [1, 10]: with self.subTest(f"{ensemble_size}"): # generate input tensor inp = torch.empty((self.params.batch_size, ensemble_size, self.params.N_in_channels, self.params.img_shape_x, self.params.img_shape_y), dtype=torch.float32) with torch.no_grad(): inp.normal_(1.0, 1.0) - + # target tensor tar = torch.ones((self.params.batch_size, self.params.N_in_channels, self.params.img_shape_x, self.params.img_shape_y), dtype=torch.float32) - + # torch result result = crps_func(inp, tar).cpu().numpy() - + # properscoring result tar_arr = tar.cpu().numpy() inp_arr = inp.cpu().numpy() - + # compute mu, sigma, guard against underflows mu = np.mean(inp_arr, axis=1) sigma = np.maximum(np.sqrt(np.var(inp_arr, axis=1)), eps) - + result_proper = crps_gaussian(tar_arr, mu, sigma, grad=False) quad_weight_arr = crps_func.quadrature.quad_weight.cpu().numpy() result_proper = np.sum(result_proper * quad_weight_arr, axis=(2, 3)) - + self.assertTrue(np.allclose(result, result_proper, rtol=1e-5, atol=0)) From ea87f873c29352e43461b7eeae2a612a0380ed73 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Sun, 21 Dec 2025 15:38:11 -0800 Subject: [PATCH 30/32] implemented improved Sobolev energy score --- makani/models/noise.py | 2 +- makani/utils/grids.py | 77 ++++++- makani/utils/loss.py | 8 +- makani/utils/losses/__init__.py | 2 +- makani/utils/losses/crps_loss.py | 43 +--- makani/utils/losses/energy_score.py | 309 ++++++++++++++++++++-------- 6 files changed, 314 insertions(+), 127 deletions(-) diff --git a/makani/models/noise.py b/makani/models/noise.py index a1f3d52..5f49daa 100644 --- a/makani/models/noise.py +++ b/makani/models/noise.py @@ -198,7 +198,7 @@ def __init__( power_spectrum = torch.pow(2 * ls + 1, -alpha) norm_factor = torch.sum((2 * ls + 1) * power_spectrum / 4.0 / math.pi) sigma_l = sigma * torch.sqrt(power_spectrum / norm_factor) - sigma_l = torch.where(ls <= ms, sigma_l, 0.0) + sigma_l = torch.where(ms <= ls, sigma_l, 0.0) # the new shape is B, T, C, L, M sigma_l = sigma_l.reshape((1, 1, 1, self.lmax, self.mmax)).to(dtype=torch.float32) diff --git a/makani/utils/grids.py b/makani/utils/grids.py index 4446fff..b3ff154 100644 --- a/makani/utils/grids.py +++ b/makani/utils/grids.py @@ -16,10 +16,12 @@ import numpy as np import torch -from torch_harmonics.quadrature import legendre_gauss_weights, clenshaw_curtiss_weights +import torch.amp as amp + +from torch_harmonics.quadrature import legendre_gauss_weights, clenshaw_curtiss_weights, _precompute_latitudes from makani.utils import comm -from physicsnemo.distributed.utils import compute_split_shapes +from physicsnemo.distributed.utils import compute_split_shapes, split_tensor_along_dim from physicsnemo.distributed.mappings import reduce_from_parallel_region @@ -49,7 +51,7 @@ def compute_spherical_bandlimit(img_shape, grid_type): class GridConverter(torch.nn.Module): def __init__(self, src_grid, dst_grid, lat_rad, lon_rad): - super(GridConverter, self).__init__() + super().__init__() self.src = src_grid self.dst = dst_grid self.src_lat = lat_rad @@ -137,7 +139,7 @@ def __init__(self, quadrature_rule, img_shape, crop_shape=None, crop_offset=(0, # apply pole mask if (pole_mask is not None) and (pole_mask > 0): quad_weight[:pole_mask, :] = 0.0 - quad_weight[sizes[0] - pole_mask :, :] = 0.0 + quad_weight[img_shape[0] - pole_mask :, :] = 0.0 # if distributed, make sure to split correctly across ranks: # in case of model parallelism, we need to make sure that we use the correct shapes per rank @@ -179,3 +181,70 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: quad = reduce_from_parallel_region(quad.contiguous(), "spatial") return quad + + +class BandLimitMask(torch.nn.Module): + def __init__(self, img_shape, grid_type, lmax = None, type="sht"): + super().__init__() + self.img_shape = img_shape + self.grid_type = grid_type + self.lmax = lmax if lmax is not None else compute_spherical_bandlimit(img_shape, grid_type) + self.type = type + + if self.type == "sht": + # SHT for the computation of SH coefficients + if (comm.get_size("spatial") > 1): + from torch_harmonics.distributed import DistributedRealSHT, DistributedInverseRealSHT + import torch_harmonics.distributed as thd + if not thd.is_initialized(): + polar_group = None if (comm.get_size("h") == 1) else comm.get_group("h") + azimuth_group = None if (comm.get_size("w") == 1) else comm.get_group("w") + thd.init(polar_group, azimuth_group) + self.forward_transform = DistributedRealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type).float() + self.inverse_transform = DistributedInverseRealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type).float() + else: + from torch_harmonics import RealSHT, InverseRealSHT + + self.forward_transform = RealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type).float() + self.inverse_transform = InverseRealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type).float() + + elif self.type == "fft": + + # get the cutoff frequency in m for each latitude + lats, _ = _precompute_latitudes(self.img_shape[0], grid=self.grid_type) + # get the grid spacing at the equator + delta_equator = 2 * torch.pi / (self.lmax-1) + mlim = torch.ceil(2 * torch.pi * torch.sin(lats) / delta_equator).reshape(self.img_shape[0], 1) + ms = torch.arange(self.lmax).reshape(1, -1) + mask = (ms <= mlim) + mask = split_tensor_along_dim(mask, dim=-1, num_chunks=comm.get_size("w"))[comm.get_rank("w")] + mask = split_tensor_along_dim(mask, dim=-2, num_chunks=comm.get_size("h"))[comm.get_rank("h")] + self.register_buffer("mask", mask, persistent=False) + + if (comm.get_size("spatial") > 1): + from makani.mpu.fft import DistributedRealFFT1, DistributedInverseRealFFT1 + self.forward_transform = DistributedRealFFT1(img_shape[1], lmax=lmax, mmax=lmax).float() + self.inverse_transform = DistributedInverseRealFFT1(img_shape[1], lmax=lmax, mmax=lmax).float() + else: + from makani.models.common.fft import RealFFT1, InverseRealFFT1 + self.forward_transform = RealFFT1(img_shape[1], lmax=lmax, mmax=lmax).float() + self.inverse_transform = InverseRealFFT1(img_shape[1], lmax=lmax, mmax=lmax).float() + else: + raise ValueError(f"Unknown truncation type {self.type}") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + with amp.autocast(device_type="cuda", enabled=False): + dtype = x.dtype + x = x.float() + + x = self.forward_transform(x) + + if hasattr(self, "mask"): + x = torch.where(self.mask, x, torch.zeros_like(x)) + + x = self.inverse_transform(x) + + x = x.to(dtype=dtype) + + return x \ No newline at end of file diff --git a/makani/utils/loss.py b/makani/utils/loss.py index b0056e0..22e30f8 100644 --- a/makani/utils/loss.py +++ b/makani/utils/loss.py @@ -23,7 +23,7 @@ from torch import nn from makani.utils import comm -from makani.utils.grids import GridQuadrature +from makani.utils.grids import GridQuadrature, BandLimitMask from makani.utils.dataloaders.data_helpers import get_data_normalization, get_time_diff_stds from physicsnemo.distributed.utils import compute_split_shapes from physicsnemo.distributed.mappings import gather_from_parallel_region, reduce_from_parallel_region @@ -33,7 +33,7 @@ from .losses import LossType, GeometricLpLoss, SpectralLpLoss, SpectralH1Loss, SpectralAMSELoss from .losses import CRPSLoss, SpectralCRPSLoss, GradientCRPSLoss, VortDivCRPSLoss -from .losses import LpEnergyScoreLoss, H1EnergyScoreLoss +from .losses import LpEnergyScoreLoss, SobolevEnergyScoreLoss from .losses import GaussianMMDLoss from .losses import EnsembleNLLLoss from .losses import DriftRegularization, HydrostaticBalanceLoss @@ -244,8 +244,8 @@ def _parse_loss_type(self, loss_type: str): loss_handle = GaussianMMDLoss elif "energy_score" in loss_type: loss_handle = partial(LpEnergyScoreLoss) - elif "h1_energy_score" in loss_type: - loss_handle = partial(H1EnergyScoreLoss) + elif "sobolev_energy_score" in loss_type: + loss_handle = partial(SobolevEnergyScoreLoss) elif "drift_regularization" in loss_type: loss_handle = DriftRegularization else: diff --git a/makani/utils/losses/__init__.py b/makani/utils/losses/__init__.py index cb11cde..048481b 100644 --- a/makani/utils/losses/__init__.py +++ b/makani/utils/losses/__init__.py @@ -19,7 +19,7 @@ from .amse_loss import SpectralAMSELoss from .hydrostatic_loss import HydrostaticBalanceLoss from .crps_loss import CRPSLoss, SpectralCRPSLoss, GradientCRPSLoss, VortDivCRPSLoss -from .energy_score import LpEnergyScoreLoss, H1EnergyScoreLoss +from .energy_score import LpEnergyScoreLoss, SobolevEnergyScoreLoss from .mmd_loss import GaussianMMDLoss from .likelihood_loss import EnsembleNLLLoss from .drift_regularization import DriftRegularization diff --git a/makani/utils/losses/crps_loss.py b/makani/utils/losses/crps_loss.py index e71470e..47af667 100644 --- a/makani/utils/losses/crps_loss.py +++ b/makani/utils/losses/crps_loss.py @@ -176,37 +176,6 @@ def _crps_probability_weighted_moment_kernel(observation: torch.Tensor, forecast return crps - -def _crps_independent_skillspread_kernel(observation: torch.Tensor, forecasts: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: - """ - CRPS which uses separate samples for the estimation of spread and skill. Only one sample is used for the estimation of the skill - """ - - observation = observation.unsqueeze(0) - - # get nanmask from the observarions - nanmasks = torch.logical_or(torch.isnan(observation), torch.isnan(weights)) - - # compute total weights - nweights = torch.where(nanmasks, 0.0, weights) - total_weight = torch.sum(nweights, dim=0, keepdim=True) - - # ensemble size - num_ensemble = forecasts.shape[0] - - # use broadcasting semantics to compute spread and skill - espread = (forecasts[1:].unsqueeze(1) - forecasts[1:].unsqueeze(0)).abs().sum(dim=(0,1)) / float((num_ensemble - 1)*(num_ensemble - 2)) - eskill = (observation - forecasts[0:1]).abs().mean(dim=0) - - # crps = torch.where(nanmasks.sum(dim=0) != 0, torch.nan, eskill - 0.5 * espread) - crps = eskill - 0.5 * espread - - # set to nan for first forecasts nan - crps = torch.where(nanmasks.sum(dim=0) != 0, 0.0, crps) - - return crps - - def _crps_naive_skillspread_kernel(observation: torch.Tensor, forecasts: torch.Tensor, weights: torch.Tensor, alpha: float) -> torch.Tensor: """ alternative fair CRPS variant that uses spread and skill. Uses naive computation which is O(N^2) in the number of ensemble members. Useful for complex @@ -392,6 +361,14 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w # compute score crps = _crps_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) + elif self.crps_type == "naive skillspread": + if self.ensemble_weights is not None: + raise NotImplementedError("currently only constant ensemble weights are supported") + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_naive_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) elif self.crps_type == "gauss": if self.ensemble_weights is not None: ensemble_weights = self.ensemble_weights[idx] @@ -472,8 +449,8 @@ def __init__( # get the local l weights lmax = self.sht.lmax ls = torch.arange(lmax).reshape(-1, 1) - l_weights = 1 / (2*ls+1) - # l_weights = torch.ones(lmax).reshape(-1, 1) + # l_weights = 1 / (2*ls+1) + l_weights = torch.ones(lmax).reshape(-1, 1) if comm.get_size("h") > 1: l_weights = split_tensor_along_dim(l_weights, dim=-2, num_chunks=comm.get_size("h"))[comm.get_rank("h")] self.register_buffer("l_weights", l_weights, persistent=False) diff --git a/makani/utils/losses/energy_score.py b/makani/utils/losses/energy_score.py index 758c548..a3d8cb2 100644 --- a/makani/utils/losses/energy_score.py +++ b/makani/utils/losses/energy_score.py @@ -20,12 +20,9 @@ import torch.nn as nn from torch import amp -from makani.utils.losses.base_loss import GeometricBaseLoss, GradientBaseLoss, LossType +from makani.utils.losses.base_loss import GeometricBaseLoss, SpectralBaseLoss, LossType from makani.utils import comm -import torch_harmonics as th -import torch_harmonics.distributed as thd - # distributed stuff from physicsnemo.distributed.utils import compute_split_shapes, split_tensor_along_dim from physicsnemo.distributed.mappings import scatter_to_parallel_region, reduce_from_parallel_region @@ -173,13 +170,13 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w eskill = eskill.sum(dim=-1, keepdim=True) # now we have reduced everything and need to sum appropriately - espread = espread.pow(1/self.beta).sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) - eskill = eskill.pow(1/self.beta).sum(dim=0) / float(num_ensemble) + espread = espread.sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) + eskill = eskill.sum(dim=0) / float(num_ensemble) # the resulting tensor should have dimension B, C which is what we return return eskill - 0.5 * espread -class H1EnergyScoreLoss(GradientBaseLoss): +class SobolevEnergyScoreLoss(SpectralBaseLoss): def __init__( self, @@ -188,15 +185,13 @@ def __init__( crop_offset: Tuple[int, int], channel_names: List[str], grid_type: str, - pole_mask: int, lmax: Optional[int] = None, - crps_type: str = "skillspread", spatial_distributed: Optional[bool] = False, ensemble_distributed: Optional[bool] = False, ensemble_weights: Optional[torch.Tensor] = None, - absolute: Optional[bool] = True, alpha: Optional[float] = 1.0, beta: Optional[float] = 2.0, + fraction: Optional[float] = 1.0, eps: Optional[float] = 1.0e-5, **kwargs, ): @@ -207,32 +202,34 @@ def __init__( crop_offset=crop_offset, channel_names=channel_names, grid_type=grid_type, - pole_mask=pole_mask, lmax=lmax, spatial_distributed=spatial_distributed, ) - # if absolute is true, the loss is computed only on the absolute value of the gradient - self.absolute = absolute - - self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed - self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed + self.spatial_distributed = spatial_distributed and comm.is_distributed("spatial") + self.ensemble_distributed = ensemble_distributed and comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) self.alpha = alpha self.beta = beta + self.fraction = fraction self.eps = eps - # we also need a variant of the weights split in ensemble direction: - quad_weight_split = self.quadrature.quad_weight.reshape(1, 1, -1) - if self.ensemble_distributed: - quad_weight_split = split_tensor_along_dim(quad_weight_split, dim=-1, num_chunks=comm.get_size("ensemble"))[comm.get_rank("ensemble")] - quad_weight_split = quad_weight_split.contiguous() - self.register_buffer("quad_weight_split", quad_weight_split, persistent=False) - if ensemble_weights is not None: self.register_buffer("ensemble_weights", ensemble_weights, persistent=False) else: self.ensemble_weights = ensemble_weights + # get the local lm weights + ls = torch.arange(self.sht.lmax).reshape(-1, 1) + ms = torch.arange(self.sht.mmax).reshape(1, -1) + lm_weights = (1 + ls * (ls + 1)).pow(self.fraction).tile(1, self.sht.mmax) + lm_weights[:, 1:] *= 2.0 + lm_weights = torch.where(ms > ls, 0.0, lm_weights) + if comm.get_size("h") > 1: + lm_weights = split_tensor_along_dim(lm_weights, dim=-2, num_chunks=comm.get_size("h"))[comm.get_rank("h")] + if comm.get_size("w") > 1: + lm_weights = split_tensor_along_dim(lm_weights, dim=-1, num_chunks=comm.get_size("w"))[comm.get_rank("w")] + self.register_buffer("lm_weights", lm_weights, persistent=False) + @property def type(self): return LossType.Probabilistic @@ -245,91 +242,67 @@ def n_channels(self): def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: str) -> torch.Tensor: return torch.ones(1) - def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, ensemble_weights: Optional[torch.Tensor] = None) -> torch.Tensor: # sanity checks if forecasts.dim() != 5: raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.") - # we assume that spatial_weights have NO ensemble dim - if (spatial_weights is not None) and (spatial_weights.dim() != observations.dim()): - spdim = spatial_weights.dim() - odim = observations.dim() - raise ValueError(f"the weights have to have the same number of dimensions (found {spdim}) as observations (found {odim}).") - - # we assume the following shapes: - # forecasts: batch, ensemble, channels, lat, lon - # observations: batch, channels, lat, lon - B, E, C, H, W = forecasts.shape - # get the data type before stripping amp types dtype = forecasts.dtype # before anything else compute the transform # as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same with amp.autocast(device_type="cuda", enabled=False): + # TODO: check 4 pi normalization + forecasts = self.sht(forecasts.float()) / 4.0 / math.pi + observations = self.sht(observations.float()) / 4.0 / math.pi - # compute the SH coefficients of the forecasts and observations - forecasts = self.sht(forecasts.float()).unsqueeze(-3) - observations = self.sht(observations.float()).unsqueeze(-3) - - # append zeros, so that we can use the inverse vector SHT - forecasts = torch.cat([forecasts, torch.zeros_like(forecasts)], dim=-3) - observations = torch.cat([observations, torch.zeros_like(observations)], dim=-3) - - forecasts = self.ivsht(forecasts) - observations = self.ivsht(observations) - - forecasts = forecasts.to(dtype) - observations = observations.to(dtype) + # cast back to original dtype + forecasts = forecasts.to(dtype=dtype) + observations = observations.to(dtype=dtype) - forecasts = forecasts.reshape(B, E, 2*C, H, W) - observations = observations.reshape(B, 2*C, H, W) + # we assume the following shapes: + # forecasts: batch, ensemble, channels, mmax, lmax + # observations: batch, channels, mmax, lmax + B, E, C, H, W = forecasts.shape # transpose the forecasts to ensemble, batch, channels, lat, lon and then do distributed transpose into ensemble direction. # ideally we split spatial dims forecasts = torch.moveaxis(forecasts, 1, 0) - forecasts = forecasts.reshape(E, B, 2*C, H * W) + forecasts = forecasts.reshape(E, B, C, H * W) if self.ensemble_distributed: ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] - forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") + forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") # for correct spatial reduction we need to do the same with spatial weights + + lm_weights_split = self.lm_weights.flatten(start_dim=-2, end_dim=-1) + lm_weights_split = scatter_to_parallel_region(lm_weights_split, -1, "ensemble") # observations does not need a transpose, but just a split - observations = observations.reshape(1, B, 2*C, H * W) + observations = observations.reshape(1, B, C, H * W) if self.ensemble_distributed: observations = scatter_to_parallel_region(observations, -1, "ensemble") - # for correct spatial reduction we need to do the same with spatial weights - if spatial_weights is not None: - spatial_weights_split = spatial_weights.flatten(start_dim=-2, end_dim=-1) - spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") - - if self.ensemble_weights is not None: - raise NotImplementedError("currently only constant ensemble weights are supported") - else: - ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) - - # ensemble size num_ensemble = forecasts.shape[0] # get nanmask from the observarions - nanmasks = torch.logical_or(torch.isnan(observations), torch.isnan(ensemble_weights)) + nanmasks = torch.logical_or(torch.isnan(observations), torch.isnan(forecasts)) - # use broadcasting semantics to compute spread and skill and sum over channels (vector norm) - espread = (forecasts.unsqueeze(1) - forecasts.unsqueeze(0)).abs().pow(self.beta) - eskill = (observations - forecasts).abs().pow(self.beta) + # compute the individual distances + espread = lm_weights_split * (forecasts.unsqueeze(1) - forecasts.unsqueeze(0)).abs().pow(self.beta) + eskill = lm_weights_split * (observations - forecasts).abs().pow(self.beta) # perform masking before any reduction espread = torch.where(nanmasks.sum(dim=0) != 0, 0.0, espread) eskill = torch.where(nanmasks.sum(dim=0) != 0, 0.0, eskill) + # do the channel reduction first + espread = espread.sum(dim=-3, keepdim=True) + eskill = eskill.sum(dim=-3, keepdim=True) + # do the spatial reduction - if spatial_weights is not None: - espread = torch.sum(espread * self.quad_weight_split * spatial_weights_split, dim=-1) - eskill = torch.sum(eskill * self.quad_weight_split * spatial_weights_split, dim=-1) - else: - espread = torch.sum(espread * self.quad_weight_split, dim=-1) - eskill = torch.sum(eskill * self.quad_weight_split, dim=-1) + espread = espread.sum(dim=-1, keepdim=False) + eskill = eskill.sum(dim=-1, keepdim=False) # since we split spatial dim into ensemble dim, we need to do an ensemble sum as well if self.ensemble_distributed: @@ -337,19 +310,187 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w eskill = reduce_from_parallel_region(eskill, "ensemble") # we need to do the spatial averaging manually since - # we are not calling the quadrature forward function if self.spatial_distributed: espread = reduce_from_parallel_region(espread, "spatial") eskill = reduce_from_parallel_region(eskill, "spatial") - # do the channel reduction while ignoring NaNs - # if channel weights are required they should be added here to the reduction - espread = espread.sum(dim=-1, keepdim=True) - eskill = eskill.sum(dim=-1, keepdim=True) - # now we have reduced everything and need to sum appropriately - espread = espread.pow(1/self.beta).sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) - eskill = eskill.pow(1/self.beta).sum(dim=0) / float(num_ensemble) + espread = espread.sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) + eskill = eskill.sum(dim=0) / float(num_ensemble) + + return eskill - 0.5 * espread - # the resulting tensor should have dimension B, 1 which is what we return - return eskill - 0.5 * espread \ No newline at end of file +# class H1EnergyScoreLoss(GradientBaseLoss): + +# def __init__( +# self, +# img_shape: Tuple[int, int], +# crop_shape: Tuple[int, int], +# crop_offset: Tuple[int, int], +# channel_names: List[str], +# grid_type: str, +# pole_mask: int, +# lmax: Optional[int] = None, +# crps_type: str = "skillspread", +# spatial_distributed: Optional[bool] = False, +# ensemble_distributed: Optional[bool] = False, +# ensemble_weights: Optional[torch.Tensor] = None, +# absolute: Optional[bool] = True, +# alpha: Optional[float] = 1.0, +# beta: Optional[float] = 2.0, +# eps: Optional[float] = 1.0e-5, +# **kwargs, +# ): + +# super().__init__( +# img_shape=img_shape, +# crop_shape=crop_shape, +# crop_offset=crop_offset, +# channel_names=channel_names, +# grid_type=grid_type, +# pole_mask=pole_mask, +# lmax=lmax, +# spatial_distributed=spatial_distributed, +# ) + +# # if absolute is true, the loss is computed only on the absolute value of the gradient +# self.absolute = absolute + +# self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed +# self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed +# self.alpha = alpha +# self.beta = beta +# self.eps = eps + +# # we also need a variant of the weights split in ensemble direction: +# quad_weight_split = self.quadrature.quad_weight.reshape(1, 1, -1) +# if self.ensemble_distributed: +# quad_weight_split = split_tensor_along_dim(quad_weight_split, dim=-1, num_chunks=comm.get_size("ensemble"))[comm.get_rank("ensemble")] +# quad_weight_split = quad_weight_split.contiguous() +# self.register_buffer("quad_weight_split", quad_weight_split, persistent=False) + +# if ensemble_weights is not None: +# self.register_buffer("ensemble_weights", ensemble_weights, persistent=False) +# else: +# self.ensemble_weights = ensemble_weights + +# @property +# def type(self): +# return LossType.Probabilistic + +# @property +# def n_channels(self): +# return 1 + +# @torch.compiler.disable(recursive=False) +# def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: str) -> torch.Tensor: +# return torch.ones(1) + +# def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + +# # sanity checks +# if forecasts.dim() != 5: +# raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.") + +# # we assume that spatial_weights have NO ensemble dim +# if (spatial_weights is not None) and (spatial_weights.dim() != observations.dim()): +# spdim = spatial_weights.dim() +# odim = observations.dim() +# raise ValueError(f"the weights have to have the same number of dimensions (found {spdim}) as observations (found {odim}).") + +# # we assume the following shapes: +# # forecasts: batch, ensemble, channels, lat, lon +# # observations: batch, channels, lat, lon +# B, E, C, H, W = forecasts.shape + +# # get the data type before stripping amp types +# dtype = forecasts.dtype + +# # before anything else compute the transform +# # as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same +# with amp.autocast(device_type="cuda", enabled=False): + +# # compute the SH coefficients of the forecasts and observations +# forecasts = self.sht(forecasts.float()).unsqueeze(-3) +# observations = self.sht(observations.float()).unsqueeze(-3) + +# # append zeros, so that we can use the inverse vector SHT +# forecasts = torch.cat([forecasts, torch.zeros_like(forecasts)], dim=-3) +# observations = torch.cat([observations, torch.zeros_like(observations)], dim=-3) + +# forecasts = self.ivsht(forecasts) +# observations = self.ivsht(observations) + +# forecasts = forecasts.to(dtype) +# observations = observations.to(dtype) + +# forecasts = forecasts.reshape(B, E, 2*C, H, W) +# observations = observations.reshape(B, 2*C, H, W) + +# # transpose the forecasts to ensemble, batch, channels, lat, lon and then do distributed transpose into ensemble direction. +# # ideally we split spatial dims +# forecasts = torch.moveaxis(forecasts, 1, 0) +# forecasts = forecasts.reshape(E, B, 2*C, H * W) +# if self.ensemble_distributed: +# ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] +# forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") + +# # observations does not need a transpose, but just a split +# observations = observations.reshape(1, B, 2*C, H * W) +# if self.ensemble_distributed: +# observations = scatter_to_parallel_region(observations, -1, "ensemble") + +# # for correct spatial reduction we need to do the same with spatial weights +# if spatial_weights is not None: +# spatial_weights_split = spatial_weights.flatten(start_dim=-2, end_dim=-1) +# spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") + +# if self.ensemble_weights is not None: +# raise NotImplementedError("currently only constant ensemble weights are supported") +# else: +# ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + +# # ensemble size +# num_ensemble = forecasts.shape[0] + +# # get nanmask from the observarions +# nanmasks = torch.logical_or(torch.isnan(observations), torch.isnan(ensemble_weights)) + +# # use broadcasting semantics to compute spread and skill and sum over channels (vector norm) +# espread = (forecasts.unsqueeze(1) - forecasts.unsqueeze(0)).abs().pow(self.beta) +# eskill = (observations - forecasts).abs().pow(self.beta) + +# # perform masking before any reduction +# espread = torch.where(nanmasks.sum(dim=0) != 0, 0.0, espread) +# eskill = torch.where(nanmasks.sum(dim=0) != 0, 0.0, eskill) + +# # do the spatial reduction +# if spatial_weights is not None: +# espread = torch.sum(espread * self.quad_weight_split * spatial_weights_split, dim=-1) +# eskill = torch.sum(eskill * self.quad_weight_split * spatial_weights_split, dim=-1) +# else: +# espread = torch.sum(espread * self.quad_weight_split, dim=-1) +# eskill = torch.sum(eskill * self.quad_weight_split, dim=-1) + +# # since we split spatial dim into ensemble dim, we need to do an ensemble sum as well +# if self.ensemble_distributed: +# espread = reduce_from_parallel_region(espread, "ensemble") +# eskill = reduce_from_parallel_region(eskill, "ensemble") + +# # we need to do the spatial averaging manually since +# # we are not calling the quadrature forward function +# if self.spatial_distributed: +# espread = reduce_from_parallel_region(espread, "spatial") +# eskill = reduce_from_parallel_region(eskill, "spatial") + +# # do the channel reduction while ignoring NaNs +# # if channel weights are required they should be added here to the reduction +# espread = espread.sum(dim=-1, keepdim=True) +# eskill = eskill.sum(dim=-1, keepdim=True) + +# # now we have reduced everything and need to sum appropriately +# espread = espread.pow(1/self.beta).sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) +# eskill = eskill.pow(1/self.beta).sum(dim=0) / float(num_ensemble) + +# # the resulting tensor should have dimension B, 1 which is what we return +# return eskill - 0.5 * espread \ No newline at end of file From 790ee031027f2411dff77cf2fa63f808ce38e702 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Tue, 23 Dec 2025 05:56:52 -0800 Subject: [PATCH 31/32] added powerspectrum to stats computation --- data_process/get_stats.py | 108 ++++++++++++++++++++++++++++---------- 1 file changed, 79 insertions(+), 29 deletions(-) diff --git a/data_process/get_stats.py b/data_process/get_stats.py index 59d9d15..634d2e5 100644 --- a/data_process/get_stats.py +++ b/data_process/get_stats.py @@ -32,6 +32,7 @@ import torch import torch.distributed as dist +from torch_harmonics import RealSHT from makani.utils.grids import GridQuadrature from wb2_helpers import DistributedProgressBar @@ -50,7 +51,7 @@ def allgather_dict(stats, group): for _ in range(dist.get_world_size(group)): substats = {varname: {} for varname in stats.keys()} stats_gather.append(substats) - + # iterate over full dict for varname, substats in stats.items(): for k,v in substats.items(): @@ -179,6 +180,14 @@ def welford_combine(stats1, stats2): return stats +def compute_powerspectrum(x, sht): + coeffs = sht(x).abs().pow(2) + # account for hermitian symetry + coeffs[..., 1:] *= 2.0 + power_spectrum = coeffs.sum(dim=-1) + return power_spectrum + + def get_file_stats(filename, file_slice, wind_indices, @@ -187,7 +196,8 @@ def get_file_stats(filename, dt=1, batch_size=16, device=torch.device("cpu"), - progress=None): + progress=None, + sht=None): stats = None with h5.File(filename, 'r') as f: @@ -203,7 +213,7 @@ def get_file_stats(filename, if batch_size is None: batch_size = slc_stop - slc_start - + for batch_start in range(slc_start, slc_stop, batch_size): batch_stop = min(batch_start+batch_size, slc_stop) sub_slc = slice(batch_start, batch_stop) @@ -227,7 +237,7 @@ def get_file_stats(filename, counts_time = tdata.shape[0] valid_count = torch.sum(quadrature(valid_mask), dim=0) counts_time_space = valid_count - + # Basic observables # compute mean and variance # the mean needs to be divided by number of valid samples: @@ -235,6 +245,25 @@ def get_file_stats(filename, # we compute m2 directly, so we do not need to divide by number of valid samples: tm2 = torch.sum(quadrature(torch.square(tdata_masked - tmean)), dim=0, keepdim=False).reshape(1, -1, 1, 1) + # compute PSD stats + psd_stats = None + if sht is not None: + # compute psd (B, C, L) + psd = compute_powerspectrum(tdata_masked, sht) + # compute mean and m2 over batch + psd_mean = torch.mean(psd, dim=0, keepdim=True) # (1, C, L) + psd_m2 = torch.sum(torch.square(psd - psd_mean), dim=0, keepdim=True) # (1, C, L) + + # reshape to (1, C, L, 1) for welford compatibility + psd_mean = psd_mean.unsqueeze(-1) + psd_m2 = psd_m2.unsqueeze(-1) + + psd_stats = { + "type": "meanvar", + "counts": float(counts_time) * torch.ones((data.shape[1]), dtype=torch.float64, device=device), + "values": torch.stack([psd_mean, psd_m2], dim=0).contiguous() + } + # fill the dict tmpstats = dict( maxs = { @@ -261,6 +290,9 @@ def get_file_stats(filename, } ) + if psd_stats is not None: + tmpstats["psd_meanvar"] = psd_stats + # time diffs: read one more sample for these, if possible # TODO: tile it for dt < batch_size if batch_start >= dt: @@ -288,13 +320,13 @@ def get_file_stats(filename, # we need the shapes tshape = tmean.shape tmpstats["time_diff_meanvar"] = { - "type": "meanvar", + "type": "meanvar", "counts": torch.zeros(data.shape[1], dtype=torch.float64, device=device), "values": torch.stack( [ - torch.zeros(tshape, dtype=torch.float64, device=device), + torch.zeros(tshape, dtype=torch.float64, device=device), torch.zeros(tshape, dtype=torch.float64, device=device) - ], + ], dim=0 ).contiguous(), } @@ -333,7 +365,7 @@ def get_file_stats(filename, "counts": torch.zeros(wdiffshape[1], dtype=torch.float64, device=device), "values": torch.stack( [ - torch.zeros(wdiffshape, dtype=torch.float64, device=device), + torch.zeros(wdiffshape, dtype=torch.float64, device=device), torch.zeros(wdiffshape, dtype=torch.float64, device=device) ], dim=0 @@ -379,7 +411,7 @@ def get_stats(input_path: str, output_path: str, metadata_file: str, dt: int, quadrature_rule: str, wind_angle_aware: bool, fail_on_nan: bool=False, batch_size: Optional[int]=16, reduction_group_size: Optional[int]=8): - """Function to compute various statistics of all variables of a makani HDF5 dataset. + """Function to compute various statistics of all variables of a makani HDF5 dataset. This function reads data from input_path and computes minimum, maximum, mean and standard deviation for all variables in the dataset. This is done globally, meaning averaged over space and time. @@ -410,7 +442,7 @@ def get_stats(input_path: str, output_path: str, metadata_file: str, coords: this is a dictionary which contains two lists, latitude and longitude coordinates in degrees as well as channel names. Example: coords = dict(lat=[-90.0, ..., 90.], lon=[0, ..., 360], channel=["t2m", "u500", "v500", ...]) Note that the number of entries in coords["lat"] has to match dimension -2 of the dataset, and coords["lon"] dimension -1. - The length of the channel names has to match dimension -3 (or dimension 1, which is the same) of the dataset. + The length of the channel names has to match dimension -3 (or dimension 1, which is the same) of the dataset. dt : int The temporal difference for which the temporal means and standard deviations should be computed. Note that this is in units of dhours (see metadata file), quadrature_rule : str @@ -456,14 +488,14 @@ def get_stats(input_path: str, output_path: str, metadata_file: str, # init torch distributed device = torch.device(f"cuda:{comm_local_rank}") if torch.cuda.is_available() else torch.device("cpu") dist.init_process_group( - backend="nccl" if torch.cuda.is_available() else "gloo", + backend="nccl" if torch.cuda.is_available() else "gloo", init_method="env://", world_size=comm_size, rank=comm_rank, device_id=device, ) mesh = dist.init_device_mesh( - device_type=device.type, + device_type=device.type, mesh_shape=[reduction_group_size, comm_size // reduction_group_size], mesh_dim_names=["reduction", "tree"], ) @@ -524,6 +556,16 @@ def get_stats(input_path: str, output_path: str, metadata_file: str, crop_shape=None, crop_offset=(0, 0), normalize=False, pole_mask=None).to(device) + # Initialize SHT + grid_type_map = { + "naive": "equiangular", + "clenshaw-curtiss": "clenshaw-curtiss", + "gauss-legendre": "legendre-gauss" + } + grid_type = grid_type_map[quadrature_rule] + sht = RealSHT(height, width, grid=grid_type).to(device) + lmax = sht.lmax + if comm_rank == 0: print(f"Found {len(filelist)} files with a total of {num_samples_total} samples. Each sample has the shape {num_channels}x{height}x{width} (CxHxW).") @@ -564,50 +606,55 @@ def get_stats(input_path: str, output_path: str, metadata_file: str, # initialize arrays stats = dict( global_meanvar = { - "type": "meanvar", - "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), + "type": "meanvar", + "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), "values": torch.zeros((2, 1, num_channels, 1, 1), dtype=torch.float64, device=device), }, mins = { - "type": "min", - "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), + "type": "min", + "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), "values": torch.full((1, num_channels, 1, 1), torch.inf, dtype=torch.float64, device=device) }, maxs = { - "type": "max", - "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), + "type": "max", + "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), "values": torch.full((1, num_channels, 1, 1), -torch.inf, dtype=torch.float64, device=device) }, time_means = { - "type": "mean", - "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), + "type": "mean", + "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), "values": torch.zeros((1, num_channels, height, width), dtype=torch.float64, device=device) }, time_diff_meanvar = { - "type": "meanvar", - "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), - "values": torch.zeros((2, 1, num_channels, 1, 1), dtype=torch.float64, device=device), + "type": "meanvar", + "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), + "values": torch.zeros((2, 1, num_channels, 1, 1), dtype=torch.float64, device=device), + }, + psd_meanvar = { + "type": "meanvar", + "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), + "values": torch.zeros((2, 1, num_channels, lmax, 1), dtype=torch.float64, device=device), } ) if wind_channels is not None: num_wind_channels = len(wind_channels[0]) stats["wind_meanvar"] = { - "type": "meanvar", - "counts": torch.zeros((num_wind_channels), dtype=torch.float64, device=device), - "values": torch.zeros((2, 1, num_wind_channels, 1, 1), dtype=torch.float64, device=device), + "type": "meanvar", + "counts": torch.zeros((num_wind_channels), dtype=torch.float64, device=device), + "values": torch.zeros((2, 1, num_wind_channels, 1, 1), dtype=torch.float64, device=device), } stats["winddiff_meanvar"] = { - "type": "meanvar", + "type": "meanvar", "counts": torch.zeros((num_wind_channels), dtype=torch.float64, device=device), - "values": torch.zeros((2, 1, num_wind_channels, 1, 1), dtype=torch.float64, device=device), + "values": torch.zeros((2, 1, num_wind_channels, 1, 1), dtype=torch.float64, device=device), } # compute local stats progress = DistributedProgressBar(num_samples_total, comm) start = time.time() for filename, index_bounds in mapping.items(): - tmpstats = get_file_stats(filename, slice(index_bounds[0], index_bounds[1]+1), wind_channels, quadrature, fail_on_nan, dt, batch_size, device, progress) + tmpstats = get_file_stats(filename, slice(index_bounds[0], index_bounds[1]+1), wind_channels, quadrature, fail_on_nan, dt, batch_size, device, progress, sht) stats = welford_combine(stats, tmpstats) # wait for everybody else @@ -648,6 +695,7 @@ def get_stats(input_path: str, output_path: str, metadata_file: str, # compute global stds: stats["global_meanvar"]["values"][1, ...] = np.sqrt(stats["global_meanvar"]["values"][1, ...] / stats["global_meanvar"]["counts"][None, :, None, None]) stats["time_diff_meanvar"]["values"][1, ...] = np.sqrt(stats["time_diff_meanvar"]["values"][1, ...] / stats["time_diff_meanvar"]["counts"][None, :, None, None]) + stats["psd_meanvar"]["values"][1, ...] = np.sqrt(stats["psd_meanvar"]["values"][1, ...] / stats["psd_meanvar"]["counts"][None, :, None, None]) # overwrite the wind channels if wind_channels is not None: @@ -672,6 +720,8 @@ def get_stats(input_path: str, output_path: str, metadata_file: str, np.save(os.path.join(output_path, 'time_means.npy'), stats["time_means"]["values"].astype(np.float32)) np.save(os.path.join(output_path, f'time_diff_means_dt{dt}.npy'), stats["time_diff_meanvar"]["values"][0, ...].astype(np.float32)) np.save(os.path.join(output_path, f'time_diff_stds_dt{dt}.npy'), stats["time_diff_meanvar"]["values"][1, ...].astype(np.float32)) + np.save(os.path.join(output_path, 'psd_means.npy'), stats["psd_meanvar"]["values"][0, ..., 0].astype(np.float32)) + np.save(os.path.join(output_path, 'psd_stds.npy'), stats["psd_meanvar"]["values"][1, ..., 0].astype(np.float32)) duration = time.time() - start print(f"Saving stats done. Duration: {duration:.2f}s", flush=True) From 15f06a5416e02c00f8ce5726d1eebf46e37796bb Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Wed, 24 Dec 2025 02:23:41 -0800 Subject: [PATCH 32/32] added psd normalization to energy score --- data_process/get_stats.py | 1 + makani/utils/dataloaders/data_helpers.py | 19 +++++++++++++ makani/utils/loss.py | 15 +++++++++- makani/utils/losses/energy_score.py | 36 +++++++++++++++++++++--- 4 files changed, 66 insertions(+), 5 deletions(-) diff --git a/data_process/get_stats.py b/data_process/get_stats.py index 634d2e5..6fe46be 100644 --- a/data_process/get_stats.py +++ b/data_process/get_stats.py @@ -552,6 +552,7 @@ def get_stats(input_path: str, output_path: str, metadata_file: str, height, width = (data_shape[2], data_shape[3]) # quadrature: + # we normalize the quadrature rule to 4pi quadrature = GridQuadrature(quadrature_rule, (height, width), crop_shape=None, crop_offset=(0, 0), normalize=False, pole_mask=None).to(device) diff --git a/makani/utils/dataloaders/data_helpers.py b/makani/utils/dataloaders/data_helpers.py index a52f684..c4d2fe0 100644 --- a/makani/utils/dataloaders/data_helpers.py +++ b/makani/utils/dataloaders/data_helpers.py @@ -70,6 +70,25 @@ def get_time_diff_stds(params): return time_diff_stds +def get_psd_stats(params): + + psd_means = None + psd_stds = None + + if hasattr(params, "psd_means_path") and hasattr(params, "psd_stds_path"): + psd_means = np.load(params.psd_means_path) + psd_stds = np.load(params.psd_stds_path) + + # filter channels if requested + if hasattr(params, "out_channels"): + psd_means = psd_means[..., params.out_channels, :] + psd_stds = psd_stds[..., params.out_channels, :] + else: + raise ValueError(f"psd_means_path or psd_stds_path not defined.") + + return psd_means, psd_stds + + def get_climatology(params): """ routine for fetching climatology and normalization factors diff --git a/makani/utils/loss.py b/makani/utils/loss.py index 22e30f8..64f1774 100644 --- a/makani/utils/loss.py +++ b/makani/utils/loss.py @@ -24,7 +24,7 @@ from makani.utils import comm from makani.utils.grids import GridQuadrature, BandLimitMask -from makani.utils.dataloaders.data_helpers import get_data_normalization, get_time_diff_stds +from makani.utils.dataloaders.data_helpers import get_data_normalization, get_time_diff_stds, get_psd_stats from physicsnemo.distributed.utils import compute_split_shapes from physicsnemo.distributed.mappings import gather_from_parallel_region, reduce_from_parallel_region @@ -91,6 +91,17 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps else: scale = torch.ones((1, len(params.out_channels), 1, 1), dtype=torch.float32) + # load PSD stats + try: + psd_means, psd_stds = get_psd_stats(params) + if psd_means is not None: + psd_means = torch.from_numpy(psd_means).to(torch.float32) + if psd_stds is not None: + psd_stds = torch.from_numpy(psd_stds).to(torch.float32) + except ValueError: + psd_means = None + psd_stds = None + # create module list self.loss_fn = nn.ModuleList([]) self.loss_requires_input = [] # track which losses need input state @@ -119,6 +130,8 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps pole_mask=pole_mask, bias=bias, scale=scale, + psd_means=psd_means, + psd_stds=psd_stds, grid_type=params.model_grid_type, spatial_distributed=self.spatial_distributed, ensemble_distributed=self.ensemble_distributed, diff --git a/makani/utils/losses/energy_score.py b/makani/utils/losses/energy_score.py index a3d8cb2..b63ae77 100644 --- a/makani/utils/losses/energy_score.py +++ b/makani/utils/losses/energy_score.py @@ -192,7 +192,11 @@ def __init__( alpha: Optional[float] = 1.0, beta: Optional[float] = 2.0, fraction: Optional[float] = 1.0, - eps: Optional[float] = 1.0e-5, + eps: Optional[float] = 1.0e-3, + psd_normalization: Optional[bool] = False, + bias: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + psd_means: Optional[torch.Tensor] = None, **kwargs, ): @@ -219,9 +223,32 @@ def __init__( self.ensemble_weights = ensemble_weights # get the local lm weights + if psd_normalization and psd_means is not None and scale is not None: + + # ensure shapes match for broadcasting + if psd_means.dim() == 3: + psd_means = psd_means.unsqueeze(-1) + + # normalize PSD + # Since PSD scales with square of signal, we divide by scale^2 + norm_psd = psd_means / (scale.reshape(1, -1, 1, 1)**2) + + # fix the 0-th component using the bias + normalized_bias_sq = (bias.reshape(1, -1) / scale.reshape(1, -1))**2 + norm_psd = norm_psd / (4.0 * math.pi) + norm_psd[..., 0, 0] = torch.clamp(norm_psd[..., 0, 0] - normalized_bias_sq, min=0.0) + + # Compute spectral weights: 1/sqrt(PSD) + # Add epsilon for numerical stability + psd_weights = 1.0 / (norm_psd + self.eps**2) + else: + psd_weights = torch.ones((1, 1, self.sht.lmax, 1)) + ls = torch.arange(self.sht.lmax).reshape(-1, 1) ms = torch.arange(self.sht.mmax).reshape(1, -1) lm_weights = (1 + ls * (ls + 1)).pow(self.fraction).tile(1, self.sht.mmax) + # account for the 4 pi normalization coming from the SHT + lm_weights = psd_weights * lm_weights / 4.0 / math.pi lm_weights[:, 1:] *= 2.0 lm_weights = torch.where(ms > ls, 0.0, lm_weights) if comm.get_size("h") > 1: @@ -255,8 +282,8 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, ensemble_ # as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same with amp.autocast(device_type="cuda", enabled=False): # TODO: check 4 pi normalization - forecasts = self.sht(forecasts.float()) / 4.0 / math.pi - observations = self.sht(observations.float()) / 4.0 / math.pi + forecasts = self.sht(forecasts.float()) + observations = self.sht(observations.float()) # cast back to original dtype forecasts = forecasts.to(dtype=dtype) @@ -276,7 +303,8 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, ensemble_ forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") # for correct spatial reduction we need to do the same with spatial weights lm_weights_split = self.lm_weights.flatten(start_dim=-2, end_dim=-1) - lm_weights_split = scatter_to_parallel_region(lm_weights_split, -1, "ensemble") + if self.ensemble_distributed: + lm_weights_split = scatter_to_parallel_region(lm_weights_split, -1, "ensemble") # observations does not need a transpose, but just a split observations = observations.reshape(1, B, C, H * W)