diff --git a/config/afnonet.yaml b/config/afnonet.yaml index 6eab5f8..02a1d88 100644 --- a/config/afnonet.yaml +++ b/config/afnonet.yaml @@ -57,9 +57,6 @@ full_field: &BASE_CONFIG nested_skip_fno: !!bool True # whether to nest the inner skip connection or have it be sequential, inside the AFNO block verbose: False - #options default, residual - target: "default" - channel_names: ["u10m", "v10m", "t2m", "sp", "msl", "t850", "u1000", "v1000", "z1000", "u850", "v850", "z850", "u500", "v500", "z500", "t500", "z50", "r500", "r850", "tcwv", "u100m", "v100m", "u250", "v250", "z250", "t250", "u100", "v100", "z100", "t100", "u900", "v900", "z900", "t900"] normalization: "zscore" #options zscore or minmax hard_thresholding_fraction: 1.0 diff --git a/config/debug.yaml b/config/debug.yaml index 8474473..cbd7b19 100644 --- a/config/debug.yaml +++ b/config/debug.yaml @@ -90,9 +90,6 @@ base_config: &BASE_CONFIG add_noise: !!bool False noise_std: 0. - target: "default" # options default, residual - normalize_residual: false - # define channels to be read from data channel_names: ["u10m", "v10m", "u100m", "v100m", "t2m", "sp", "msl", "tcwv", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"] # normalization mode zscore but for q diff --git a/config/fourcastnet3.yaml b/config/fourcastnet3.yaml index d27c629..21925c9 100644 --- a/config/fourcastnet3.yaml +++ b/config/fourcastnet3.yaml @@ -87,9 +87,6 @@ base_config: &BASE_CONFIG add_noise: !!bool False noise_std: 0. - target: "default" # options default, residual - normalize_residual: false - # define channels to be read from data. sp has been removed here channel_names: ["u10m", "v10m", "u100m", "v100m", "t2m", "msl", "tcwv", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"] # normalization mode zscore but for q diff --git a/config/icml_models.yaml b/config/icml_models.yaml index e135a7f..ddde08e 100644 --- a/config/icml_models.yaml +++ b/config/icml_models.yaml @@ -66,9 +66,6 @@ base_config: &BASE_CONFIG N_grid_channels: 0 gridtype: "sinusoidal" #options "sinusoidal" or "linear" - #options default, residual - target: "default" - channel_names: ["u10m", "v10m", "t2m", "sp", "msl", "t850", "u1000", "v1000", "z1000", "u850", "v850", "z850", "u500", "v500", "z500", "t500", "z50", "r500", "r850", "tcwv", "u100m", "v100m", "u250", "v250", "z250", "t250", "u100", "v100", "z100", "t100", "u900", "v900", "z900", "t900"] normalization: "zscore" #options zscore or minmax or none diff --git a/config/pangu.yaml b/config/pangu.yaml index c2852a1..7f49193 100644 --- a/config/pangu.yaml +++ b/config/pangu.yaml @@ -78,9 +78,6 @@ base_config: &BASE_CONFIG add_noise: !!bool False noise_std: 0. - target: "default" # options default, residual - normalize_residual: !!bool False - # define channels to be read from data channel_names: ["u10m", "v10m", "t2m", "msl", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"] normalization: "zscore" # options zscore or minmax or none @@ -131,10 +128,10 @@ base_onnx: &BASE_ONNX # ONNX wrapper related overwrite nettype: "/makani/makani/makani/models/networks/pangu_onnx.py:PanguOnnx" onnx_file: '/model/pangu_weather_6.onnx' - + amp_mode: "none" disable_ddp: True - + # Set Pangu ONNX channel order channel_names: ["msl", "u10m", "v10m", "t2m", "z1000", "z925", "z850", "z700", "z600", "z500", "z400", "z300", "z250", "z200", "z150", "z100", "z50", "q1000", "q925", "q850", "q700", "q600", "q500", "q400", "q300", "q250", "q200", "q150", "q100", "q50", "t1000", "t925", "t850", "t700", "t600", "t500", "t400", "t300", "t250", "t200", "t150", "t100", "t50", "u1000", "u925", "u850", "u700", "u600", "u500", "u400", "u300", "u250", "u200", "u150", "u100", "u50", "v1000", "v925", "v850", "v700", "v600", "v500", "v400", "v300", "v250", "v200", "v150", "v100", "v50"] # Remove input/output normalization diff --git a/config/sfnonet.yaml b/config/sfnonet.yaml index 271a0ca..b70b1d4 100644 --- a/config/sfnonet.yaml +++ b/config/sfnonet.yaml @@ -86,9 +86,6 @@ base_config: &BASE_CONFIG add_noise: !!bool False noise_std: 0. - target: "default" # options default, residual - normalize_residual: !!bool False - # define channels to be read from data channel_names: ["u10m", "v10m", "u100m", "v100m", "t2m", "sp", "msl", "tcwv", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"] normalization: "zscore" # options zscore or minmax or none diff --git a/config/vit.yaml b/config/vit.yaml index 58fcee0..fc2571c 100644 --- a/config/vit.yaml +++ b/config/vit.yaml @@ -65,9 +65,6 @@ full_field: &BASE_CONFIG embed_dim: 384 normalization_layer: "layer_norm" - #options default, residual - target: "default" - channel_names: ["u10m", "v10m", "u100m", "v100m", "t2m", "sp", "msl", "tcwv", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"] normalization: "zscore" #options zscore or minmax hard_thresholding_fraction: 1.0 diff --git a/data_process/get_stats.py b/data_process/get_stats.py index 59d9d15..6fe46be 100644 --- a/data_process/get_stats.py +++ b/data_process/get_stats.py @@ -32,6 +32,7 @@ import torch import torch.distributed as dist +from torch_harmonics import RealSHT from makani.utils.grids import GridQuadrature from wb2_helpers import DistributedProgressBar @@ -50,7 +51,7 @@ def allgather_dict(stats, group): for _ in range(dist.get_world_size(group)): substats = {varname: {} for varname in stats.keys()} stats_gather.append(substats) - + # iterate over full dict for varname, substats in stats.items(): for k,v in substats.items(): @@ -179,6 +180,14 @@ def welford_combine(stats1, stats2): return stats +def compute_powerspectrum(x, sht): + coeffs = sht(x).abs().pow(2) + # account for hermitian symetry + coeffs[..., 1:] *= 2.0 + power_spectrum = coeffs.sum(dim=-1) + return power_spectrum + + def get_file_stats(filename, file_slice, wind_indices, @@ -187,7 +196,8 @@ def get_file_stats(filename, dt=1, batch_size=16, device=torch.device("cpu"), - progress=None): + progress=None, + sht=None): stats = None with h5.File(filename, 'r') as f: @@ -203,7 +213,7 @@ def get_file_stats(filename, if batch_size is None: batch_size = slc_stop - slc_start - + for batch_start in range(slc_start, slc_stop, batch_size): batch_stop = min(batch_start+batch_size, slc_stop) sub_slc = slice(batch_start, batch_stop) @@ -227,7 +237,7 @@ def get_file_stats(filename, counts_time = tdata.shape[0] valid_count = torch.sum(quadrature(valid_mask), dim=0) counts_time_space = valid_count - + # Basic observables # compute mean and variance # the mean needs to be divided by number of valid samples: @@ -235,6 +245,25 @@ def get_file_stats(filename, # we compute m2 directly, so we do not need to divide by number of valid samples: tm2 = torch.sum(quadrature(torch.square(tdata_masked - tmean)), dim=0, keepdim=False).reshape(1, -1, 1, 1) + # compute PSD stats + psd_stats = None + if sht is not None: + # compute psd (B, C, L) + psd = compute_powerspectrum(tdata_masked, sht) + # compute mean and m2 over batch + psd_mean = torch.mean(psd, dim=0, keepdim=True) # (1, C, L) + psd_m2 = torch.sum(torch.square(psd - psd_mean), dim=0, keepdim=True) # (1, C, L) + + # reshape to (1, C, L, 1) for welford compatibility + psd_mean = psd_mean.unsqueeze(-1) + psd_m2 = psd_m2.unsqueeze(-1) + + psd_stats = { + "type": "meanvar", + "counts": float(counts_time) * torch.ones((data.shape[1]), dtype=torch.float64, device=device), + "values": torch.stack([psd_mean, psd_m2], dim=0).contiguous() + } + # fill the dict tmpstats = dict( maxs = { @@ -261,6 +290,9 @@ def get_file_stats(filename, } ) + if psd_stats is not None: + tmpstats["psd_meanvar"] = psd_stats + # time diffs: read one more sample for these, if possible # TODO: tile it for dt < batch_size if batch_start >= dt: @@ -288,13 +320,13 @@ def get_file_stats(filename, # we need the shapes tshape = tmean.shape tmpstats["time_diff_meanvar"] = { - "type": "meanvar", + "type": "meanvar", "counts": torch.zeros(data.shape[1], dtype=torch.float64, device=device), "values": torch.stack( [ - torch.zeros(tshape, dtype=torch.float64, device=device), + torch.zeros(tshape, dtype=torch.float64, device=device), torch.zeros(tshape, dtype=torch.float64, device=device) - ], + ], dim=0 ).contiguous(), } @@ -333,7 +365,7 @@ def get_file_stats(filename, "counts": torch.zeros(wdiffshape[1], dtype=torch.float64, device=device), "values": torch.stack( [ - torch.zeros(wdiffshape, dtype=torch.float64, device=device), + torch.zeros(wdiffshape, dtype=torch.float64, device=device), torch.zeros(wdiffshape, dtype=torch.float64, device=device) ], dim=0 @@ -379,7 +411,7 @@ def get_stats(input_path: str, output_path: str, metadata_file: str, dt: int, quadrature_rule: str, wind_angle_aware: bool, fail_on_nan: bool=False, batch_size: Optional[int]=16, reduction_group_size: Optional[int]=8): - """Function to compute various statistics of all variables of a makani HDF5 dataset. + """Function to compute various statistics of all variables of a makani HDF5 dataset. This function reads data from input_path and computes minimum, maximum, mean and standard deviation for all variables in the dataset. This is done globally, meaning averaged over space and time. @@ -410,7 +442,7 @@ def get_stats(input_path: str, output_path: str, metadata_file: str, coords: this is a dictionary which contains two lists, latitude and longitude coordinates in degrees as well as channel names. Example: coords = dict(lat=[-90.0, ..., 90.], lon=[0, ..., 360], channel=["t2m", "u500", "v500", ...]) Note that the number of entries in coords["lat"] has to match dimension -2 of the dataset, and coords["lon"] dimension -1. - The length of the channel names has to match dimension -3 (or dimension 1, which is the same) of the dataset. + The length of the channel names has to match dimension -3 (or dimension 1, which is the same) of the dataset. dt : int The temporal difference for which the temporal means and standard deviations should be computed. Note that this is in units of dhours (see metadata file), quadrature_rule : str @@ -456,14 +488,14 @@ def get_stats(input_path: str, output_path: str, metadata_file: str, # init torch distributed device = torch.device(f"cuda:{comm_local_rank}") if torch.cuda.is_available() else torch.device("cpu") dist.init_process_group( - backend="nccl" if torch.cuda.is_available() else "gloo", + backend="nccl" if torch.cuda.is_available() else "gloo", init_method="env://", world_size=comm_size, rank=comm_rank, device_id=device, ) mesh = dist.init_device_mesh( - device_type=device.type, + device_type=device.type, mesh_shape=[reduction_group_size, comm_size // reduction_group_size], mesh_dim_names=["reduction", "tree"], ) @@ -520,10 +552,21 @@ def get_stats(input_path: str, output_path: str, metadata_file: str, height, width = (data_shape[2], data_shape[3]) # quadrature: + # we normalize the quadrature rule to 4pi quadrature = GridQuadrature(quadrature_rule, (height, width), crop_shape=None, crop_offset=(0, 0), normalize=False, pole_mask=None).to(device) + # Initialize SHT + grid_type_map = { + "naive": "equiangular", + "clenshaw-curtiss": "clenshaw-curtiss", + "gauss-legendre": "legendre-gauss" + } + grid_type = grid_type_map[quadrature_rule] + sht = RealSHT(height, width, grid=grid_type).to(device) + lmax = sht.lmax + if comm_rank == 0: print(f"Found {len(filelist)} files with a total of {num_samples_total} samples. Each sample has the shape {num_channels}x{height}x{width} (CxHxW).") @@ -564,50 +607,55 @@ def get_stats(input_path: str, output_path: str, metadata_file: str, # initialize arrays stats = dict( global_meanvar = { - "type": "meanvar", - "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), + "type": "meanvar", + "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), "values": torch.zeros((2, 1, num_channels, 1, 1), dtype=torch.float64, device=device), }, mins = { - "type": "min", - "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), + "type": "min", + "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), "values": torch.full((1, num_channels, 1, 1), torch.inf, dtype=torch.float64, device=device) }, maxs = { - "type": "max", - "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), + "type": "max", + "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), "values": torch.full((1, num_channels, 1, 1), -torch.inf, dtype=torch.float64, device=device) }, time_means = { - "type": "mean", - "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), + "type": "mean", + "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), "values": torch.zeros((1, num_channels, height, width), dtype=torch.float64, device=device) }, time_diff_meanvar = { - "type": "meanvar", - "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), - "values": torch.zeros((2, 1, num_channels, 1, 1), dtype=torch.float64, device=device), + "type": "meanvar", + "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), + "values": torch.zeros((2, 1, num_channels, 1, 1), dtype=torch.float64, device=device), + }, + psd_meanvar = { + "type": "meanvar", + "counts": torch.zeros((num_channels), dtype=torch.float64, device=device), + "values": torch.zeros((2, 1, num_channels, lmax, 1), dtype=torch.float64, device=device), } ) if wind_channels is not None: num_wind_channels = len(wind_channels[0]) stats["wind_meanvar"] = { - "type": "meanvar", - "counts": torch.zeros((num_wind_channels), dtype=torch.float64, device=device), - "values": torch.zeros((2, 1, num_wind_channels, 1, 1), dtype=torch.float64, device=device), + "type": "meanvar", + "counts": torch.zeros((num_wind_channels), dtype=torch.float64, device=device), + "values": torch.zeros((2, 1, num_wind_channels, 1, 1), dtype=torch.float64, device=device), } stats["winddiff_meanvar"] = { - "type": "meanvar", + "type": "meanvar", "counts": torch.zeros((num_wind_channels), dtype=torch.float64, device=device), - "values": torch.zeros((2, 1, num_wind_channels, 1, 1), dtype=torch.float64, device=device), + "values": torch.zeros((2, 1, num_wind_channels, 1, 1), dtype=torch.float64, device=device), } # compute local stats progress = DistributedProgressBar(num_samples_total, comm) start = time.time() for filename, index_bounds in mapping.items(): - tmpstats = get_file_stats(filename, slice(index_bounds[0], index_bounds[1]+1), wind_channels, quadrature, fail_on_nan, dt, batch_size, device, progress) + tmpstats = get_file_stats(filename, slice(index_bounds[0], index_bounds[1]+1), wind_channels, quadrature, fail_on_nan, dt, batch_size, device, progress, sht) stats = welford_combine(stats, tmpstats) # wait for everybody else @@ -648,6 +696,7 @@ def get_stats(input_path: str, output_path: str, metadata_file: str, # compute global stds: stats["global_meanvar"]["values"][1, ...] = np.sqrt(stats["global_meanvar"]["values"][1, ...] / stats["global_meanvar"]["counts"][None, :, None, None]) stats["time_diff_meanvar"]["values"][1, ...] = np.sqrt(stats["time_diff_meanvar"]["values"][1, ...] / stats["time_diff_meanvar"]["counts"][None, :, None, None]) + stats["psd_meanvar"]["values"][1, ...] = np.sqrt(stats["psd_meanvar"]["values"][1, ...] / stats["psd_meanvar"]["counts"][None, :, None, None]) # overwrite the wind channels if wind_channels is not None: @@ -672,6 +721,8 @@ def get_stats(input_path: str, output_path: str, metadata_file: str, np.save(os.path.join(output_path, 'time_means.npy'), stats["time_means"]["values"].astype(np.float32)) np.save(os.path.join(output_path, f'time_diff_means_dt{dt}.npy'), stats["time_diff_meanvar"]["values"][0, ...].astype(np.float32)) np.save(os.path.join(output_path, f'time_diff_stds_dt{dt}.npy'), stats["time_diff_meanvar"]["values"][1, ...].astype(np.float32)) + np.save(os.path.join(output_path, 'psd_means.npy'), stats["psd_meanvar"]["values"][0, ..., 0].astype(np.float32)) + np.save(os.path.join(output_path, 'psd_stds.npy'), stats["psd_meanvar"]["values"][1, ..., 0].astype(np.float32)) duration = time.time() - start print(f"Saving stats done. Duration: {duration:.2f}s", flush=True) diff --git a/makani/convert_checkpoint.py b/makani/convert_checkpoint.py index be4f95c..19fed0a 100644 --- a/makani/convert_checkpoint.py +++ b/makani/convert_checkpoint.py @@ -100,7 +100,7 @@ def consolidate_checkpoints(input_path, output_path, checkpoint_version=0): print(checkpoint_paths) # load the static data necessary for instantiating the preprocessor (required due to the way the registry works) - LocalPackage._load_static_data(input_path, params) + LocalPackage._load_static_data(LocalPackage(input_path), params) # get the model multistep = params.n_future > 0 diff --git a/makani/models/common/__init__.py b/makani/models/common/__init__.py index 0341114..7c14ff0 100644 --- a/makani/models/common/__init__.py +++ b/makani/models/common/__init__.py @@ -18,3 +18,4 @@ from .fft import RealFFT1, InverseRealFFT1, RealFFT2, InverseRealFFT2, RealFFT3, InverseRealFFT3 from .layer_norm import GeometricInstanceNormS2 from .spectral_convolution import SpectralConv, SpectralAttention +from .pos_embedding import LearnablePositionEmbedding diff --git a/makani/models/common/layers.py b/makani/models/common/layers.py index 49c5910..46e394a 100644 --- a/makani/models/common/layers.py +++ b/makani/models/common/layers.py @@ -605,5 +605,5 @@ def forward(self, x): x = self.norm(x) x = self.linear(x) - + return x diff --git a/makani/models/common/pos_embedding.py b/makani/models/common/pos_embedding.py new file mode 100644 index 0000000..0f65da5 --- /dev/null +++ b/makani/models/common/pos_embedding.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +import torch +import torch.nn as nn + +from makani.utils import comm +from physicsnemo.distributed.utils import compute_split_shapes + +class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta): + """ + Abstract base class for position embeddings. + + This class defines the interface for position embedding modules + that add positional information to input tensors. + + Parameters + ---------- + img_shape : tuple, optional + Image shape (height, width), by default (480, 960) + grid : str, optional + Grid type, by default "equiangular" + num_chans : int, optional + Number of channels, by default 1 + """ + + def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1): + super().__init__() + + self.img_shape = img_shape + self.num_chans = num_chans + + def forward(self): + + return self.position_embeddings + +class LearnablePositionEmbedding(PositionEmbedding): + """ + Learnable position embeddings for spherical transformers. + + This module provides learnable position embeddings that can be either + latitude-only or full latitude-longitude embeddings. + + Parameters + ---------- + img_shape : tuple, optional + Image shape (height, width), by default (480, 960) + grid : str, optional + Grid type, by default "equiangular" + num_chans : int, optional + Number of channels, by default 1 + embed_type : str, optional + Embedding type ("lat" or "latlon"), by default "lat" + """ + + def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_type="lat"): + super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans) + + # if distributed, make sure to split correctly across ranks: + # in case of model parallelism, we need to make sure that we use the correct shapes per rank + # for h + if comm.get_size("h") > 1: + self.local_shape_h = compute_split_shapes(img_shape[0], comm.get_size("h"))[comm.get_rank("h")] + else: + self.local_shape_h = img_shape[0] + + # for w + if comm.get_size("w") > 1: + self.local_shape_w = compute_split_shapes(img_shape[1], comm.get_size("w"))[comm.get_rank("w")] + else: + self.local_shape_w = img_shape[1] + + if embed_type == "latlon": + self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.local_shape_h, self.local_shape_w)) + self.position_embeddings.is_shared_mp = [] + self.position_embeddings.sharded_dims_mp = [None, None, "h", "w"] + elif embed_type == "lat": + self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.local_shape_h, 1)) + self.position_embeddings.is_shared_mp = ["w"] + self.position_embeddings.sharded_dims_mp = [None, None, "h", None] + else: + raise ValueError(f"Unknown learnable position embedding type {embed_type}") + + def forward(self): + return self.position_embeddings.expand(-1,-1,self.local_shape_h, self.local_shape_w) diff --git a/makani/models/model_package.py b/makani/models/model_package.py index b674133..94ac71d 100644 --- a/makani/models/model_package.py +++ b/makani/models/model_package.py @@ -38,7 +38,7 @@ class LocalPackage: """ - Implements the earth2mip/modulus Package interface. + Implements the modulus Package interface. """ # These define the model package in terms of where makani expects the files to be located @@ -49,7 +49,7 @@ class LocalPackage: MEANS_FILE = "global_means.npy" STDS_FILE = "global_stds.npy" OROGRAPHY_FILE = "orography.nc" - LANDMASK_FILE = "land_mask.nc" + LANDMASK_FILE = "land_sea_mask.nc" SOILTYPE_FILE = "soil_type.nc" def __init__(self, root): @@ -148,11 +148,11 @@ def timestep(self): def update_state(self, replace_state=True): self.model.preprocessor.update_internal_state(replace_state=replace_state) return - + def set_rng(self, reset=True, seed=333): self.model.preprocessor.set_rng(reset=reset, seed=seed) return - + def forward(self, x, time, normalized_data=True, replace_state=None): if not normalized_data: x = (x - self.in_bias) / self.in_scale diff --git a/makani/models/model_registry.py b/makani/models/model_registry.py index 5b860a3..6d1a8ff 100644 --- a/makani/models/model_registry.py +++ b/makani/models/model_registry.py @@ -166,7 +166,27 @@ def get_model(params: ParamsBase, use_stochastic_interpolation: bool = False, mu if isinstance(model_handle, (EntryPoint, importlib_metadata.EntryPoint)): model_handle = model_handle.load() - model_handle = partial(model_handle, inp_shape=inp_shape, out_shape=out_shape, inp_chans=inp_chans, out_chans=out_chans, **params.to_dict()) + model_kwargs = params.to_dict() + + # pass normalization statistics to the model + if params.get("normalization", "none") in ["zscore", "minmax"]: + try: + bias, scale = get_data_normalization(params) + # Slice the stats to match the model's output channels + # Assuming the model's output corresponds to params.out_channels + if hasattr(params, "out_channels"): + if bias is not None: + bias = bias.flatten()[params.out_channels] + if scale is not None: + scale = scale.flatten()[params.out_channels] + + if bias is not None and scale is not None: + model_kwargs["normalization_means"] = bias + model_kwargs["normalization_stds"] = scale + except Exception as e: + logging.warning(f"Could not load normalization stats. Error: {e}") + + model_handle = partial(model_handle, inp_shape=inp_shape, out_shape=out_shape, inp_chans=inp_chans, out_chans=out_chans, **model_kwargs) else: raise KeyError(f"No model is registered under the name {name}") diff --git a/makani/models/networks/fourcastnet3.py b/makani/models/networks/fourcastnet3.py index 9244e14..4eed2a6 100644 --- a/makani/models/networks/fourcastnet3.py +++ b/makani/models/networks/fourcastnet3.py @@ -24,7 +24,7 @@ from itertools import groupby # helpers -from makani.models.common import DropPath, LayerScale, MLP, EncoderDecoder, SpectralConv +from makani.models.common import DropPath, LayerScale, MLP, EncoderDecoder, SpectralConv, LearnablePositionEmbedding from makani.utils.features import get_water_channels, get_channel_groups # get spectral transforms and spherical convolutions from torch_harmonics @@ -32,7 +32,6 @@ import torch_harmonics.distributed as thd # get pre-formulated layers -#from makani.models.common import GeometricInstanceNormS2 from makani.mpu.layers import DistributedMLP, DistributedEncoderDecoder # more distributed stuff @@ -57,6 +56,47 @@ def _soft_clamp(x: torch.Tensor, offset: float = 0.0): y = torch.where(x >= 0.5, x - 0.25, y) return y +# helper module to handle imputation of SST +# class MLPImputer(nn.Module): +# def __init__( +# self, +# inp_chans = 2, +# out_chans = 2, +# mlp_ratio = 2.0, +# activation_function=nn.GELU, +# ): + +# self.mlp = EncoderDecoder( +# num_layers=1, +# input_dim=inp_chans, +# output_dim=out_chans, +# hidden_dim=int(mlp_ratio * out_chans), +# act_layer=activation_function, +# input_format="nchw", +# ) + +# def forward(self, inp, out): +# return torch.where(torch.isnan(out), self.mlp(inp), out) + +class ConstantImputation(nn.Module): + def __init__( + self, + inp_chans = 2, + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(inp_chans, 1, 1)) + + if comm.get_size("spatial") > 1: + self.weight.is_shared_mp = ["spatial"] + self.weight.sharded_dims_mp = [None, None, None] + + def forward(self, x, mask = None): + if mask is None: + mask = torch.isnan(x) + else: + mask = torch.logical_or(mask, torch.isnan(x)) + return torch.where(mask, self.weight, x) class DiscreteContinuousEncoder(nn.Module): def __init__( @@ -408,7 +448,9 @@ def __init__( n_history=0, atmo_embed_dim=8, surf_embed_dim=8, - aux_embed_dim=8, + dyn_aux_embed_dim=8, + stat_aux_embed_dim=8, + pos_embed_dim=0, num_layers=4, num_groups=1, use_mlp=True, @@ -424,6 +466,7 @@ def __init__( sfno_block_frequency=2, big_skip=False, clamp_water=False, + encoder_bias=False, bias=False, checkpointing_level=0, freeze_encoder=False, @@ -436,12 +479,12 @@ def __init__( self.out_shape = out_shape self.atmo_embed_dim = atmo_embed_dim self.surf_embed_dim = surf_embed_dim - self.aux_embed_dim = aux_embed_dim + self.dyn_aux_embed_dim = dyn_aux_embed_dim + self.stat_aux_embed_dim = stat_aux_embed_dim + self.pos_embed_dim = pos_embed_dim self.big_skip = big_skip self.checkpointing_level = checkpointing_level - - # currently doesn't support neither history nor future: - assert n_history == 0 + self.n_history = n_history # compute the downscaled image size self.h = int(self.inp_shape[0] // scale_factor) @@ -451,11 +494,12 @@ def __init__( self._init_spectral_transforms(model_grid_type, sht_grid_type, hard_thresholding_fraction, max_modes) # compute static permutations to extract - self._precompute_channel_groups(channel_names, aux_channel_names) + self._precompute_channel_groups(channel_names, aux_channel_names, n_history) # compute the total number of internal groups self.n_out_chans = self.n_atmo_groups * self.n_atmo_chans + self.n_surf_chans self.total_embed_dim = self.n_atmo_groups * self.atmo_embed_dim + self.surf_embed_dim + self.total_aux_embed_dim = (self.n_dyn_aux_chans > 0) * self.dyn_aux_embed_dim + (self.n_stat_aux_chans > 0) * self.stat_aux_embed_dim + self.pos_embed_dim # convert kernel shape to tuple kernel_shape = tuple(kernel_shape) @@ -471,11 +515,10 @@ def __init__( raise ValueError(f"Unknown activation function {activation_function}") # encoder for the atmospheric channels - # TODO: add the groups self.atmo_encoder = DiscreteContinuousEncoder( inp_shape=inp_shape, out_shape=(self.h, self.w), - inp_chans=self.n_atmo_chans, + inp_chans=self.n_atmo_chans * (self.n_history + 1), out_chans=self.atmo_embed_dim, grid_in=model_grid_type, grid_out=sht_grid_type, @@ -484,16 +527,16 @@ def __init__( basis_norm_mode=filter_basis_norm_mode, activation_function=activation_function, groups=math.gcd(self.n_atmo_chans, self.atmo_embed_dim), - bias=bias, + bias=encoder_bias, use_mlp=encoder_mlp, ) - # encoder for the auxiliary channels + # encoder for the surface channels if self.n_surf_chans > 0: self.surf_encoder = DiscreteContinuousEncoder( inp_shape=inp_shape, out_shape=(self.h, self.w), - inp_chans=self.n_surf_chans, + inp_chans=self.n_surf_chans * (self.n_history + 1), out_chans=self.surf_embed_dim, grid_in=model_grid_type, grid_out=sht_grid_type, @@ -502,10 +545,52 @@ def __init__( basis_norm_mode=filter_basis_norm_mode, activation_function=activation_function, groups=math.gcd(self.n_surf_chans, self.surf_embed_dim), - bias=bias, + bias=encoder_bias, + use_mlp=encoder_mlp, + ) + + if self.sst_channels_in.shape[0] > 0: + self.sst_imputation = ConstantImputation( + inp_chans=self.sst_channels_in.shape[0], + ) + + # encoder for the auxiliary channels + if self.n_dyn_aux_chans > 0: + self.dyn_aux_encoder = DiscreteContinuousEncoder( + inp_shape=inp_shape, + out_shape=(self.h, self.w), + inp_chans=self.n_dyn_aux_chans * (self.n_history + 1), + out_chans=self.dyn_aux_embed_dim, + grid_in=model_grid_type, + grid_out=sht_grid_type, + kernel_shape=kernel_shape, + basis_type=filter_basis_type, + basis_norm_mode=filter_basis_norm_mode, + activation_function=activation_function, + groups=math.gcd(self.n_dyn_aux_chans, self.dyn_aux_embed_dim), + bias=encoder_bias, + use_mlp=encoder_mlp, + ) + + # encoder for the auxiliary channels + if self.n_stat_aux_chans > 0: + self.stat_aux_encoder = DiscreteContinuousEncoder( + inp_shape=inp_shape, + out_shape=(self.h, self.w), + inp_chans=self.n_stat_aux_chans, + out_chans=self.stat_aux_embed_dim, + grid_in=model_grid_type, + grid_out=sht_grid_type, + kernel_shape=kernel_shape, + basis_type=filter_basis_type, + basis_norm_mode=filter_basis_norm_mode, + activation_function=activation_function, + groups=math.gcd(self.n_stat_aux_chans, self.stat_aux_embed_dim), + bias=encoder_bias, use_mlp=encoder_mlp, ) + # decoder for the atmospheric variables self.atmo_decoder = DiscreteContinuousDecoder( inp_shape=(self.h, self.w), @@ -519,7 +604,7 @@ def __init__( basis_norm_mode=filter_basis_norm_mode, activation_function=activation_function, groups=math.gcd(self.n_atmo_chans, self.atmo_embed_dim), - bias=bias, + bias=encoder_bias, use_mlp=encoder_mlp, upsample_sht=upsample_sht, ) @@ -538,35 +623,21 @@ def __init__( basis_norm_mode=filter_basis_norm_mode, activation_function=activation_function, groups=math.gcd(self.n_surf_chans, self.surf_embed_dim), - bias=bias, + bias=encoder_bias, use_mlp=encoder_mlp, upsample_sht=upsample_sht, ) - # encoder for the auxiliary channels - if self.n_aux_chans > 0: - self.aux_encoder = DiscreteContinuousEncoder( - inp_shape=inp_shape, - out_shape=(self.h, self.w), - inp_chans=self.n_aux_chans, - out_chans=self.aux_embed_dim, - grid_in=model_grid_type, - grid_out=sht_grid_type, - kernel_shape=kernel_shape, - basis_type=filter_basis_type, - basis_norm_mode=filter_basis_norm_mode, - activation_function=activation_function, - groups=math.gcd(self.n_aux_chans, self.aux_embed_dim), - bias=bias, - use_mlp=encoder_mlp, - ) + # position embedding + if self.pos_embed_dim > 0: + self.pos_embed = LearnablePositionEmbedding(img_shape=(self.h, self.w), grid=sht_grid_type, num_chans=self.pos_embed_dim, embed_type="lat") # dropout self.pos_drop = nn.Dropout(p=pos_drop_rate) if pos_drop_rate > 0.0 else nn.Identity() dpr = [x.item() for x in torch.linspace(0, path_drop_rate, num_layers)] # get the handle for the normalization layer - norm_layer = self._get_norm_layer_handle(self.h, self.w, self.total_embed_dim, normalization_layer=normalization_layer, sht_grid_type=sht_grid_type) + norm_layer = self._get_norm_layer_handle(self.h, self.w, self.total_embed_dim + self.total_aux_embed_dim, normalization_layer=normalization_layer, sht_grid_type=sht_grid_type) # Internal NO blocks self.blocks = nn.ModuleList([]) @@ -583,7 +654,7 @@ def __init__( block = NeuralOperatorBlock( self.sht, self.isht, - self.total_embed_dim + (self.n_aux_chans > 0) * self.aux_embed_dim, + self.total_embed_dim + self.total_aux_embed_dim, self.total_embed_dim, conv_type=conv_type, mlp_ratio=mlp_ratio, @@ -603,17 +674,6 @@ def __init__( self.blocks.append(block) - # residual prediction - if self.big_skip: - self.residual_transform = nn.Conv2d(self.n_out_chans, self.n_out_chans, 1, bias=False) - self.residual_transform.weight.is_shared_mp = ["spatial"] - self.residual_transform.weight.sharded_dims_mp = [None, None, None, None] - if self.residual_transform.bias is not None: - self.residual_transform.bias.is_shared_mp = ["spatial"] - self.residual_transform.bias.sharded_dims_mp = [None] - scale = math.sqrt(0.5 / self.n_out_chans) - nn.init.normal_(self.residual_transform.weight, mean=0.0, std=scale) - # controlled output normalization of q and tcwv if clamp_water: water_chans = get_water_channels(channel_names) @@ -726,27 +786,57 @@ def _precompute_channel_groups( self, channel_names=[], aux_channel_names=[], + n_history=0, ): """ group the channels appropriately into atmospheric pressure levels and surface variables """ - atmo_chans, surf_chans, aux_chans, pressure_lvls = get_channel_groups(channel_names, aux_channel_names) + atmo_chans, surf_chans, dyn_aux_chans, stat_aux_chans, pressure_lvls = get_channel_groups(channel_names, aux_channel_names) + sst_chans = [channel_names.index("sst")] if "sst" in channel_names else [] + lsml_chans = [len(channel_names) + aux_channel_names.index("xlsml")] if "xlsml" in aux_channel_names else [] # compute how many channel groups will be kept internally self.n_atmo_groups = len(pressure_lvls) self.n_atmo_chans = len(atmo_chans) // self.n_atmo_groups + self.n_surf_chans = len(surf_chans) + self.n_dyn_aux_chans = len(dyn_aux_chans) + self.n_stat_aux_chans= len(stat_aux_chans) + self.n_aux_chans = self.n_dyn_aux_chans + self.n_stat_aux_chans # make sure they are divisible. Attention! This does not guarantee that the grrouping is correct if len(atmo_chans) % self.n_atmo_groups: raise ValueError(f"Expected number of atmospheric variables to be divisible by number of atmospheric groups but got {len(atmo_chans)} and {self.n_atmo_groups}") - self.register_buffer("atmo_channels", torch.LongTensor(atmo_chans), persistent=False) - self.register_buffer("surf_channels", torch.LongTensor(surf_chans), persistent=False) - self.register_buffer("aux_channels", torch.LongTensor(aux_chans), persistent=False) - - self.n_surf_chans = self.surf_channels.shape[0] - self.n_aux_chans = self.aux_channels.shape[0] + # if history is included, adapt the channel lists to include the offsets + self.n_atmo_groups = self.n_atmo_groups + n_dyn_chans = len(atmo_chans) + len(surf_chans) + len(dyn_aux_chans) + atmo_chans_in = atmo_chans.copy() + surf_chans_in = surf_chans.copy() + sst_chans_in = sst_chans.copy() + for ih in range(1, n_history+1): + atmo_chans_in += [(c + ih*n_dyn_chans) for c in atmo_chans] + surf_chans_in += [(c + ih*n_dyn_chans) for c in surf_chans] + sst_chans_in += [(c + ih*n_dyn_chans) for c in ssts_chan] + dyn_aux_chans += [(c + ih*n_dyn_chans) for c in dyn_aux_chans] + # account for the history offset in the static aux channels + stat_aux_chans = [c + n_history*n_dyn_chans for c in stat_aux_chans] + lsml_chans = [c + n_history*n_dyn_chans for c in lsml_chans] + + self.register_buffer("atmo_channels_in", torch.LongTensor(atmo_chans_in), persistent=False) + self.register_buffer("atmo_channels_out", torch.LongTensor(atmo_chans), persistent=False) + self.register_buffer("surf_channels_in", torch.LongTensor(surf_chans_in), persistent=False) + self.register_buffer("surf_channels_out", torch.LongTensor(surf_chans), persistent=False) + self.register_buffer("sst_channels_in", torch.LongTensor(sst_chans_in), persistent=False) + self.register_buffer("sst_channels_out", torch.LongTensor(sst_chans), persistent=False) + self.register_buffer("dyn_aux_channels", torch.LongTensor(dyn_aux_chans), persistent=False) + self.register_buffer("stat_aux_channels", torch.LongTensor(stat_aux_chans), persistent=False) + self.register_buffer("land_mask_channels", torch.LongTensor(lsml_chans), persistent=False) + self.register_buffer("pred_channels", torch.LongTensor(surf_chans + atmo_chans), persistent=False) + + # print(f"in atmo: {self.atmo_channels_in}") + # print(f"out atmo: {self.atmo_channels_out}") + # print(f"q50: {channel_names.index("q50")}") return @@ -756,13 +846,24 @@ def encode(self, x): """ batchdims = x.shape[:-3] - # for atmospheric channels the same encoder is applied to each atmospheric level - x_atmo = x[..., self.atmo_channels, :, :].contiguous().reshape(-1, self.n_atmo_chans, *x.shape[-2:]) + if hasattr(self, "sst_imputation"): + if self.land_mask_channels.nelement() > 0: + mask = x[..., self.land_mask_channels, :, :] + else: + mask = None + x[..., self.sst_channels_in, :, :] = self.sst_imputation(x[..., self.sst_channels_in, :, :], mask=mask) + + # for atmospheric channels the same encoder is applied to each atmospheric level and takes the entire history into account + x_atmo = x[..., self.atmo_channels_in, :, :].reshape(-1, 1 + self.n_history, self.n_atmo_groups, self.n_atmo_chans, *x.shape[-2:]) + # move the history backwards and fold it into the channel dimension + x_atmo = x_atmo.permute(0,2,3,1,4,5).reshape(-1, self.n_atmo_chans * (1 + self.n_history), *x.shape[-2:]).contiguous() x_out = self.atmo_encoder(x_atmo) x_out = x_out.reshape(*batchdims, self.n_atmo_groups * self.atmo_embed_dim, *x_out.shape[-2:]) if hasattr(self, "surf_encoder"): - x_surf = x[..., self.surf_channels, :, :].contiguous() + x_surf = x[..., self.surf_channels_in, :, :].reshape(-1, 1 + self.n_history, self.n_surf_chans, *x.shape[-2:]) + # move the history backwards and fold it into the channel dimension + x_surf = x_surf.transpose(-3,-4).reshape(-1, self.n_surf_chans * (1 + self.n_history), *x.shape[-2:]).contiguous() x_surf = self.surf_encoder(x_surf) x_out = torch.cat((x_out, x_surf), dim=-3) @@ -776,10 +877,27 @@ def encode_auxiliary_channels(self, x): """ batchdims = x.shape[:-3] - if hasattr(self, "aux_encoder"): - x_aux = x[..., self.aux_channels, :, :] - x_aux = self.aux_encoder(x_aux) - x_aux = x_aux.reshape(*batchdims, self.aux_embed_dim, *x_aux.shape[-2:]) + aux_tensors = [] + + if hasattr(self, "dyn_aux_encoder"): + x_aux = x[..., self.dyn_aux_channels, :, :].reshape(-1, 1 + self.n_history, self.n_dyn_aux_chans, *x.shape[-2:]) + x_aux = x_aux.transpose(-3,-4).reshape(-1, self.n_dyn_aux_chans * (1 + self.n_history), *x.shape[-2:]).contiguous() + x_aux = self.dyn_aux_encoder(x_aux) + x_aux = x_aux.reshape(*batchdims, self.dyn_aux_embed_dim, *x_aux.shape[-2:]) + aux_tensors.append(x_aux) + + if hasattr(self, "stat_aux_encoder"): + x_aux = x[..., self.stat_aux_channels, :, :].contiguous() + x_aux = self.stat_aux_encoder(x_aux) + x_aux = x_aux.reshape(*batchdims, self.stat_aux_embed_dim, *x_aux.shape[-2:]) + aux_tensors.append(x_aux) + + if hasattr(self, "pos_embed"): + x_pos = self.pos_embed() + aux_tensors.append(x_pos) + + if len(aux_tensors) > 0: + x_aux = torch.cat(aux_tensors, dim=-3) else: x_aux = None @@ -795,12 +913,13 @@ def decode(self, x): x_atmo = x[..., : (self.n_atmo_groups * self.atmo_embed_dim), :, :].reshape(-1, self.atmo_embed_dim, *x.shape[-2:]) x_atmo = self.atmo_decoder(x_atmo) x_out = torch.zeros(*batchdims, self.n_out_chans, *x_atmo.shape[-2:], dtype=x.dtype, device=x.device) - x_out[..., self.atmo_channels, :, :] = x_atmo.reshape(*batchdims, -1, *x_atmo.shape[-2:]) + x_out[..., self.atmo_channels_out, :, :] = x_atmo.reshape(*batchdims, -1, *x_atmo.shape[-2:]) + if hasattr(self, "surf_decoder"): x_surf = x[..., -self.surf_embed_dim :, :, :] x_surf = self.surf_decoder(x_surf) - x_out[..., self.surf_channels, :, :] = x_surf.reshape(*batchdims, -1, *x_surf.shape[-2:]) + x_out[..., self.surf_channels_out, :, :] = x_surf.reshape(*batchdims, -1, *x_surf.shape[-2:]) return x_out @@ -836,7 +955,7 @@ def forward(self, x): # save big skip if self.big_skip: - residual = x[..., : self.n_out_chans, :, :].contiguous() + residual = x[..., self.pred_channels, :, :].contiguous() # extract embeddings for the auxiliary embeddings x_aux = self.encode_auxiliary_channels(x) @@ -850,6 +969,13 @@ def forward(self, x): # run the processor x = self.processor_blocks(x, x_aux) + # for debugging print the activations + atmo_activations = x[..., : (self.n_atmo_groups * self.atmo_embed_dim), :, :].reshape(-1, self.n_atmo_groups, self.atmo_embed_dim, *x.shape[-2:]) + s, m = torch.std_mean(atmo_activations, dim=(0, -1, -2)) + print(f"group 0, stds: {s[0]} means: {m[0]}") + print(f"group -1, stds: {s[-1]} means: {m[-1]}") + + # run the decoder if self.checkpointing_level >= 1: x = checkpoint(self.decode, x, use_reentrant=False) @@ -857,7 +983,7 @@ def forward(self, x): x = self.decode(x) if self.big_skip: - x = x + self.residual_transform(residual) + x[..., self.pred_channels, :, :] = x + residual # apply output transform x = self.clamp_water_channels(x) diff --git a/makani/models/networks/pangu.py b/makani/models/networks/pangu.py index b5b8e55..ad6b352 100644 --- a/makani/models/networks/pangu.py +++ b/makani/models/networks/pangu.py @@ -452,7 +452,7 @@ def forward(self, x: torch.Tensor, mask=None): x: input features with shape of (B * num_lon, num_pl*num_lat, N, C) mask: (0/-inf) mask with shape of (num_lon, num_pl*num_lat, Wpl*Wlat*Wlon, Wpl*Wlat*Wlon) """ - + B_, nW_, N, C = x.shape qkv = ( self.qkv(x) @@ -478,7 +478,7 @@ def forward(self, x: torch.Tensor, mask=None): attn = self.attn_drop_fn(attn) x = self.apply_attention(attn, v, B_, nW_, N, C) - + else: if mask is not None: bias = mask.unsqueeze(1).unsqueeze(0) + earth_position_bias.unsqueeze(0).unsqueeze(0) @@ -486,10 +486,10 @@ def forward(self, x: torch.Tensor, mask=None): #bias = bias.squeeze(2) else: bias = earth_position_bias.unsqueeze(0) - + # extract batch size for q,k,v nLon = self.num_lon - q = q.view(B_ // nLon, nLon, q.shape[1], q.shape[2], q.shape[3], q.shape[4]) + q = q.view(B_ // nLon, nLon, q.shape[1], q.shape[2], q.shape[3], q.shape[4]) k = k.view(B_ // nLon, nLon, k.shape[1], k.shape[2], k.shape[3], k.shape[4]) v = v.view(B_ // nLon, nLon, v.shape[1], v.shape[2], v.shape[3], v.shape[4]) #### @@ -736,7 +736,7 @@ class Pangu(nn.Module): - https://arxiv.org/abs/2211.02556 """ - def __init__(self, + def __init__(self, inp_shape=(721,1440), out_shape=(721,1440), grid_in="equiangular", @@ -773,14 +773,14 @@ def __init__(self, self.checkpointing_level = checkpointing_level drop_path = np.linspace(0, drop_path_rate, 8).tolist() - + # Add static channels to surface self.num_aux = len(self.aux_channel_names) N_total_surface = self.num_aux + self.num_surface # compute static permutations to extract self._precompute_channel_groups(self.channel_names, self.aux_channel_names) - + # Patch embeddings are 2D or 3D convolutions, mapping the data to the required patches self.patchembed2d = PatchEmbed2D( img_size=self.inp_shape, @@ -791,7 +791,7 @@ def __init__(self, flatten=False, norm_layer=None, ) - + self.patchembed3d = PatchEmbed3D( img_size=(num_levels, self.inp_shape[0], self.inp_shape[1]), patch_size=patch_size, @@ -870,7 +870,7 @@ def __init__(self, self.patchrecovery3d = PatchRecovery3D( (num_levels, self.inp_shape[0], self.inp_shape[1]), patch_size, 2 * embed_dim, num_atmospheric ) - + def _precompute_channel_groups( self, channel_names=[], @@ -901,7 +901,7 @@ def _precompute_channel_groups( def prepare_input(self, input): """ - Prepares the input tensor for the Pangu model by splitting it into surface * static variables and atmospheric, + Prepares the input tensor for the Pangu model by splitting it into surface * static variables and atmospheric, and reshaping the atmospheric variables into the required format. """ @@ -932,23 +932,23 @@ def prepare_output(self, output_surface, output_atmospheric): level_dict = {level: [idx for idx, value in enumerate(self.channel_names) if value[1:] == level] for level in levels} reordered_ids = [idx for level in levels for idx in level_dict[level]] check_reorder = [f'{level}_{idx}' for level in levels for idx in level_dict[level]] - + # Flatten & reorder the output atmospheric to original order (doublechecked that this is working correctly!) flattened_atmospheric = output_atmospheric.reshape(output_atmospheric.shape[0], -1, output_atmospheric.shape[3], output_atmospheric.shape[4]) reordered_atmospheric = torch.cat([torch.zeros_like(output_surface), torch.zeros_like(flattened_atmospheric)], dim=1) for i in range(len(reordered_ids)): reordered_atmospheric[:, reordered_ids[i], :, :] = flattened_atmospheric[:, i, :, :] - + # Append the surface output, this has not been reordered. if output_surface is not None: - _, surf_chans, _, _ = features.get_channel_groups(self.channel_names, self.aux_channel_names) + _, surf_chans, _, _, _ = features.get_channel_groups(self.channel_names, self.aux_channel_names) reordered_atmospheric[:, surf_chans, :, :] = output_surface output = reordered_atmospheric else: output = reordered_atmospheric return output - + def forward(self, input): # Prep the input by splitting into surface and atmospheric variables @@ -959,7 +959,7 @@ def forward(self, input): surface = checkpoint(self.patchembed2d, surface_aux, use_reentrant=False) atmospheric = checkpoint(self.patchembed3d, atmospheric, use_reentrant=False) else: - surface = self.patchembed2d(surface_aux) + surface = self.patchembed2d(surface_aux) atmospheric = self.patchembed3d(atmospheric) if surface.shape[1] == 0: @@ -1011,11 +1011,5 @@ def forward(self, input): output_atmospheric = self.patchrecovery3d(output_atmospheric) output = self.prepare_output(output_surface, output_atmospheric) - - return output - - - - - + return output diff --git a/makani/models/networks/pangu_onnx.py b/makani/models/networks/pangu_onnx.py index 0805bad..bf3a006 100644 --- a/makani/models/networks/pangu_onnx.py +++ b/makani/models/networks/pangu_onnx.py @@ -38,7 +38,7 @@ class PanguOnnx(OnnxWrapper): channel_order_PL: List containing the names of the pressure levels with the ordering that the ONNX model expects onnx_file: Path to the ONNX file containing the model ''' - def __init__(self, + def __init__(self, channel_names=[], aux_channel_names=[], onnx_file=None, @@ -58,7 +58,7 @@ def _precompute_channel_groups( group the channels appropriately into atmospheric pressure levels and surface variables """ - atmo_chans, surf_chans, _, pressure_lvls = get_channel_groups(channel_names, aux_channel_names) + atmo_chans, surf_chans, _, _, pressure_lvls = get_channel_groups(channel_names, aux_channel_names) # compute how many channel groups will be kept internally self.n_atmo_groups = len(pressure_lvls) @@ -78,12 +78,12 @@ def prepare_input(self, input): B,V,Lat,Long=input.shape if B>1: - raise NotImplementedError("Not implemented yet for batch size greater than 1") + raise NotImplementedError("Not implemented yet for batch size greater than 1") input=input.squeeze(0) surface_aux_inp=input[self.surf_channels] atmospheric_inp=input[self.atmo_channels].reshape(self.n_atmo_groups,self.n_atmo_chans,Lat,Long).transpose(1,0) - + return surface_aux_inp, atmospheric_inp def prepare_output(self, output_surface, output_atmospheric): @@ -99,9 +99,9 @@ def prepare_output(self, output_surface, output_atmospheric): return output.unsqueeze(0) - + def forward(self, input): - + surface, atmospheric = self.prepare_input(input) @@ -109,5 +109,5 @@ def forward(self, input): output = self.prepare_output(output_surface, output) - + return output diff --git a/makani/models/noise.py b/makani/models/noise.py index 56e3232..5f49daa 100644 --- a/makani/models/noise.py +++ b/makani/models/noise.py @@ -100,7 +100,7 @@ def reset(self, batch_size=None): # this routine generates a noise sample for a single time step and updates the state accordingly, by appending the last time step def update(self, replace_state=False, batch_size=None): - # Update should always create a new state, so + # Update should always create a new state, so # we don't need to check for replace_state # create single occurence with torch.no_grad(): @@ -193,21 +193,28 @@ def __init__( alpha = float(alpha) # Compute ls, angular power spectrum and sigma_l: - ls = torch.arange(self.lmax) + ls = torch.arange(self.lmax).reshape(-1 ,1) + ms = torch.arange(self.mmax) power_spectrum = torch.pow(2 * ls + 1, -alpha) norm_factor = torch.sum((2 * ls + 1) * power_spectrum / 4.0 / math.pi) sigma_l = sigma * torch.sqrt(power_spectrum / norm_factor) + sigma_l = torch.where(ms <= ls, sigma_l, 0.0) # the new shape is B, T, C, L, M - sigma_l = sigma_l.reshape((1, 1, 1, self.lmax, 1)).to(dtype=torch.float32) + sigma_l = sigma_l.reshape((1, 1, 1, self.lmax, self.mmax)).to(dtype=torch.float32) # split tensor if comm.get_size("h") > 1: sigma_l = split_tensor_along_dim(sigma_l, dim=-2, num_chunks=comm.get_size("h"))[comm.get_rank("h")] + # split tensor + if comm.get_size("w") > 1: + sigma_l = split_tensor_along_dim(sigma_l, dim=-1, num_chunks=comm.get_size("w"))[comm.get_rank("w")] + # register buffer if learnable: self.register_parameter("sigma_l", nn.Parameter(sigma_l)) + self.sigma_l.sharded_dims_mp = [None, None, None, "h", "w"] else: self.register_buffer("sigma_l", sigma_l, persistent=False) diff --git a/makani/models/preprocessor.py b/makani/models/preprocessor.py index 59dbc2f..071de92 100644 --- a/makani/models/preprocessor.py +++ b/makani/models/preprocessor.py @@ -53,13 +53,7 @@ def __init__(self, params): self.history_eps = 1e-6 # residual normalization - self.learn_residual = params.target == "residual" - if self.learn_residual and (params.normalize_residual): - with torch.no_grad(): - residual_scale = torch.as_tensor(np.load(params.time_diff_stds_path)).to(torch.float32) - self.register_buffer("residual_scale", residual_scale, persistent=False) - else: - self.residual_scale = None + self.residual_scale = None # image shape self.img_shape = [params.img_shape_x, params.img_shape_y] @@ -178,20 +172,6 @@ def expand_history(self, x, nhist): x = torch.reshape(x, (b_, nhist, ct_ // nhist, h_, w_)) return x - def add_residual(self, x, dx): - if self.learn_residual: - if self.residual_scale is not None: - dx = dx * self.residual_scale - - # add residual: deal with history - x = self.expand_history(x, nhist=self.n_history + 1) - x[:, -1, ...] = x[:, -1, ...] + dx - x = self.flatten_history(x) - else: - x = dx - - return x - def add_static_features(self, x): if self.do_add_static_features: # we need to replicate the grid for each batch: diff --git a/makani/models/stepper.py b/makani/models/stepper.py index f04590b..f7edfea 100644 --- a/makani/models/stepper.py +++ b/makani/models/stepper.py @@ -49,9 +49,6 @@ def forward(self, inp, update_state=True, replace_state=True): # undo normalization y = self.preprocessor.history_denormalize(yn, target=True) - # add residual (for residual learning, no-op for direct learning - y = self.preprocessor.add_residual(inp, y) - return y @@ -60,7 +57,6 @@ def __init__(self, params, model_handle): super().__init__() self.preprocessor = Preprocessor2D(params) self.model = model_handle() - self.residual_mode = True if (params.target == "target") else False self.push_forward_mode = params.get("multistep_push_forward", False) # collect parameters for history @@ -102,9 +98,6 @@ def _forward_train(self, inp, update_state=True, replace_state=True): # will have been updated later: pred = self.preprocessor.history_denormalize(predn, target=True) - # add residual (for residual learning, no-op for direct learning - pred = self.preprocessor.add_residual(inpt, pred) - # append output result.append(pred) @@ -148,15 +141,12 @@ def _forward_eval(self, inp, update_state=True, replace_state=True): # because otherwise normalization stats are already outdated y = self.preprocessor.history_denormalize(yn, target=True) - # add residual (for residual learning, no-op for direct learning - y = self.preprocessor.add_residual(inp, y) - return y def forward(self, inp, update_state=True, replace_state=True): # decide which routine to call if self.training: - y = self._forward_train(inp, update_state=True, replace_state=replace_state) + y = self._forward_train(inp, update_state=update_state, replace_state=replace_state) else: y = self._forward_eval(inp, update_state=update_state, replace_state=replace_state) diff --git a/makani/mpu/layers.py b/makani/mpu/layers.py index 18f470b..dd67a96 100644 --- a/makani/mpu/layers.py +++ b/makani/mpu/layers.py @@ -266,6 +266,134 @@ def forward(self, x): else: return self.fwd(x) +# Stochastic MLP needs comm datastructure +class StochasticMLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + output_bias=True, + input_format="nchw", + drop_rate=0.0, + drop_type="iid", + checkpointing=False, + gain=1.0, + seed=333, + **kwargs, + ): + super().__init__() + + self.checkpointing = checkpointing + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + # generator objects: + self.set_rng(seed=seed) + + # First fully connected layer + if input_format == "nchw": + self.fc1_weight_std = nn.Parameter(torch.zeros(hidden_features, in_features, 1, 1)) + self.fc1_weight_mean = nn.Parameter(torch.zeros(hidden_features, in_features, 1, 1)) + self.fc1_bias = nn.Parameter(torch.zeros(hidden_features)) + else: + raise NotImplementedError(f"Error, input format {input_format} not supported.") + + # sharing settings + self.fc1_weight_std.is_shared_mp = ["spatial"] + self.fc1_weight_mean.is_shared_mp = ["spatial"] + self.fc1_bias.is_shared_mp = ["spatial"] + + # initialize the weights correctly + scale = math.sqrt(1.0 / in_features) + nn.init.normal_(self.fc1_weight_std, mean=0.0, std=scale) + nn.init.normal_(self.fc1_weight_mean, mean=0.0, std=scale) + + # activation + self.act = act_layer() + + # sanity checks + if (input_format == "traditional") and (drop_type == "features"): + raise NotImplementedError(f"Error, traditional input format and feature dropout cannot be selected simultaneously") + + # output layer + if input_format == "nchw": + self.fc2_weight_std = nn.Parameter(torch.zeros(out_features, hidden_features, 1, 1)) + self.fc2_weight_mean = nn.Parameter(torch.zeros(out_features, hidden_features, 1, 1)) + self.fc2_bias = nn.Parameter(torch.zeros(out_features)) if output_bias else None + else: + raise NotImplementedError(f"Error, input format {input_format} not supported.") + + # sharing settings + self.fc2_weight_std.is_shared_mp = ["spatial"] + self.fc2_weight_mean.is_shared_mp = ["spatial"] + if self.fc2_bias is not None: + self.fc2_bias.is_shared_mp = ["spatial"] + + # gain factor for the output determines the scaling of the output init + scale = math.sqrt(gain / hidden_features / 2) + nn.init.normal_(self.fc2_weight_std, mean=0.0, std=scale) + nn.init.normal_(self.fc2_weight_mean, mean=0.0, std=scale) + if self.fc2_bias is not None: + nn.init.constant_(self.fc2_bias, 0.0) + + if drop_rate > 0.0: + if drop_type == "iid": + self.drop = nn.Dropout(drop_rate) + elif drop_type == "features": + self.drop = nn.Dropout2d(drop_rate) + else: + raise NotImplementedError(f"Error, drop_type {drop_type} not supported") + else: + self.drop = nn.Identity() + + @torch.compiler.disable(recursive=False) + def set_rng(self, seed=333): + self.rng_cpu = torch.Generator(device=torch.device("cpu")) + self.rng_cpu.manual_seed(seed) + if torch.cuda.is_available(): + self.rng_gpu = torch.Generator(device=torch.device(f"cuda:{comm.get_local_rank()}")) + self.rng_gpu.manual_seed(seed) + + @torch.compiler.disable(recursive=False) + def checkpoint_forward(self, x): + return checkpoint(self.fwd, x, use_reentrant=False) + + def fwd(self, x): + + # generate weight1 + weight1 = torch.empty_like(self.fc1_weight_mean) + weight1.normal_(mean=0.0, std=1.0, generator=self.rng_gpu if weight1.is_cuda else self.rng_cpu) + weight1 = self.fc1_weight_std * weight1 + self.fc1_weight_mean + + # fully connected 1 + x = nn.functional.conv2d(x, weight1, bias=self.fc1_bias) + + # activation + x = self.act(x) + + # dropout + x = self.drop(x) + + # generate weight1 + weight2 = torch.empty_like(self.fc2_weight_mean) + weight2.normal_(mean=0.0, std=1.0, generator=self.rng_gpu if weight2.is_cuda else self.rng_cpu) + weight2 = self.fc2_weight_std * weight2 + self.fc2_weight_mean + + # fully connected 2 + x = nn.functional.conv2d(x, weight2, bias=self.fc2_bias) + + # dropout + x = self.drop(x) + + return x + + def forward(self, x): + if self.checkpointing: + return self.checkpoint_forward(x) + else: + return self.fwd(x) class DistributedPatchEmbed(nn.Module): def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768, input_is_matmul_parallel=False, output_is_matmul_parallel=True): diff --git a/makani/utils/dataloaders/dali_es_helper_2d.py b/makani/utils/dataloaders/dali_es_helper_2d.py index 5a8b252..bd6627d 100644 --- a/makani/utils/dataloaders/dali_es_helper_2d.py +++ b/makani/utils/dataloaders/dali_es_helper_2d.py @@ -183,7 +183,7 @@ def _generate_indexlist(self, timestamp_boundary_list): if timestamp_boundary_list: #compute list of allowed timestamps timestamp_boundary_list = [get_date_from_string(timestamp_string) for timestamp_string in timestamp_boundary_list] - + # now, based on dt, dh, n_history and n_future, we can build regions where no data is allowed timestamp_exclusion_list = get_date_ranges(timestamp_boundary_list, lookback_hours = dt_total * (self.n_future + 1), lookahead_hours = dt_total * self.n_history) @@ -521,7 +521,7 @@ def _compute_zenith_angle(self, inp_times, tar_times): # nvtx range torch.cuda.nvtx.range_pop() - return cos_zenith_inp, cos_zenith_tar + return cos_zenith_inp, cos_zenith_tar def __getstate__(self): del self.aws_connector diff --git a/makani/utils/dataloaders/dali_es_helper_concat_2d.py b/makani/utils/dataloaders/dali_es_helper_concat_2d.py index 21f8b6f..c2e30a8 100644 --- a/makani/utils/dataloaders/dali_es_helper_concat_2d.py +++ b/makani/utils/dataloaders/dali_es_helper_concat_2d.py @@ -159,7 +159,7 @@ def _generate_indexlist(self, timestamp_boundary_list): if timestamp_boundary_list: #compute list of allowed timestamps timestamp_boundary_list = [get_date_from_string(timestamp_string) for timestamp_string in timestamp_boundary_list] - + # now, based on dt, dh, n_history and n_future, we can build regions where no data is allowed timestamp_exclusion_list = get_date_ranges(timestamp_boundary_list, lookback_hours = dt_total * (self.n_future + 1), lookahead_hours = dt_total * self.n_history) diff --git a/makani/utils/dataloaders/data_helpers.py b/makani/utils/dataloaders/data_helpers.py index 753e398..c4d2fe0 100644 --- a/makani/utils/dataloaders/data_helpers.py +++ b/makani/utils/dataloaders/data_helpers.py @@ -58,6 +58,36 @@ def get_data_normalization(params): return bias, scale +def get_time_diff_stds(params): + + time_diff_stds = None + + if hasattr(params, "time_diff_stds_path"): + time_diff_stds = np.load(params.time_diff_stds_path) + else: + raise ValueError(f"time_diff_std_path not defined.") + + return time_diff_stds + + +def get_psd_stats(params): + + psd_means = None + psd_stds = None + + if hasattr(params, "psd_means_path") and hasattr(params, "psd_stds_path"): + psd_means = np.load(params.psd_means_path) + psd_stds = np.load(params.psd_stds_path) + + # filter channels if requested + if hasattr(params, "out_channels"): + psd_means = psd_means[..., params.out_channels, :] + psd_stds = psd_stds[..., params.out_channels, :] + else: + raise ValueError(f"psd_means_path or psd_stds_path not defined.") + + return psd_means, psd_stds + def get_climatology(params): """ diff --git a/makani/utils/driver.py b/makani/utils/driver.py index 6e507c8..766df5a 100644 --- a/makani/utils/driver.py +++ b/makani/utils/driver.py @@ -632,11 +632,11 @@ def get_optimizer(self, model, params): if params.optimizer_type == "Adam": if self.log_to_screen: self.logger.info("using Adam optimizer") - optimizer = optim.Adam(all_parameters, betas=betas, lr=params.get("lr", 1e-3), weight_decay=params.get("weight_decay", 0), foreach=True) + optimizer = optim.Adam(all_parameters, lr=params.get("lr", 1e-3), betas=betas, eps=params.get("optimizer_eps", 1e-8), weight_decay=params.get("weight_decay", 0), foreach=True) elif params.optimizer_type == "AdamW": if self.log_to_screen: self.logger.info("using AdamW optimizer") - optimizer = optim.AdamW(all_parameters, betas=betas, lr=params.get("lr", 1e-3), weight_decay=params.get("weight_decay", 0), foreach=True) + optimizer = optim.AdamW(all_parameters, lr=params.get("lr", 1e-3), betas=betas, eps=params.get("optimizer_eps", 1e-8), weight_decay=params.get("weight_decay", 0), foreach=True) elif params.optimizer_type == "SGD": if self.log_to_screen: self.logger.info("using SGD optimizer") diff --git a/makani/utils/features.py b/makani/utils/features.py index 9b17750..cab61c3 100644 --- a/makani/utils/features.py +++ b/makani/utils/features.py @@ -88,7 +88,7 @@ def get_wind_channels(channel_names): wind_chans = [] for c, ch in enumerate(channel_names): - if ch[0] == "u": + if ch[0] == "u" and ("v" + ch[1:]) in channel_names: vc = channel_names.index("v" + ch[1:]) wind_chans = wind_chans + [c, vc] @@ -97,13 +97,15 @@ def get_wind_channels(channel_names): def get_channel_groups(channel_names, aux_channel_names=[]): """ - Helper routine to extract indices of atmospheric, surface and auxiliary variables and group them into their respective groups + Helper routine to extract indices of atmospheric, surface and auxiliary variables and group them into their respective groups. + The resulting numbering does NOT respect history. """ atmo_groups = OrderedDict() atmo_chans = [] surf_chans = [] - aux_chans = [] + dyn_aux_chans = [] + stat_aux_chans = [] # parse channel names and group variables by pressure level/surface variables for idx, chn in enumerate(channel_names): @@ -127,6 +129,10 @@ def get_channel_groups(channel_names, aux_channel_names=[]): atmo_chans += idx # append the auxiliary variable to the surface channels - aux_chans = [idx + len(channel_names) for idx in range(len(aux_channel_names))] + for idx, chn in enumerate(aux_channel_names): + if chn in ["xoro", "xlsml", "xlsms"]: + stat_aux_chans.append(idx + len(channel_names)) + else: + dyn_aux_chans.append(idx + len(channel_names)) - return atmo_chans, surf_chans, aux_chans, atmo_groups.keys() + return atmo_chans, surf_chans, dyn_aux_chans, stat_aux_chans, atmo_groups.keys() diff --git a/makani/utils/grids.py b/makani/utils/grids.py index 5208529..b3ff154 100644 --- a/makani/utils/grids.py +++ b/makani/utils/grids.py @@ -16,10 +16,12 @@ import numpy as np import torch -from torch_harmonics.quadrature import legendre_gauss_weights, clenshaw_curtiss_weights +import torch.amp as amp + +from torch_harmonics.quadrature import legendre_gauss_weights, clenshaw_curtiss_weights, _precompute_latitudes from makani.utils import comm -from physicsnemo.distributed.utils import compute_split_shapes +from physicsnemo.distributed.utils import compute_split_shapes, split_tensor_along_dim from physicsnemo.distributed.mappings import reduce_from_parallel_region @@ -33,9 +35,23 @@ def grid_to_quadrature_rule(grid_type): return grid_to_quad_dict[grid_type] +def compute_spherical_bandlimit(img_shape, grid_type): + + if grid_type == "equiangular": + lmax = (img_shape[0] - 1) // 2 + mmax = img_shape[1] // 2 + return min(lmax, mmax) + elif grid_type == "legendre-gauss": + lmax = img_shape[0] - 1 + mmax = img_shape[1] // 2 + return min(lmax, mmax) + else: + raise NotImplementedError(f"Unknown type {grid_type} not implemented") + + class GridConverter(torch.nn.Module): def __init__(self, src_grid, dst_grid, lat_rad, lon_rad): - super(GridConverter, self).__init__() + super().__init__() self.src = src_grid self.dst = dst_grid self.src_lat = lat_rad @@ -123,7 +139,7 @@ def __init__(self, quadrature_rule, img_shape, crop_shape=None, crop_offset=(0, # apply pole mask if (pole_mask is not None) and (pole_mask > 0): quad_weight[:pole_mask, :] = 0.0 - quad_weight[sizes[0] - pole_mask :, :] = 0.0 + quad_weight[img_shape[0] - pole_mask :, :] = 0.0 # if distributed, make sure to split correctly across ranks: # in case of model parallelism, we need to make sure that we use the correct shapes per rank @@ -165,3 +181,70 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: quad = reduce_from_parallel_region(quad.contiguous(), "spatial") return quad + + +class BandLimitMask(torch.nn.Module): + def __init__(self, img_shape, grid_type, lmax = None, type="sht"): + super().__init__() + self.img_shape = img_shape + self.grid_type = grid_type + self.lmax = lmax if lmax is not None else compute_spherical_bandlimit(img_shape, grid_type) + self.type = type + + if self.type == "sht": + # SHT for the computation of SH coefficients + if (comm.get_size("spatial") > 1): + from torch_harmonics.distributed import DistributedRealSHT, DistributedInverseRealSHT + import torch_harmonics.distributed as thd + if not thd.is_initialized(): + polar_group = None if (comm.get_size("h") == 1) else comm.get_group("h") + azimuth_group = None if (comm.get_size("w") == 1) else comm.get_group("w") + thd.init(polar_group, azimuth_group) + self.forward_transform = DistributedRealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type).float() + self.inverse_transform = DistributedInverseRealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type).float() + else: + from torch_harmonics import RealSHT, InverseRealSHT + + self.forward_transform = RealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type).float() + self.inverse_transform = InverseRealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type).float() + + elif self.type == "fft": + + # get the cutoff frequency in m for each latitude + lats, _ = _precompute_latitudes(self.img_shape[0], grid=self.grid_type) + # get the grid spacing at the equator + delta_equator = 2 * torch.pi / (self.lmax-1) + mlim = torch.ceil(2 * torch.pi * torch.sin(lats) / delta_equator).reshape(self.img_shape[0], 1) + ms = torch.arange(self.lmax).reshape(1, -1) + mask = (ms <= mlim) + mask = split_tensor_along_dim(mask, dim=-1, num_chunks=comm.get_size("w"))[comm.get_rank("w")] + mask = split_tensor_along_dim(mask, dim=-2, num_chunks=comm.get_size("h"))[comm.get_rank("h")] + self.register_buffer("mask", mask, persistent=False) + + if (comm.get_size("spatial") > 1): + from makani.mpu.fft import DistributedRealFFT1, DistributedInverseRealFFT1 + self.forward_transform = DistributedRealFFT1(img_shape[1], lmax=lmax, mmax=lmax).float() + self.inverse_transform = DistributedInverseRealFFT1(img_shape[1], lmax=lmax, mmax=lmax).float() + else: + from makani.models.common.fft import RealFFT1, InverseRealFFT1 + self.forward_transform = RealFFT1(img_shape[1], lmax=lmax, mmax=lmax).float() + self.inverse_transform = InverseRealFFT1(img_shape[1], lmax=lmax, mmax=lmax).float() + else: + raise ValueError(f"Unknown truncation type {self.type}") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + with amp.autocast(device_type="cuda", enabled=False): + dtype = x.dtype + x = x.float() + + x = self.forward_transform(x) + + if hasattr(self, "mask"): + x = torch.where(self.mask, x, torch.zeros_like(x)) + + x = self.inverse_transform(x) + + x = x.to(dtype=dtype) + + return x \ No newline at end of file diff --git a/makani/utils/inference/inferencer.py b/makani/utils/inference/inferencer.py index 7f8999b..78bb0c9 100644 --- a/makani/utils/inference/inferencer.py +++ b/makani/utils/inference/inferencer.py @@ -446,10 +446,11 @@ def inference_indexlist( return logs - def _initialize_noise_states(self): + def _initialize_noise_states(self, seed_offset=666): noise_states = [] - for _ in range(self.params.local_ensemble_size): - self.preprocessor.update_internal_state(replace_state=True) + for ide in range(self.params.local_ensemble_size): + member_seed = seed_offset + self.preprocessor.get_base_seed(default=333) * ide + self.preprocessor.set_rng(seed=member_seed, reset=True) noise_states.append(self.preprocessor.get_internal_state(tensor=True)) return noise_states @@ -515,7 +516,7 @@ def _inference_indexlist( climatology_iterator = iter(self.climatology_dataloader) # create loader for the full epoch - noise_states = [] + noise_states = self._initialize_noise_states() inptlist = None idt = 0 with torch.inference_mode(): @@ -572,7 +573,7 @@ def _inference_indexlist( self.preprocessor.update_internal_state(replace_state=True, batch_size=inp.shape[0]) # reset noise states and input list - noise_states = self._initialize_noise_states() + noise_states = self._initialize_noise_states(seed_offset=idt) inptlist = [inp.clone() for _ in range(self.params.local_ensemble_size)] if rollout_buffer is not None: @@ -602,9 +603,8 @@ def _inference_indexlist( # retrieve input inpt = inptlist[e] - # this is different, depending on local ensemble size + # restore noise state if (self.params.local_ensemble_size > 1): - # restore noise belonging to this ensemble member self.preprocessor.set_internal_state(noise_states[e]) # forward pass: never replace state since we do that manually diff --git a/makani/utils/loss.py b/makani/utils/loss.py index 887c19f..64f1774 100644 --- a/makani/utils/loss.py +++ b/makani/utils/loss.py @@ -23,17 +23,20 @@ from torch import nn from makani.utils import comm -from makani.utils.grids import GridQuadrature -from makani.utils.dataloaders.data_helpers import get_data_normalization +from makani.utils.grids import GridQuadrature, BandLimitMask +from makani.utils.dataloaders.data_helpers import get_data_normalization, get_time_diff_stds, get_psd_stats from physicsnemo.distributed.utils import compute_split_shapes from physicsnemo.distributed.mappings import gather_from_parallel_region, reduce_from_parallel_region import torch_harmonics as harmonics from torch_harmonics.quadrature import clenshaw_curtiss_weights, legendre_gauss_weights -from .losses import LossType, GeometricLpLoss, SpectralH1Loss, SpectralAMSELoss, HydrostaticBalanceLoss -from .losses import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleNLLLoss, EnsembleMMDLoss -from .losses import DriftRegularization +from .losses import LossType, GeometricLpLoss, SpectralLpLoss, SpectralH1Loss, SpectralAMSELoss +from .losses import CRPSLoss, SpectralCRPSLoss, GradientCRPSLoss, VortDivCRPSLoss +from .losses import LpEnergyScoreLoss, SobolevEnergyScoreLoss +from .losses import GaussianMMDLoss +from .losses import EnsembleNLLLoss +from .losses import DriftRegularization, HydrostaticBalanceLoss class LossHandler(nn.Module): @@ -47,6 +50,7 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps self.rank = comm.get_rank("matmul") self.n_future = params.n_future + self.n_history = params.n_history self.spatial_distributed = comm.is_distributed("spatial") and (comm.get_size("spatial") > 1) self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) @@ -57,11 +61,12 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps # check whether dynamic loss weighting is required self.uncertainty_weighting = params.get("uncertainty_weighting", False) + self.balanced_weighting = params.get("balanced_weighting", False) self.randomized_loss_weights = params.get("randomized_loss_weights", False) self.random_slice_loss = params.get("random_slice_loss", False) # whether to keep running stats - self.track_running_stats = track_running_stats or self.uncertainty_weighting + self.track_running_stats = track_running_stats or self.uncertainty_weighting or self.balanced_weighting self.eps = eps n_channels = len(params.channel_names) @@ -86,14 +91,29 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps else: scale = torch.ones((1, len(params.out_channels), 1, 1), dtype=torch.float32) + # load PSD stats + try: + psd_means, psd_stds = get_psd_stats(params) + if psd_means is not None: + psd_means = torch.from_numpy(psd_means).to(torch.float32) + if psd_stds is not None: + psd_stds = torch.from_numpy(psd_stds).to(torch.float32) + except ValueError: + psd_means = None + psd_stds = None + # create module list self.loss_fn = nn.ModuleList([]) + self.loss_requires_input = [] # track which losses need input state channel_weights = [] for loss in losses: loss_type = loss["type"] + # check if this is a tendency loss (from explicit field, not string parsing) + requires_input = loss.get("tendency", False) + # get pole mask if it was specified pole_mask = loss.get("pole_mask", 0) @@ -110,6 +130,8 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps pole_mask=pole_mask, bias=bias, scale=scale, + psd_means=psd_means, + psd_stds=psd_stds, grid_type=params.model_grid_type, spatial_distributed=self.spatial_distributed, ensemble_distributed=self.ensemble_distributed, @@ -117,9 +139,8 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps ) # append to dict and compile before: - # TODO: fix the compile issue - # self.loss_fn[loss_type] = torch.compile(loss_fn) self.loss_fn.append(loss_fn) + self.loss_requires_input.append(requires_input) # determine channel weighting if "channel_weights" not in loss.keys(): @@ -127,32 +148,24 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps else: channel_weight_type = loss["channel_weights"] + # check if time difference weighting is required + if loss.get("temp_diff_normalization", False): + time_diff_scale = get_time_diff_stds(params).flatten() + time_diff_scale = torch.clamp(torch.from_numpy(time_diff_scale[params.out_channels]), min=1e-4) + time_diff_scale = scale.flatten() / time_diff_scale + else: + time_diff_scale = None + + # get channel weights either directly or through the compute routine if isinstance(channel_weight_type, List): - chw = torch.tensor(channel_weight_type, dtype=torch.float32).reshape(1, -1) + chw = torch.tensor(channel_weight_type, dtype=torch.float32) + if time_diff_scale is not None: + chw = chw * time_diff_scale assert chw.shape[1] == loss_fn.n_channels else: - chw = loss_fn.compute_channel_weighting(channel_weight_type) - - # the option to normalize outputs with stds of the time difference rather than th - if ("temp_diff_normalization" in loss.keys()) and loss["temp_diff_normalization"]: - - # extract relevant stds - time_diff_stds = torch.from_numpy(np.load(params.time_diff_stds_path)).reshape(1, -1)[:, params.out_channels] - # the time differences are computed between two consecutive datapoints, - # so we need to account for the number of timesteps used in the prediction - # this is now commebnted out as we expect the stats to be computed with the correct dt - # time_diff_stds *= np.sqrt(params.dt) - - # to avoid division by very small numbers, we clamp the time differences from below - time_diff_stds = torch.clamp(time_diff_stds, min=1e-4) - - time_var_weights = scale.reshape(1, -1) / time_diff_stds - - if hasattr(loss_fn, "squared") and loss_fn.squared: - time_var_weights = time_var_weights**2 - - chw = chw * time_var_weights + chw = loss_fn.compute_channel_weighting(channel_weight_type, time_diff_scale=time_diff_scale) + # reshape channel weights for propewr broadcasting chw = chw.reshape(1, -1) # check for a relative weight that weights the loss relative to other losses @@ -199,13 +212,18 @@ def _parse_loss_type(self, loss_type: str): loss_type = set(loss_type.split()) + # this can probably all be moved to the loss function itself relative = "relative" in loss_type squared = "squared" in loss_type jacobian = "s2" if "geometric" in loss_type else "flat" # decide which loss to use - if "l2" in loss_type: + if "spectral" in loss_type and "l2" in loss_type: + loss_handle = partial(SpectralLpLoss, p=2, relative=relative, squared=squared) + elif "spectral" in loss_type and "l1" in loss_type: + loss_handle = partial(SpectralLpLoss, p=1, relative=relative, squared=squared) + elif "l2" in loss_type: loss_handle = partial(GeometricLpLoss, p=2, relative=relative, squared=squared, jacobian=jacobian) elif "l1" in loss_type: loss_handle = partial(GeometricLpLoss, p=1, relative=relative, squared=squared, jacobian=jacobian) @@ -226,15 +244,21 @@ def _parse_loss_type(self, loss_type: str): p_max = int(x.replace("p_max=", "")) loss_handle = partial(HydrostaticBalanceLoss, p_min=p_min, p_max=p_max, use_moist_air_formula=use_moist_air_formula) elif "ensemble_crps" in loss_type: - loss_handle = partial(EnsembleCRPSLoss, crps_type="cdf") + loss_handle = partial(CRPSLoss) elif "ensemble_spectral_crps" in loss_type: - loss_handle = partial(EnsembleSpectralCRPSLoss, crps_type="cdf") - elif "gauss_crps" in loss_type: - loss_handle = partial(EnsembleCRPSLoss, crps_type="gauss") + loss_handle = partial(SpectralCRPSLoss) + elif "ensemble_vort_div_crps" in loss_type: + loss_handle = partial(VortDivCRPSLoss) + elif "ensemble_gradient_crps" in loss_type: + loss_handle = partial(GradientCRPSLoss) elif "ensemble_nll" in loss_type: loss_handle = EnsembleNLLLoss - elif "ensemble_mmd" in loss_type: - loss_handle = EnsembleMMDLoss + elif "gaussian_mmd" in loss_type: + loss_handle = GaussianMMDLoss + elif "energy_score" in loss_type: + loss_handle = partial(LpEnergyScoreLoss) + elif "sobolev_energy_score" in loss_type: + loss_handle = partial(SobolevEnergyScoreLoss) elif "drift_regularization" in loss_type: loss_handle = DriftRegularization else: @@ -293,7 +317,23 @@ def reset_running_stats(self): self.running_var.fill_(1) self.num_batches_tracked.zero_() - def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None): + def _extract_input_state(self, inp: torch.Tensor) -> torch.Tensor: + """ + Extract last timestep from flattened history input. + + Args: + inp: Input tensor with shape (B, (n_history+1)*C, H, W) + + Returns: + Last timestep with shape (B, C, H, W) + """ + # inp shape: (B, (n_history+1)*C, H, W) + # we want: (B, C, H, W) - the last timestep + n_channels_per_step = inp.shape[1] // (self.n_history + 1) + inp_last = inp[..., -n_channels_per_step:, :, :] + return inp_last + + def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None, inp: Optional[torch.Tensor] = None): # we assume the following: # if prd is 5D, we assume that the dims are # batch, ensemble, channel, h, w @@ -331,15 +371,52 @@ def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tens else: prdm = prd + # transform to tendency space if any loss requires it + if inp is not None and any(self.loss_requires_input): + inp_state = self._extract_input_state(inp) + + # validate channel counts for single-step predictions + if self.n_future == 0: + n_pred_channels = prdm.shape[1] + n_inp_channels = inp_state.shape[1] + assert n_pred_channels == n_inp_channels, \ + f"Channel mismatch: prediction has {n_pred_channels} channels but input has {n_inp_channels} channels" + + # transform predictions and targets to tendency space + # this allows ANY loss function to compute tendency-based metrics + prdm_tendency = prdm - inp_state + tar_tendency = tar - inp_state + + # also transform ensemble predictions if present + if prd.dim() == 5: + # expand inp_state to match ensemble dim + inp_state_expanded = inp_state.unsqueeze(1) + prd_tendency = prd - inp_state_expanded + else: + prd_tendency = prdm_tendency + else: + prdm_tendency = prdm + tar_tendency = tar + prd_tendency = prd + # compute loss contributions from each loss loss_vals = [] - for lfn in self.loss_fn: + for lfn, requires_inp in zip(self.loss_fn, self.loss_requires_input): if lfn.type == LossType.Deterministic: - loss_vals.append(lfn(prdm, tar, wgt)) + if requires_inp: + loss_vals.append(lfn(prdm_tendency, tar_tendency, wgt)) + else: + loss_vals.append(lfn(prdm, tar, wgt)) else: - loss_vals.append(lfn(prd, tar, wgt)) + # probabilistic losses: use tendency-transformed ensemble if needed + if requires_inp: + loss_vals.append(lfn(prd_tendency, tar_tendency, wgt)) + else: + loss_vals.append(lfn(prd, tar, wgt)) all_losses = torch.cat(loss_vals, dim=-1) + # print(all_losses) + if self.training and self.track_running_stats: self._update_running_stats(all_losses.clone()) @@ -347,7 +424,14 @@ def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tens chw = self.channel_weights if self.uncertainty_weighting and self.training: var, _ = self.get_running_stats() + if self.num_batches_tracked.item() <= 100: + var = torch.ones_like(var) chw = chw / (torch.sqrt(2 * var) + self.eps) + elif self.balanced_weighting and self.training: + _, mean = self.get_running_stats() + if self.num_batches_tracked.item() <= 100: + mean = torch.ones_like(mean) + chw = chw / (mean + self.eps) if self.randomized_loss_weights: rmask = torch.zeros_like(chw) diff --git a/makani/utils/losses/__init__.py b/makani/utils/losses/__init__.py index 48d63f3..048481b 100644 --- a/makani/utils/losses/__init__.py +++ b/makani/utils/losses/__init__.py @@ -15,10 +15,11 @@ from .base_loss import LossType from .h1_loss import SpectralH1Loss -from .lp_loss import GeometricLpLoss, SpectralL2Loss +from .lp_loss import GeometricLpLoss, SpectralLpLoss from .amse_loss import SpectralAMSELoss from .hydrostatic_loss import HydrostaticBalanceLoss -from .crps_loss import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss -from .mmd_loss import EnsembleMMDLoss +from .crps_loss import CRPSLoss, SpectralCRPSLoss, GradientCRPSLoss, VortDivCRPSLoss +from .energy_score import LpEnergyScoreLoss, SobolevEnergyScoreLoss +from .mmd_loss import GaussianMMDLoss from .likelihood_loss import EnsembleNLLLoss from .drift_regularization import DriftRegularization diff --git a/makani/utils/losses/amse_loss.py b/makani/utils/losses/amse_loss.py index 89db8fb..8ae392c 100644 --- a/makani/utils/losses/amse_loss.py +++ b/makani/utils/losses/amse_loss.py @@ -61,7 +61,7 @@ def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tens with amp.autocast(device_type="cuda", enabled=False): xcoeffs = self.sht(prd) ycoeffs = self.sht(tar) - + # compute the SHT: xcoeffssq = torch.square(torch.abs(xcoeffs)) ycoeffssq = torch.square(torch.abs(ycoeffs)) @@ -100,5 +100,5 @@ def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tens loss = torch.sum(loss, dim=-1) if self.spatial_distributed and (comm.get_size("h") > 1): loss = reduce_from_parallel_region(loss, "h") - + return loss diff --git a/makani/utils/losses/base_loss.py b/makani/utils/losses/base_loss.py index b71006c..171fb0a 100644 --- a/makani/utils/losses/base_loss.py +++ b/makani/utils/losses/base_loss.py @@ -17,6 +17,7 @@ from dataclasses import dataclass from abc import ABCMeta, abstractmethod +import math import torch import torch.nn as nn @@ -26,9 +27,10 @@ from makani.utils.grids import grid_to_quadrature_rule, GridQuadrature from makani.utils import comm +from makani.utils.features import get_wind_channels -def _compute_channel_weighting_helper(channel_names: List[str], channel_weight_type: str) -> torch.Tensor: +def _compute_channel_weighting_helper(channel_names: List[str], channel_weight_type: str, time_diff_scale: torch.Tensor = None) -> torch.Tensor: """ auxiliary routine for predetermining channel weighting """ @@ -43,7 +45,7 @@ def _compute_channel_weighting_helper(channel_names: List[str], channel_weight_t elif channel_weight_type == "auto": for c, chn in enumerate(channel_names): - if chn in ["u10m", "v10m", "u100m", "v100m", "tp", "sp", "msl", "tcwv"]: + if chn in ["u10m", "v10m", "u100m", "v100m", "tp", "sp", "msl", "tcwv", "sst"]: channel_weights[c] = 0.1 elif chn in ["t2m", "2d"]: channel_weights[c] = 1.0 @@ -53,6 +55,32 @@ def _compute_channel_weighting_helper(channel_names: List[str], channel_weight_t else: channel_weights[c] = 0.01 + elif channel_weight_type == "new auto": + + for c, chn in enumerate(channel_names): + if chn in ["u10m", "v10m", "u100m", "v100m", "tp", "sp", "msl", "tcwv", "sst"]: + channel_weights[c] = 0.1 + elif chn in ["t2m", "2d"]: + channel_weights[c] = 2.0 + elif chn[0] in ["z", "u", "v", "t", "r", "q"]: + pressure_level = float(chn[1:]) + channel_weights[c] = max(0.2, 0.001 * pressure_level) + else: + channel_weights[c] = 0.01 + + elif channel_weight_type == "new auto 2": + + for c, chn in enumerate(channel_names): + if chn in ["u10m", "v10m", "u100m", "v100m", "tp", "sp", "msl", "tcwv", "sst"]: + channel_weights[c] = 0.1 + elif chn in ["t2m", "2d"]: + channel_weights[c] = 2.0 + elif chn[0] in ["z", "u", "v", "t", "r", "q"]: + pressure_level = float(chn[1:]) + channel_weights[c] = max(0.3, 0.001 * pressure_level) + else: + channel_weights[c] = 0.01 + elif channel_weight_type == "custom": weight_dict = { @@ -216,6 +244,10 @@ def _compute_channel_weighting_helper(channel_names: List[str], channel_weight_t # normalize channel_weights = channel_weights / torch.sum(channel_weights) + # get the time differences and weigh them additionally + if time_diff_scale is not None: + channel_weights = channel_weights * time_diff_scale + return channel_weights @@ -250,9 +282,8 @@ def __init__( self.pole_mask = pole_mask self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + # get the quadrature rule for the corresponding grid quadrature_rule = grid_to_quadrature_rule(grid_type) - - # get the quadrature self.quadrature = GridQuadrature( quadrature_rule, img_shape=self.img_shape, @@ -272,8 +303,8 @@ def n_channels(self): return len(self.channel_names) @torch.compiler.disable(recursive=False) - def compute_channel_weighting(self, channel_weight_type: str) -> torch.Tensor: - return _compute_channel_weighting_helper(self.channel_names, channel_weight_type) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: torch.Tensor = None) -> torch.Tensor: + return _compute_channel_weighting_helper(self.channel_names, channel_weight_type, time_diff_scale=time_diff_scale) @abstractmethod def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None) -> torch.Tensor: @@ -292,6 +323,7 @@ def __init__( crop_offset: Tuple[int, int], channel_names: List[str], grid_type: str, + lmax: Optional[int] = None, spatial_distributed: Optional[bool] = False, ): super().__init__() @@ -302,14 +334,15 @@ def __init__( self.channel_names = channel_names self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + # SHT for the computation of SH coefficients if self.spatial_distributed and (comm.get_size("spatial") > 1): if not thd.is_initialized(): polar_group = None if (comm.get_size("h") == 1) else comm.get_group("h") azimuth_group = None if (comm.get_size("w") == 1) else comm.get_group("w") thd.init(polar_group, azimuth_group) - self.sht = thd.DistributedRealSHT(*img_shape, grid=grid_type) + self.sht = thd.DistributedRealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type) else: - self.sht = th.RealSHT(*img_shape, grid=grid_type).float() + self.sht = th.RealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type).float() @property def type(self): @@ -320,8 +353,150 @@ def n_channels(self): return len(self.n_channels) @torch.compiler.disable(recursive=False) - def compute_channel_weighting(self, channel_weight_type: str) -> torch.Tensor: - return _compute_channel_weighting_helper(self.channel_names, channel_weight_type) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: torch.Tensor = None) -> torch.Tensor: + return _compute_channel_weighting_helper(self.channel_names, channel_weight_type, time_diff_scale=time_diff_scale) + + @abstractmethod + def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None) -> torch.Tensor: + pass + + +class VortDivBaseLoss(nn.Module, metaclass=ABCMeta): + """ + Geometric base loss class used by all geometric losses + """ + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + pole_mask: int, + lmax: Optional[int] = None, + spatial_distributed: Optional[bool] = False, + ): + super().__init__() + + self.img_shape = img_shape + self.crop_shape = crop_shape + self.crop_offset = crop_offset + self.channel_names = channel_names + self.pole_mask = pole_mask + self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + + # get the wind channels + wind_chans = get_wind_channels(self.channel_names) + self.register_buffer("wind_chans", torch.LongTensor(wind_chans)) + + if self.spatial_distributed and (comm.get_size("spatial") > 1): + if not thd.is_initialized(): + polar_group = None if (comm.get_size("h") == 1) else comm.get_group("h") + azimuth_group = None if (comm.get_size("w") == 1) else comm.get_group("w") + thd.init(polar_group, azimuth_group) + self.vsht = thd.DistributedRealVectorSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type) + self.isht = thd.DistributedInverseRealVectorSHT(nlat=self.vsht.nlat, nlon=self.vsht.nlon, lmax=lmax, mmax=lmax, grid=grid_type) + else: + self.vsht = th.RealVectorSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type) + self.isht = th.InverseRealVectorSHT(nlat=self.vsht.nlat, nlon=self.vsht.nlon, lmax=lmax, mmax=lmax, grid=grid_type) + + # get the quadrature rule for the corresponding grid + quadrature_rule = grid_to_quadrature_rule(grid_type) + self.quadrature = GridQuadrature( + quadrature_rule, + img_shape=self.img_shape, + crop_shape=self.crop_shape, + crop_offset=self.crop_offset, + normalize=True, + pole_mask=self.pole_mask, + distributed=self.spatial_distributed, + ) + + @property + def type(self): + return LossType.Deterministic + + @property + def n_channels(self): + return len(self.n_channels) + + @torch.compiler.disable(recursive=False) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: torch.Tensor = None) -> torch.Tensor: + + chw = _compute_channel_weighting_helper(self.channel_names, channel_weight_type, time_diff_scale=time_diff_scale) + chw = chw[self.wind_chans.to(chw.device)] + + # average u and v component weightings to weight vort and div equally + chw[1::2] = (chw[1::2] + chw[0::2]) / 2 + chw[0::2] = chw[1::2] + + return chw + + @abstractmethod + def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None) -> torch.Tensor: + pass + +class GradientBaseLoss(nn.Module, metaclass=ABCMeta): + """ + Gradient base loss class used by all gradient based losses + """ + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + pole_mask: int, + lmax: Optional[int] = None, + spatial_distributed: Optional[bool] = False, + ): + super().__init__() + + self.img_shape = img_shape + self.crop_shape = crop_shape + self.crop_offset = crop_offset + self.channel_names = channel_names + self.pole_mask = pole_mask + self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + + if self.spatial_distributed and (comm.get_size("spatial") > 1): + if not thd.is_initialized(): + polar_group = None if (comm.get_size("h") == 1) else comm.get_group("h") + azimuth_group = None if (comm.get_size("w") == 1) else comm.get_group("w") + thd.init(polar_group, azimuth_group) + self.sht = thd.DistributedRealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type) + self.ivsht = thd.DistributedInverseRealVectorSHT(nlat=self.sht.nlat, nlon=self.sht.nlon, lmax=lmax, mmax=lmax, grid=grid_type) + else: + self.sht = th.RealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type) + self.ivsht = th.InverseRealVectorSHT(nlat=self.sht.nlat, nlon=self.sht.nlon, lmax=lmax, mmax=lmax, grid=grid_type) + + # get the quadrature rule for the corresponding grid + quadrature_rule = grid_to_quadrature_rule(grid_type) + self.quadrature = GridQuadrature( + quadrature_rule, + img_shape=self.img_shape, + crop_shape=self.crop_shape, + crop_offset=self.crop_offset, + normalize=True, + pole_mask=self.pole_mask, + distributed=self.spatial_distributed, + ) + + @property + def type(self): + return LossType.Deterministic + + @property + def n_channels(self): + return len(self.n_channels) + + @torch.compiler.disable(recursive=False) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: torch.Tensor = None) -> torch.Tensor: + return _compute_channel_weighting_helper(self.channel_names, channel_weight_type, time_diff_scale=time_diff_scale) + @abstractmethod def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None) -> torch.Tensor: diff --git a/makani/utils/losses/crps_loss.py b/makani/utils/losses/crps_loss.py index 90c6224..47af667 100644 --- a/makani/utils/losses/crps_loss.py +++ b/makani/utils/losses/crps_loss.py @@ -22,7 +22,7 @@ import torch.nn as nn from torch import amp -from makani.utils.losses.base_loss import GeometricBaseLoss, SpectralBaseLoss, LossType +from makani.utils.losses.base_loss import LossType, GeometricBaseLoss, SpectralBaseLoss, VortDivBaseLoss, GradientBaseLoss from makani.utils import comm # distributed stuff @@ -49,6 +49,9 @@ def _crps_ensemble_kernel(observation: torch.Tensor, forecasts: torch.Tensor, we CRPS ensemble score from integrating the PDF piecewise compare https://github.com/properscoring/properscoring/blob/master/properscoring/_gufuncs.py#L7 disabling torch compile for the moment due to very long startup times when training large ensembles with ensemble parallelism + + forecasts: [ensemble, ...], observation: [...], weights: [ensemble, ...] + Assumes forecasts are sorted along ensemble dimension 0. """ # beware: forecasts are assumed sorted in sorted order @@ -110,14 +113,13 @@ def _crps_ensemble_kernel(observation: torch.Tensor, forecasts: torch.Tensor, we def _crps_skillspread_kernel(observation: torch.Tensor, forecasts: torch.Tensor, weights: torch.Tensor, alpha: float) -> torch.Tensor: """ - alternative CRPS variant that uses spread and skill + fair CRPS variant that uses spread and skill. Assumes pre-sorted ensemble """ observation = observation.unsqueeze(0) - # get nanmask - nanmasks = torch.logical_or(torch.isnan(forecasts), torch.isnan(weights)) - nanmask = torch.sum(nanmasks, dim=0).bool() + # get nanmask from the observarions + nanmasks = torch.logical_or(torch.isnan(observation), torch.isnan(weights)) # compute total weights nweights = torch.where(nanmasks, 0.0, weights) @@ -133,11 +135,73 @@ def _crps_skillspread_kernel(observation: torch.Tensor, forecasts: torch.Tensor, espread = 2 * torch.mean((2 * rank - num_ensemble - 1) * forecasts, dim=0) * (float(num_ensemble) - 1.0 + alpha) / float(num_ensemble * (num_ensemble - 1)) eskill = (observation - forecasts).abs().mean(dim=0) + crps = torch.where(nanmasks.sum(dim=0) != 0, 0.0, eskill - 0.5 * espread) + + return crps + + +def _crps_probability_weighted_moment_kernel(observation: torch.Tensor, forecasts: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + """ + CRPS estimator based on the probability weighted moment. see [1]. + + [1] Michael Zamo, Phillippe Naveau. Estimation of the Continuous Ranked Probability Score with Limited Information and Applications to Ensemble Weather Forecasts. Mathematical Geosciences. Volume 50 pp. 209-234. 2018. + """ + + observation = observation.unsqueeze(0) + + # get nanmask from the observarions + nanmasks = torch.logical_or(torch.isnan(observation), torch.isnan(weights)) + + # compute total weights + nweights = torch.where(nanmasks, 0.0, weights) + total_weight = torch.sum(nweights, dim=0, keepdim=True) + + # ensemble size + num_ensemble = forecasts.shape[0] + + # get the ranks for the pwm computation + rank = torch.arange(num_ensemble, device=forecasts.device).reshape((num_ensemble,) + (1,) * (forecasts.dim() - 1)) + + # get the ensemble spread (total_weight is ensemble size here) + beta0 = forecasts.mean(dim=0) + beta1 = (rank * forecasts).sum(dim=0) / float(num_ensemble * (num_ensemble - 1)) + eskill = (observation - forecasts).abs().mean(dim=0) + + # crps = torch.where(nanmasks.sum(dim=0) != 0, torch.nan, eskill - 0.5 * espread) + crps = eskill + beta0 - 2 * beta1 + + # set to nan for first forecasts nan + crps = torch.where(nanmasks.sum(dim=0) != 0, 0.0, crps) + + return crps + + +def _crps_naive_skillspread_kernel(observation: torch.Tensor, forecasts: torch.Tensor, weights: torch.Tensor, alpha: float) -> torch.Tensor: + """ + alternative fair CRPS variant that uses spread and skill. Uses naive computation which is O(N^2) in the number of ensemble members. Useful for complex + """ + + observation = observation.unsqueeze(0) + + # get nanmask from the observarions + nanmasks = torch.logical_or(torch.isnan(observation), torch.isnan(weights)) + + # compute total weights + nweights = torch.where(nanmasks, 0.0, weights) + total_weight = torch.sum(nweights, dim=0, keepdim=True) + + # ensemble size + num_ensemble = forecasts.shape[0] + + # use broadcasting semantics to compute spread and skill + espread = (forecasts.unsqueeze(1) - forecasts.unsqueeze(0)).abs().sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) + eskill = (observation - forecasts).abs().mean(dim=0) + # crps = torch.where(nanmasks.sum(dim=0) != 0, torch.nan, eskill - 0.5 * espread) crps = eskill - 0.5 * espread # set to nan for first forecasts nan - crps = torch.where(nanmask, torch.nan, crps) + crps = torch.where(nanmasks.sum(dim=0) != 0, 0.0, crps) return crps @@ -173,7 +237,7 @@ def _crps_gauss_kernel(observation: torch.Tensor, forecasts: torch.Tensor, weigh return crps -class EnsembleCRPSLoss(GeometricBaseLoss): +class CRPSLoss(GeometricBaseLoss): def __init__( self, @@ -278,6 +342,17 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w # compute score crps = _crps_ensemble_kernel(observations, forecasts, ensemble_weights) + elif self.crps_type == "probability weighted moment": + # now, E dimension is local and spatial dim is split further + # we need to sort the forecasts now + forecasts, idx = torch.sort(forecasts, dim=0) + if self.ensemble_weights is not None: + ensemble_weights = self.ensemble_weights[idx] + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_probability_weighted_moment_kernel(observations, forecasts, ensemble_weights) elif self.crps_type == "skillspread": if self.ensemble_weights is not None: raise NotImplementedError("currently only constant ensemble weights are supported") @@ -286,6 +361,14 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w # compute score crps = _crps_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) + elif self.crps_type == "naive skillspread": + if self.ensemble_weights is not None: + raise NotImplementedError("currently only constant ensemble weights are supported") + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_naive_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) elif self.crps_type == "gauss": if self.ensemble_weights is not None: ensemble_weights = self.ensemble_weights[idx] @@ -316,7 +399,7 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w return crps -class EnsembleSpectralCRPSLoss(SpectralBaseLoss): +class SpectralCRPSLoss(SpectralBaseLoss): def __init__( self, @@ -325,6 +408,7 @@ def __init__( crop_offset: Tuple[int, int], channel_names: List[str], grid_type: str, + lmax: Optional[int] = None, crps_type: str = "skillspread", spatial_distributed: Optional[bool] = False, ensemble_distributed: Optional[bool] = False, @@ -341,6 +425,7 @@ def __init__( crop_offset=crop_offset, channel_names=channel_names, grid_type=grid_type, + lmax=lmax, spatial_distributed=spatial_distributed, ) @@ -405,13 +490,8 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spectral_ forecasts = torch.abs(forecasts).to(dtype) observations = torch.abs(observations).to(dtype) else: - forecasts = torch.view_as_real(forecasts).to(dtype) - observations = torch.view_as_real(observations).to(dtype) - - # merge complex dimension after channel dimension and flatten - # this needs to be undone at the end - forecasts = torch.movedim(forecasts, 5, 3).flatten(2, 3) - observations = torch.movedim(observations, 4, 2).flatten(1, 2) + # since the other kernels require sorting, this approach only works with the naive CRPS kernel + assert self.crps_type == "skillspread" # we assume the following shapes: # forecasts: batch, ensemble, channels, mmax, lmax @@ -460,6 +540,17 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spectral_ # compute score crps = _crps_ensemble_kernel(observations, forecasts, ensemble_weights) + elif self.crps_type == "probability weighted moment": + # now, E dimension is local and spatial dim is split further + # we need to sort the forecasts now + forecasts, idx = torch.sort(forecasts, dim=0) + if self.ensemble_weights is not None: + ensemble_weights = self.ensemble_weights[idx] + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_probability_weighted_moment_kernel(observations, forecasts, ensemble_weights) elif self.crps_type == "skillspread": if self.ensemble_weights is not None: raise NotImplementedError("currently only constant ensemble weights are supported") @@ -467,7 +558,10 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spectral_ ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) # compute score - crps = _crps_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) + if self.absolute: + crps = _crps_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) + else: + crps = _crps_naive_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) elif self.crps_type == "gauss": if self.ensemble_weights is not None: ensemble_weights = self.ensemble_weights[idx] @@ -488,9 +582,359 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spectral_ if self.ensemble_distributed: crps = reduce_from_parallel_region(crps, "ensemble") - # finally undo the folding of the complex dimension into the channel dimension - if not self.absolute: - crps = crps.reshape(B, -1, 2).sum(dim=-1) + # the resulting tensor should have dimension B, C, which is what we return + return crps + +class GradientCRPSLoss(GradientBaseLoss): + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + pole_mask: int, + lmax: Optional[int] = None, + crps_type: str = "skillspread", + spatial_distributed: Optional[bool] = False, + ensemble_distributed: Optional[bool] = False, + ensemble_weights: Optional[torch.Tensor] = None, + absolute: Optional[bool] = True, + alpha: Optional[float] = 1.0, + eps: Optional[float] = 1.0e-5, + **kwargs, + ): + + super().__init__( + img_shape=img_shape, + crop_shape=crop_shape, + crop_offset=crop_offset, + channel_names=channel_names, + grid_type=grid_type, + pole_mask=pole_mask, + lmax=lmax, + spatial_distributed=spatial_distributed, + ) + + # if absolute is true, the loss is computed only on the absolute value of the gradient + self.absolute = absolute + + self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed + self.crps_type = crps_type + self.alpha = alpha + self.eps = eps + + if (self.crps_type != "skillspread") and (self.alpha < 1.0): + raise NotImplementedError("The alpha parameter (almost fair CRPS factor) is only supported for the skillspread kernel.") + + # we also need a variant of the weights split in ensemble direction: + quad_weight_split = self.quadrature.quad_weight.reshape(1, 1, -1) + if self.ensemble_distributed: + quad_weight_split = split_tensor_along_dim(quad_weight_split, dim=-1, num_chunks=comm.get_size("ensemble"))[comm.get_rank("ensemble")] + quad_weight_split = quad_weight_split.contiguous() + self.register_buffer("quad_weight_split", quad_weight_split, persistent=False) + + if ensemble_weights is not None: + self.register_buffer("ensemble_weights", ensemble_weights, persistent=False) + else: + self.ensemble_weights = ensemble_weights + + @property + def type(self): + return LossType.Probabilistic + + @property + def n_channels(self): + if self.absolute: + return len(self.channel_names) + else: + return 2 * len(self.channel_names) + + @torch.compiler.disable(recursive=False) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: torch.Tensor = None) -> torch.Tensor: + chw = super().compute_channel_weighting(channel_weight_type, time_diff_scale=time_diff_scale) + + if self.absolute: + return chw + else: + return [weight for weight in chw for _ in range(2)] + + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + + # sanity checks + if forecasts.dim() != 5: + raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.") + + # we assume that spatial_weights have NO ensemble dim + if (spatial_weights is not None) and (spatial_weights.dim() != observations.dim()): + spdim = spatial_weights.dim() + odim = observations.dim() + raise ValueError(f"the weights have to have the same number of dimensions (found {spdim}) as observations (found {odim}).") + + # we assume the following shapes: + # forecasts: batch, ensemble, channels, lat, lon + # observations: batch, channels, lat, lon + B, E, C, H, W = forecasts.shape + + # get the data type before stripping amp types + dtype = forecasts.dtype + + # before anything else compute the transform + # as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same + with amp.autocast(device_type="cuda", enabled=False): + + # compute the SH coefficients of the forecasts and observations + forecasts = self.sht(forecasts.float()).unsqueeze(-3) + observations = self.sht(observations.float()).unsqueeze(-3) + + # append zeros, so that we can use the inverse vector SHT + forecasts = torch.cat([forecasts, torch.zeros_like(forecasts)], dim=-3) + observations = torch.cat([observations, torch.zeros_like(observations)], dim=-3) + + forecasts = self.ivsht(forecasts) + observations = self.ivsht(observations) + + forecasts = forecasts.to(dtype) + observations = observations.to(dtype) + + if self.absolute: + forecasts = forecasts.pow(2).sum(dim=-3).sqrt() + observations = observations.pow(2).sum(dim=-3).sqrt() + else: + C = 2 * C + + forecasts = forecasts.reshape(B, E, C, H, W) + observations = observations.reshape(B, C, H, W) + + # if ensemble dim is one dimensional then computing the score is quick: + if (not self.ensemble_distributed) and (E == 1): + # in this case, CRPS is straightforward + crps = torch.abs(observations - forecasts.squeeze(1)).reshape(B, C, H * W) + else: + # transpose forecasts: ensemble, batch, channels, lat, lon + forecasts = torch.moveaxis(forecasts, 1, 0) + + # now we need to transpose the forecasts into ensemble direction. + # ideally we split spatial dims + forecasts = forecasts.reshape(E, B, C, H * W) + if self.ensemble_distributed: + ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] + forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") + # observations does not need a transpose, but just a split + observations = observations.reshape(B, C, H * W) + if self.ensemble_distributed: + observations = scatter_to_parallel_region(observations, -1, "ensemble") + if spatial_weights is not None: + spatial_weights_split = spatial_weights.flatten(start_dim=-2, end_dim=-1) + spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") + + # run appropriate crps kernel to compute it pointwise + if self.crps_type == "cdf": + # now, E dimension is local and spatial dim is split further + # we need to sort the forecasts now + forecasts, idx = torch.sort(forecasts, dim=0) + if self.ensemble_weights is not None: + ensemble_weights = self.ensemble_weights[idx] + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_ensemble_kernel(observations, forecasts, ensemble_weights) + elif self.crps_type == "skillspread": + if self.ensemble_weights is not None: + raise NotImplementedError("currently only constant ensemble weights are supported") + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) + elif self.crps_type == "gauss": + if self.ensemble_weights is not None: + ensemble_weights = self.ensemble_weights[idx] + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_gauss_kernel(observations, forecasts, ensemble_weights, self.eps) + else: + raise ValueError(f"Unknown CRPS crps_type {self.crps_type}") + + # perform ensemble and spatial average of crps score + if spatial_weights is not None: + crps = torch.sum(crps * self.quad_weight_split * spatial_weights_split, dim=-1) + else: + crps = torch.sum(crps * self.quad_weight_split, dim=-1) + + # since we split spatial dim into ensemble dim, we need to do an ensemble sum as well + if self.ensemble_distributed: + crps = reduce_from_parallel_region(crps, "ensemble") + + # we need to do the spatial averaging manually since + # we are not calling he quadrature forward function + if self.spatial_distributed: + crps = reduce_from_parallel_region(crps, "spatial") + + # the resulting tensor should have dimension B, C, which is what we return + return crps + +class VortDivCRPSLoss(VortDivBaseLoss): + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + pole_mask: int, + crps_type: str = "skillspread", + spatial_distributed: Optional[bool] = False, + ensemble_distributed: Optional[bool] = False, + ensemble_weights: Optional[torch.Tensor] = None, + alpha: Optional[float] = 1.0, + eps: Optional[float] = 1.0e-5, + **kwargs, + ): + + super().__init__( + img_shape=img_shape, + crop_shape=crop_shape, + crop_offset=crop_offset, + channel_names=channel_names, + grid_type=grid_type, + pole_mask=pole_mask, + spatial_distributed=spatial_distributed, + ) + + self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed + self.crps_type = crps_type + self.alpha = alpha + self.eps = eps + + if (self.crps_type != "skillspread") and (self.alpha < 1.0): + raise NotImplementedError("The alpha parameter (almost fair CRPS factor) is only supported for the skillspread kernel.") + + # we also need a variant of the weights split in ensemble direction: + quad_weight_split = self.quadrature.quad_weight.reshape(1, 1, -1) + if self.ensemble_distributed: + quad_weight_split = split_tensor_along_dim(quad_weight_split, dim=-1, num_chunks=comm.get_size("ensemble"))[comm.get_rank("ensemble")] + quad_weight_split = quad_weight_split.contiguous() + self.register_buffer("quad_weight_split", quad_weight_split, persistent=False) + + if ensemble_weights is not None: + self.register_buffer("ensemble_weights", ensemble_weights, persistent=False) + else: + self.ensemble_weights = ensemble_weights + + @property + def type(self): + return LossType.Probabilistic + + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + + # sanity checks + if forecasts.dim() != 5: + raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.") + + # we assume that spatial_weights have NO ensemble dim + if (spatial_weights is not None) and (spatial_weights.dim() != observations.dim()): + spdim = spatial_weights.dim() + odim = observations.dim() + raise ValueError(f"the weights have to have the same number of dimensions (found {spdim}) as observations (found {odim}).") + + # we assume the following shapes: + # forecasts: batch, ensemble, channels, lat, lon + # observations: batch, channels, lat, lon + B, E, _, H, W = forecasts.shape + C = self.wind_chans.shape[0] + + # extract wind channels + forecasts = forecasts[..., self.wind_chans, :, :].reshape(B, E, C//2, 2, H, W) + observations = observations[..., self.wind_chans, :, :].reshape(B, C//2, 2, H, W) + + # get the data type before stripping amp types + dtype = forecasts.dtype + + # before anything else compute the transform + # as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same + with amp.autocast(device_type="cuda", enabled=False): + forecasts = self.isht(self.vsht(forecasts.float())) + observations = self.isht(self.vsht(observations.float())) + + # extract wind channels + forecasts = forecasts.reshape(B, E, C, H, W) + observations = observations.reshape(B, C, H, W) + + # if ensemble dim is one dimensional then computing the score is quick: + if (not self.ensemble_distributed) and (E == 1): + # in this case, CRPS is straightforward + crps = torch.abs(observations - forecasts.squeeze(1)).reshape(B, C, H * W) + else: + # transpose forecasts: ensemble, batch, channels, lat, lon + forecasts = torch.moveaxis(forecasts, 1, 0) + + # now we need to transpose the forecasts into ensemble direction. + # ideally we split spatial dims + forecasts = forecasts.reshape(E, B, C, H * W) + if self.ensemble_distributed: + ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] + forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") + # observations does not need a transpose, but just a split + observations = observations.reshape(B, C, H * W) + if self.ensemble_distributed: + observations = scatter_to_parallel_region(observations, -1, "ensemble") + if spatial_weights is not None: + spatial_weights_split = spatial_weights.flatten(start_dim=-2, end_dim=-1) + spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") + + # run appropriate crps kernel to compute it pointwise + if self.crps_type == "cdf": + # now, E dimension is local and spatial dim is split further + # we need to sort the forecasts now + forecasts, idx = torch.sort(forecasts, dim=0) + if self.ensemble_weights is not None: + ensemble_weights = self.ensemble_weights[idx] + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_ensemble_kernel(observations, forecasts, ensemble_weights) + elif self.crps_type == "skillspread": + if self.ensemble_weights is not None: + raise NotImplementedError("currently only constant ensemble weights are supported") + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) + elif self.crps_type == "gauss": + if self.ensemble_weights is not None: + ensemble_weights = self.ensemble_weights[idx] + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_gauss_kernel(observations, forecasts, ensemble_weights, self.eps) + else: + raise ValueError(f"Unknown CRPS crps_type {self.crps_type}") + + # perform ensemble and spatial average of crps score + if spatial_weights is not None: + crps = torch.sum(crps * self.quad_weight_split * spatial_weights_split, dim=-1) + else: + crps = torch.sum(crps * self.quad_weight_split, dim=-1) + + # since we split spatial dim into ensemble dim, we need to do an ensemble sum as well + if self.ensemble_distributed: + crps = reduce_from_parallel_region(crps, "ensemble") + + # we need to do the spatial averaging manually since + # we are not calling he quadrature forward function + if self.spatial_distributed: + crps = reduce_from_parallel_region(crps, "spatial") # the resulting tensor should have dimension B, C, which is what we return return crps diff --git a/makani/utils/losses/energy_score.py b/makani/utils/losses/energy_score.py new file mode 100644 index 0000000..b63ae77 --- /dev/null +++ b/makani/utils/losses/energy_score.py @@ -0,0 +1,524 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, List + +import math +import torch +import torch.nn as nn +from torch import amp + +from makani.utils.losses.base_loss import GeometricBaseLoss, SpectralBaseLoss, LossType +from makani.utils import comm + +# distributed stuff +from physicsnemo.distributed.utils import compute_split_shapes, split_tensor_along_dim +from physicsnemo.distributed.mappings import scatter_to_parallel_region, reduce_from_parallel_region +from makani.mpu.mappings import distributed_transpose + + +class LpEnergyScoreLoss(GeometricBaseLoss): + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + pole_mask: int, + spatial_distributed: Optional[bool] = False, + ensemble_distributed: Optional[bool] = False, + ensemble_weights: Optional[torch.Tensor] = None, + alpha: Optional[float] = 1.0, + beta: Optional[float] = 2.0, + eps: Optional[float] = 1.0e-5, + **kwargs, + ): + + super().__init__( + img_shape=img_shape, + crop_shape=crop_shape, + crop_offset=crop_offset, + channel_names=channel_names, + grid_type=grid_type, + pole_mask=pole_mask, + spatial_distributed=spatial_distributed, + ) + + self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed + self.alpha = alpha + self.beta = beta + self.eps = eps + + # we also need a variant of the weights split in ensemble direction: + quad_weight_split = self.quadrature.quad_weight.reshape(1, 1, -1) + if self.ensemble_distributed: + quad_weight_split = split_tensor_along_dim(quad_weight_split, dim=-1, num_chunks=comm.get_size("ensemble"))[comm.get_rank("ensemble")] + quad_weight_split = quad_weight_split.contiguous() + self.register_buffer("quad_weight_split", quad_weight_split, persistent=False) + + if ensemble_weights is not None: + self.register_buffer("ensemble_weights", ensemble_weights, persistent=False) + else: + self.ensemble_weights = ensemble_weights + + @property + def type(self): + return LossType.Probabilistic + + @property + def n_channels(self): + return 1 + + @torch.compiler.disable(recursive=False) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: str) -> torch.Tensor: + return torch.ones(1) + + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + + # sanity checks + if forecasts.dim() != 5: + raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.") + + # we assume that spatial_weights have NO ensemble dim + if (spatial_weights is not None) and (spatial_weights.dim() != observations.dim()): + spdim = spatial_weights.dim() + odim = observations.dim() + raise ValueError(f"the weights have to have the same number of dimensions (found {spdim}) as observations (found {odim}).") + + # we assume the following shapes: + # forecasts: batch, ensemble, channels, lat, lon + # observations: batch, channels, lat, lon + B, E, C, H, W = forecasts.shape + + # get the data type before stripping amp types + dtype = forecasts.dtype + + # transpose the forecasts to ensemble, batch, channels, lat, lon and then do distributed transpose into ensemble direction. + # ideally we split spatial dims + forecasts = torch.moveaxis(forecasts, 1, 0) + forecasts = forecasts.reshape(E, B, C, H * W) + if self.ensemble_distributed: + ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] + forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") + + # observations does not need a transpose, but just a split + observations = observations.reshape(1, B, C, H * W) + if self.ensemble_distributed: + observations = scatter_to_parallel_region(observations, -1, "ensemble") + + # for correct spatial reduction we need to do the same with spatial weights + if spatial_weights is not None: + spatial_weights_split = spatial_weights.flatten(start_dim=-2, end_dim=-1) + spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") + + if self.ensemble_weights is not None: + raise NotImplementedError("currently only constant ensemble weights are supported") + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # ensemble size + num_ensemble = forecasts.shape[0] + + # get nanmask from the observarions + nanmasks = torch.logical_or(torch.isnan(observations), torch.isnan(ensemble_weights)) + + # use broadcasting semantics to compute spread and skill and sum over channels (vector norm) + espread = (forecasts.unsqueeze(1) - forecasts.unsqueeze(0)).abs().pow(self.beta) + eskill = (observations - forecasts).abs().pow(self.beta) + + # perform masking before any reduction + espread = torch.where(nanmasks.sum(dim=0) != 0, 0.0, espread) + eskill = torch.where(nanmasks.sum(dim=0) != 0, 0.0, eskill) + + # do the spatial reduction + if spatial_weights is not None: + espread = torch.sum(espread * self.quad_weight_split * spatial_weights_split, dim=-1) + eskill = torch.sum(eskill * self.quad_weight_split * spatial_weights_split, dim=-1) + else: + espread = torch.sum(espread * self.quad_weight_split, dim=-1) + eskill = torch.sum(eskill * self.quad_weight_split, dim=-1) + + # since we split spatial dim into ensemble dim, we need to do an ensemble sum as well + if self.ensemble_distributed: + espread = reduce_from_parallel_region(espread, "ensemble") + eskill = reduce_from_parallel_region(eskill, "ensemble") + + # we need to do the spatial averaging manually since + # we are not calling the quadrature forward function + if self.spatial_distributed: + espread = reduce_from_parallel_region(espread, "spatial") + eskill = reduce_from_parallel_region(eskill, "spatial") + + # do the channel reduction while ignoring NaNs + # if channel weights are required they should be added here to the reduction + espread = espread.sum(dim=-1, keepdim=True) + eskill = eskill.sum(dim=-1, keepdim=True) + + # now we have reduced everything and need to sum appropriately + espread = espread.sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) + eskill = eskill.sum(dim=0) / float(num_ensemble) + + # the resulting tensor should have dimension B, C which is what we return + return eskill - 0.5 * espread + +class SobolevEnergyScoreLoss(SpectralBaseLoss): + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + lmax: Optional[int] = None, + spatial_distributed: Optional[bool] = False, + ensemble_distributed: Optional[bool] = False, + ensemble_weights: Optional[torch.Tensor] = None, + alpha: Optional[float] = 1.0, + beta: Optional[float] = 2.0, + fraction: Optional[float] = 1.0, + eps: Optional[float] = 1.0e-3, + psd_normalization: Optional[bool] = False, + bias: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + psd_means: Optional[torch.Tensor] = None, + **kwargs, + ): + + super().__init__( + img_shape=img_shape, + crop_shape=crop_shape, + crop_offset=crop_offset, + channel_names=channel_names, + grid_type=grid_type, + lmax=lmax, + spatial_distributed=spatial_distributed, + ) + + self.spatial_distributed = spatial_distributed and comm.is_distributed("spatial") + self.ensemble_distributed = ensemble_distributed and comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) + self.alpha = alpha + self.beta = beta + self.fraction = fraction + self.eps = eps + + if ensemble_weights is not None: + self.register_buffer("ensemble_weights", ensemble_weights, persistent=False) + else: + self.ensemble_weights = ensemble_weights + + # get the local lm weights + if psd_normalization and psd_means is not None and scale is not None: + + # ensure shapes match for broadcasting + if psd_means.dim() == 3: + psd_means = psd_means.unsqueeze(-1) + + # normalize PSD + # Since PSD scales with square of signal, we divide by scale^2 + norm_psd = psd_means / (scale.reshape(1, -1, 1, 1)**2) + + # fix the 0-th component using the bias + normalized_bias_sq = (bias.reshape(1, -1) / scale.reshape(1, -1))**2 + norm_psd = norm_psd / (4.0 * math.pi) + norm_psd[..., 0, 0] = torch.clamp(norm_psd[..., 0, 0] - normalized_bias_sq, min=0.0) + + # Compute spectral weights: 1/sqrt(PSD) + # Add epsilon for numerical stability + psd_weights = 1.0 / (norm_psd + self.eps**2) + else: + psd_weights = torch.ones((1, 1, self.sht.lmax, 1)) + + ls = torch.arange(self.sht.lmax).reshape(-1, 1) + ms = torch.arange(self.sht.mmax).reshape(1, -1) + lm_weights = (1 + ls * (ls + 1)).pow(self.fraction).tile(1, self.sht.mmax) + # account for the 4 pi normalization coming from the SHT + lm_weights = psd_weights * lm_weights / 4.0 / math.pi + lm_weights[:, 1:] *= 2.0 + lm_weights = torch.where(ms > ls, 0.0, lm_weights) + if comm.get_size("h") > 1: + lm_weights = split_tensor_along_dim(lm_weights, dim=-2, num_chunks=comm.get_size("h"))[comm.get_rank("h")] + if comm.get_size("w") > 1: + lm_weights = split_tensor_along_dim(lm_weights, dim=-1, num_chunks=comm.get_size("w"))[comm.get_rank("w")] + self.register_buffer("lm_weights", lm_weights, persistent=False) + + @property + def type(self): + return LossType.Probabilistic + + @property + def n_channels(self): + return 1 + + @torch.compiler.disable(recursive=False) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: str) -> torch.Tensor: + return torch.ones(1) + + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, ensemble_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + + # sanity checks + if forecasts.dim() != 5: + raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.") + + # get the data type before stripping amp types + dtype = forecasts.dtype + + # before anything else compute the transform + # as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same + with amp.autocast(device_type="cuda", enabled=False): + # TODO: check 4 pi normalization + forecasts = self.sht(forecasts.float()) + observations = self.sht(observations.float()) + + # cast back to original dtype + forecasts = forecasts.to(dtype=dtype) + observations = observations.to(dtype=dtype) + + # we assume the following shapes: + # forecasts: batch, ensemble, channels, mmax, lmax + # observations: batch, channels, mmax, lmax + B, E, C, H, W = forecasts.shape + + # transpose the forecasts to ensemble, batch, channels, lat, lon and then do distributed transpose into ensemble direction. + # ideally we split spatial dims + forecasts = torch.moveaxis(forecasts, 1, 0) + forecasts = forecasts.reshape(E, B, C, H * W) + if self.ensemble_distributed: + ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] + forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") # for correct spatial reduction we need to do the same with spatial weights + + lm_weights_split = self.lm_weights.flatten(start_dim=-2, end_dim=-1) + if self.ensemble_distributed: + lm_weights_split = scatter_to_parallel_region(lm_weights_split, -1, "ensemble") + + # observations does not need a transpose, but just a split + observations = observations.reshape(1, B, C, H * W) + if self.ensemble_distributed: + observations = scatter_to_parallel_region(observations, -1, "ensemble") + + num_ensemble = forecasts.shape[0] + + # get nanmask from the observarions + nanmasks = torch.logical_or(torch.isnan(observations), torch.isnan(forecasts)) + + # compute the individual distances + espread = lm_weights_split * (forecasts.unsqueeze(1) - forecasts.unsqueeze(0)).abs().pow(self.beta) + eskill = lm_weights_split * (observations - forecasts).abs().pow(self.beta) + + # perform masking before any reduction + espread = torch.where(nanmasks.sum(dim=0) != 0, 0.0, espread) + eskill = torch.where(nanmasks.sum(dim=0) != 0, 0.0, eskill) + + # do the channel reduction first + espread = espread.sum(dim=-3, keepdim=True) + eskill = eskill.sum(dim=-3, keepdim=True) + + # do the spatial reduction + espread = espread.sum(dim=-1, keepdim=False) + eskill = eskill.sum(dim=-1, keepdim=False) + + # since we split spatial dim into ensemble dim, we need to do an ensemble sum as well + if self.ensemble_distributed: + espread = reduce_from_parallel_region(espread, "ensemble") + eskill = reduce_from_parallel_region(eskill, "ensemble") + + # we need to do the spatial averaging manually since + if self.spatial_distributed: + espread = reduce_from_parallel_region(espread, "spatial") + eskill = reduce_from_parallel_region(eskill, "spatial") + + # now we have reduced everything and need to sum appropriately + espread = espread.sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) + eskill = eskill.sum(dim=0) / float(num_ensemble) + + return eskill - 0.5 * espread + +# class H1EnergyScoreLoss(GradientBaseLoss): + +# def __init__( +# self, +# img_shape: Tuple[int, int], +# crop_shape: Tuple[int, int], +# crop_offset: Tuple[int, int], +# channel_names: List[str], +# grid_type: str, +# pole_mask: int, +# lmax: Optional[int] = None, +# crps_type: str = "skillspread", +# spatial_distributed: Optional[bool] = False, +# ensemble_distributed: Optional[bool] = False, +# ensemble_weights: Optional[torch.Tensor] = None, +# absolute: Optional[bool] = True, +# alpha: Optional[float] = 1.0, +# beta: Optional[float] = 2.0, +# eps: Optional[float] = 1.0e-5, +# **kwargs, +# ): + +# super().__init__( +# img_shape=img_shape, +# crop_shape=crop_shape, +# crop_offset=crop_offset, +# channel_names=channel_names, +# grid_type=grid_type, +# pole_mask=pole_mask, +# lmax=lmax, +# spatial_distributed=spatial_distributed, +# ) + +# # if absolute is true, the loss is computed only on the absolute value of the gradient +# self.absolute = absolute + +# self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed +# self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed +# self.alpha = alpha +# self.beta = beta +# self.eps = eps + +# # we also need a variant of the weights split in ensemble direction: +# quad_weight_split = self.quadrature.quad_weight.reshape(1, 1, -1) +# if self.ensemble_distributed: +# quad_weight_split = split_tensor_along_dim(quad_weight_split, dim=-1, num_chunks=comm.get_size("ensemble"))[comm.get_rank("ensemble")] +# quad_weight_split = quad_weight_split.contiguous() +# self.register_buffer("quad_weight_split", quad_weight_split, persistent=False) + +# if ensemble_weights is not None: +# self.register_buffer("ensemble_weights", ensemble_weights, persistent=False) +# else: +# self.ensemble_weights = ensemble_weights + +# @property +# def type(self): +# return LossType.Probabilistic + +# @property +# def n_channels(self): +# return 1 + +# @torch.compiler.disable(recursive=False) +# def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: str) -> torch.Tensor: +# return torch.ones(1) + +# def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + +# # sanity checks +# if forecasts.dim() != 5: +# raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.") + +# # we assume that spatial_weights have NO ensemble dim +# if (spatial_weights is not None) and (spatial_weights.dim() != observations.dim()): +# spdim = spatial_weights.dim() +# odim = observations.dim() +# raise ValueError(f"the weights have to have the same number of dimensions (found {spdim}) as observations (found {odim}).") + +# # we assume the following shapes: +# # forecasts: batch, ensemble, channels, lat, lon +# # observations: batch, channels, lat, lon +# B, E, C, H, W = forecasts.shape + +# # get the data type before stripping amp types +# dtype = forecasts.dtype + +# # before anything else compute the transform +# # as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same +# with amp.autocast(device_type="cuda", enabled=False): + +# # compute the SH coefficients of the forecasts and observations +# forecasts = self.sht(forecasts.float()).unsqueeze(-3) +# observations = self.sht(observations.float()).unsqueeze(-3) + +# # append zeros, so that we can use the inverse vector SHT +# forecasts = torch.cat([forecasts, torch.zeros_like(forecasts)], dim=-3) +# observations = torch.cat([observations, torch.zeros_like(observations)], dim=-3) + +# forecasts = self.ivsht(forecasts) +# observations = self.ivsht(observations) + +# forecasts = forecasts.to(dtype) +# observations = observations.to(dtype) + +# forecasts = forecasts.reshape(B, E, 2*C, H, W) +# observations = observations.reshape(B, 2*C, H, W) + +# # transpose the forecasts to ensemble, batch, channels, lat, lon and then do distributed transpose into ensemble direction. +# # ideally we split spatial dims +# forecasts = torch.moveaxis(forecasts, 1, 0) +# forecasts = forecasts.reshape(E, B, 2*C, H * W) +# if self.ensemble_distributed: +# ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] +# forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") + +# # observations does not need a transpose, but just a split +# observations = observations.reshape(1, B, 2*C, H * W) +# if self.ensemble_distributed: +# observations = scatter_to_parallel_region(observations, -1, "ensemble") + +# # for correct spatial reduction we need to do the same with spatial weights +# if spatial_weights is not None: +# spatial_weights_split = spatial_weights.flatten(start_dim=-2, end_dim=-1) +# spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") + +# if self.ensemble_weights is not None: +# raise NotImplementedError("currently only constant ensemble weights are supported") +# else: +# ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + +# # ensemble size +# num_ensemble = forecasts.shape[0] + +# # get nanmask from the observarions +# nanmasks = torch.logical_or(torch.isnan(observations), torch.isnan(ensemble_weights)) + +# # use broadcasting semantics to compute spread and skill and sum over channels (vector norm) +# espread = (forecasts.unsqueeze(1) - forecasts.unsqueeze(0)).abs().pow(self.beta) +# eskill = (observations - forecasts).abs().pow(self.beta) + +# # perform masking before any reduction +# espread = torch.where(nanmasks.sum(dim=0) != 0, 0.0, espread) +# eskill = torch.where(nanmasks.sum(dim=0) != 0, 0.0, eskill) + +# # do the spatial reduction +# if spatial_weights is not None: +# espread = torch.sum(espread * self.quad_weight_split * spatial_weights_split, dim=-1) +# eskill = torch.sum(eskill * self.quad_weight_split * spatial_weights_split, dim=-1) +# else: +# espread = torch.sum(espread * self.quad_weight_split, dim=-1) +# eskill = torch.sum(eskill * self.quad_weight_split, dim=-1) + +# # since we split spatial dim into ensemble dim, we need to do an ensemble sum as well +# if self.ensemble_distributed: +# espread = reduce_from_parallel_region(espread, "ensemble") +# eskill = reduce_from_parallel_region(eskill, "ensemble") + +# # we need to do the spatial averaging manually since +# # we are not calling the quadrature forward function +# if self.spatial_distributed: +# espread = reduce_from_parallel_region(espread, "spatial") +# eskill = reduce_from_parallel_region(eskill, "spatial") + +# # do the channel reduction while ignoring NaNs +# # if channel weights are required they should be added here to the reduction +# espread = espread.sum(dim=-1, keepdim=True) +# eskill = eskill.sum(dim=-1, keepdim=True) + +# # now we have reduced everything and need to sum appropriately +# espread = espread.pow(1/self.beta).sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) +# eskill = eskill.pow(1/self.beta).sum(dim=0) / float(num_ensemble) + +# # the resulting tensor should have dimension B, 1 which is what we return +# return eskill - 0.5 * espread \ No newline at end of file diff --git a/makani/utils/losses/h1_loss.py b/makani/utils/losses/h1_loss.py index 0cac68b..75acff4 100644 --- a/makani/utils/losses/h1_loss.py +++ b/makani/utils/losses/h1_loss.py @@ -134,7 +134,7 @@ def rel(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] tar_norm2 = 2 * torch.sum(tar_coeffssq, dim=-1) if self.spatial_distributed and (comm.get_size("w") > 1): tar_norm2 = reduce_from_parallel_region(tar_norm2, "w") - + # compute target norms tar_norm2 = tar_norm2.reshape(B, C, -1) tar_h1_norm2 = torch.sum(tar_norm2 * self.h1_weights, dim=-1) diff --git a/makani/utils/losses/lp_loss.py b/makani/utils/losses/lp_loss.py index b6aae83..5fda9b5 100644 --- a/makani/utils/losses/lp_loss.py +++ b/makani/utils/losses/lp_loss.py @@ -22,6 +22,8 @@ from makani.utils import comm +from physicsnemo.distributed.mappings import reduce_from_parallel_region + class GeometricLpLoss(GeometricBaseLoss): """ @@ -112,9 +114,9 @@ def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tens return loss -class SpectralL2Loss(SpectralBaseLoss): +class SpectralLpLoss(SpectralBaseLoss): """ - Computes the geometric L2 loss but using the spherical Harmonic transform + Computes the Lp loss in spectral (SH coefficients) space """ def __init__( @@ -124,6 +126,7 @@ def __init__( crop_offset: Tuple[int, int], channel_names: List[str], grid_type: str, + p: Optional[float] = 2.0, relative: Optional[bool] = False, squared: Optional[bool] = False, spatial_distributed: Optional[bool] = False, @@ -138,6 +141,7 @@ def __init__( spatial_distributed=spatial_distributed, ) + self.p = p self.relative = relative self.squared = squared self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed @@ -145,80 +149,95 @@ def __init__( def abs(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None): B, C, H, W = prd.shape - coeffssq = torch.square(torch.abs(self.sht(prd - tar))) / torch.pi / 4.0 + # compute SH coefficients of the difference + coeffs = self.sht(prd - tar) + + # compute |coeffs|^p (orthonormal convention) + coeffsp = torch.abs(coeffs) ** self.p if wgt is not None: - coeffssq = coeffssq * wgt + coeffsp = coeffsp * wgt + # sum over m: m=0 contributes once, m!=0 contribute twice (due to conjugate symmetry) if comm.get_rank("w") == 0: - norm2 = coeffssq[..., 0] + 2 * torch.sum(coeffssq[..., 1:], dim=-1) + normp = coeffsp[..., 0] + 2 * torch.sum(coeffsp[..., 1:], dim=-1) else: - norm2 = 2 * torch.sum(coeffssq, dim=-1) + normp = 2 * torch.sum(coeffsp, dim=-1) + if self.spatial_distributed and (comm.get_size("w") > 1): - norm2 = reduce_from_parallel_region(norm2, "w") + normp = reduce_from_parallel_region(normp, "w") - # compute norms - norm2 = norm2.reshape(B, C, -1) - norm2 = torch.sum(norm2, dim=-1) + # sum over l (degrees) + normp = normp.reshape(B, C, -1) + normp = torch.sum(normp, dim=-1) if self.spatial_distributed and (comm.get_size("h") > 1): - norm2 = reduce_from_parallel_region(norm2, "h") + normp = reduce_from_parallel_region(normp, "h") + # take p-th root unless squared is True if not self.squared: - norm2 = torch.sqrt(norm2) + normp = normp ** (1.0 / self.p) - return norm2 + return normp def rel(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None): B, C, H, W = prd.shape - coeffssq = torch.square(torch.abs(self.sht(prd - tar))) / torch.pi / 4.0 + # compute SH coefficients of the difference + coeffs = self.sht(prd - tar) + coeffsp = torch.abs(coeffs) ** self.p if wgt is not None: - coeffssq = coeffssq * wgt + coeffsp = coeffsp * wgt - # sum m != 0 coeffs: + # sum m != 0 coeffs for numerator if comm.get_rank("w") == 0: - norm2 = coeffssq[..., 0] + 2 * torch.sum(coeffssq[..., 1:], dim=-1) + normp = coeffsp[..., 0] + 2 * torch.sum(coeffsp[..., 1:], dim=-1) else: - norm2 = 2 * torch.sum(coeffssq, dim=-1) + normp = 2 * torch.sum(coeffsp, dim=-1) + if self.spatial_distributed and (comm.get_size("w") > 1): - norm2 = reduce_from_parallel_region(norm2, "w") + normp = reduce_from_parallel_region(normp, "w") + + # sum over l + normp = normp.reshape(B, C, -1) + normp = torch.sum(normp, dim=-1) - # compute norms - norm2 = norm2.reshape(B, C, -1) - norm2 = torch.sum(norm2, dim=-1) if self.spatial_distributed and (comm.get_size("h") > 1): - norm2 = reduce_from_parallel_region(norm2, "h") + normp = reduce_from_parallel_region(normp, "h") - # target - tar_coeffssq = torch.square(torch.abs(self.sht(tar))) / torch.pi / 4.0 + # compute target norm + tar_coeffs = self.sht(tar) + tar_coeffsp = torch.abs(tar_coeffs) ** self.p if wgt is not None: - tar_coeffssq = tar_coeffssq * wgt + tar_coeffsp = tar_coeffsp * wgt - # sum m != 0 coeffs: + # sum m != 0 coeffs for denominator if comm.get_rank("w") == 0: - tar_norm2 = tar_coeffssq[..., 0] + 2 * torch.sum(tar_coeffssq[..., 1:], dim=-1) + tar_normp = tar_coeffsp[..., 0] + 2 * torch.sum(tar_coeffsp[..., 1:], dim=-1) else: - tar_norm2 = 2 * torch.sum(tar_coeffssq, dim=-1) + tar_normp = 2 * torch.sum(tar_coeffsp, dim=-1) + if self.spatial_distributed and (comm.get_size("w") > 1): - tar_norm2 = reduce_from_parallel_region(tar_norm2, "w") + tar_normp = reduce_from_parallel_region(tar_normp, "w") + + # sum over l + tar_normp = tar_normp.reshape(B, C, -1) + tar_normp = torch.sum(tar_normp, dim=-1) - # compute target norms - tar_norm2 = tar_norm2.reshape(B, C, -1) - tar_norm2 = torch.sum(tar_norm2, dim=-1) if self.spatial_distributed and (comm.get_size("h") > 1): - tar_norm2 = reduce_from_parallel_region(tar_norm2, "h") + tar_normp = reduce_from_parallel_region(tar_normp, "h") + # take p-th root unless squared is True if not self.squared: - diff_norms = torch.sqrt(norm2) - tar_norms = torch.sqrt(tar_norm2) + diff_norms = normp ** (1.0 / self.p) + tar_norms = tar_normp ** (1.0 / self.p) else: - diff_norms = norm2 - tar_norms = tar_norm2 + diff_norms = normp + tar_norms = tar_normp - # setup return value + # compute relative error retval = diff_norms / tar_norms return retval @@ -231,3 +250,4 @@ def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tens loss = self.abs(prd, tar, wgt) return loss + diff --git a/makani/utils/losses/mmd_loss.py b/makani/utils/losses/mmd_loss.py index 21d2d8d..7f0912c 100644 --- a/makani/utils/losses/mmd_loss.py +++ b/makani/utils/losses/mmd_loss.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,58 +15,24 @@ from typing import Optional, Tuple, List -import numpy as np - +import math import torch import torch.nn as nn -from torch.cuda import amp +from torch import amp -from makani.utils.losses.base_loss import GeometricBaseLoss, SpectralBaseLoss, LossType +from makani.utils.losses.base_loss import GeometricBaseLoss, GradientBaseLoss, LossType from makani.utils import comm +import torch_harmonics as th +import torch_harmonics.distributed as thd + # distributed stuff from physicsnemo.distributed.utils import compute_split_shapes, split_tensor_along_dim from physicsnemo.distributed.mappings import scatter_to_parallel_region, reduce_from_parallel_region from makani.mpu.mappings import distributed_transpose -# @torch.compile -# def _mmd_rbf_kernel(x: torch.Tensor, y: torch.Tensor): -# return torch.abs(x - y) - - -@torch.compile -def _mmd_rbf_kernel(x: torch.Tensor, y: torch.Tensor, bandwidth: float = 1.0): - return torch.exp(-0.5 * torch.square(torch.abs(x - y)) / bandwidth) - - -# Computes the squared maximum mean discrepancy -# @torch.compile -def _mmd2_ensemble_kernel(observation: torch.Tensor, forecasts: torch.Tensor) -> torch.Tensor: - - # initial values - spread_term = torch.zeros_like(observation) - disc_term = torch.zeros_like(observation) - - num_forecasts = forecasts.shape[0] - - for m in range(num_forecasts): - # get the forecast - ym = forecasts[m] - - # account for contributions on the off-diasgonal assuming that the kernel is symmetric - spread_term = spread_term + 2.0 * torch.sum(_mmd_rbf_kernel(ym, forecasts[m:]), dim=0) - - # contributions to the discrepancy term - disc_term = disc_term + _mmd_rbf_kernel(ym, observation) - - # compute the squared mmd - mmd2 = spread_term / (num_forecasts - 1) / num_forecasts - 2.0 * disc_term / num_forecasts - - return mmd2 - - -class EnsembleMMDLoss(GeometricBaseLoss): +class GaussianMMDLoss(GeometricBaseLoss): r""" Computes the maximum mean discrepancy loss for a specific kernel. For details see [1] @@ -80,10 +46,15 @@ def __init__( crop_offset: Tuple[int, int], channel_names: List[str], grid_type: str, - squared: Optional[bool] = False, - pole_mask: Optional[int] = 0, + pole_mask: int, spatial_distributed: Optional[bool] = False, ensemble_distributed: Optional[bool] = False, + ensemble_weights: Optional[torch.Tensor] = None, + sigma: Optional[float] = 1.0, + alpha: Optional[float] = 1.0, + beta: Optional[float] = 2.0, + eps: Optional[float] = 1.0e-5, + channel_reduction: Optional[bool] = False, **kwargs, ): @@ -97,10 +68,13 @@ def __init__( spatial_distributed=spatial_distributed, ) - self.squared = squared - self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed + self.alpha = alpha + self.beta = beta + self.eps = eps + self.channel_reduction = channel_reduction + self.sigma = sigma # we also need a variant of the weights split in ensemble direction: quad_weight_split = self.quadrature.quad_weight.reshape(1, 1, -1) @@ -109,10 +83,26 @@ def __init__( quad_weight_split = quad_weight_split.contiguous() self.register_buffer("quad_weight_split", quad_weight_split, persistent=False) + if ensemble_weights is not None: + self.register_buffer("ensemble_weights", ensemble_weights, persistent=False) + else: + self.ensemble_weights = ensemble_weights + @property def type(self): return LossType.Probabilistic + @property + def n_channels(self): + if self.channel_reduction: + return 1 + else: + return len(self.channel_names) + + @torch.compiler.disable(recursive=False) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: str) -> torch.Tensor: + return torch.ones(1) + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: # sanity checks @@ -121,51 +111,231 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w # we assume that spatial_weights have NO ensemble dim if (spatial_weights is not None) and (spatial_weights.dim() != observations.dim()): - raise ValueError("the weights have to have the same number of dimensions as observations") + spdim = spatial_weights.dim() + odim = observations.dim() + raise ValueError(f"the weights have to have the same number of dimensions (found {spdim}) as observations (found {odim}).") # we assume the following shapes: # forecasts: batch, ensemble, channels, lat, lon # observations: batch, channels, lat, lon B, E, C, H, W = forecasts.shape - # if ensemble dim is one dimensional then computing the score is quick: - if (not self.ensemble_distributed) and (forecasts.shape[1] == 1): - # in this case, CRPS is straightforward - mmd = _mmd_rbf_kernel(observations, forecasts.squeeze(1)).reshape(B, C, H * W) + # get the data type before stripping amp types + dtype = forecasts.dtype + + # transpose the forecasts to ensemble, batch, channels, lat, lon and then do distributed transpose into ensemble direction. + # ideally we split spatial dims + forecasts = torch.moveaxis(forecasts, 1, 0) + forecasts = forecasts.reshape(E, B, C, H * W) + if self.ensemble_distributed: + ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] + forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") + + # observations does not need a transpose, but just a split + observations = observations.reshape(1, B, C, H * W) + if self.ensemble_distributed: + observations = scatter_to_parallel_region(observations, -1, "ensemble") + + # for correct spatial reduction we need to do the same with spatial weights + if spatial_weights is not None: + spatial_weights_split = spatial_weights.flatten(start_dim=-2, end_dim=-1) + spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") + + if self.ensemble_weights is not None: + raise NotImplementedError("currently only constant ensemble weights are supported") else: - # transpose forecasts: ensemble, batch, channels, lat, lon - forecasts = torch.moveaxis(forecasts, 1, 0) - - # now we need to transpose the forecasts into ensemble direction. - # ideally we split spatial dims - forecasts = forecasts.reshape(E, B, C, H * W) - if self.ensemble_distributed: - ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] - forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") - # observations does not need a transpose, but just a split - observations = observations.reshape(B, C, H * W) - if self.ensemble_distributed: - observations = scatter_to_parallel_region(observations, -1, "ensemble") - if spatial_weights is not None: - spatial_weights_split = spatial_weights.flatten(-2, -1) - spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") - - # now, E dimension is local and spatial dim is split further. Compute the mmd - mmd = _mmd2_ensemble_kernel(observations, forecasts) - - # perform spatial average of crps score + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # ensemble size + num_ensemble = forecasts.shape[0] + + # get nanmask from the observarions + nanmasks = torch.logical_or(torch.isnan(observations), torch.isnan(ensemble_weights)) + + # use broadcasting semantics to compute spread and skill and sum over channels (vector norm) + espread = (forecasts.unsqueeze(1) - forecasts.unsqueeze(0)).abs().pow(self.beta) + eskill = (observations - forecasts).abs().pow(self.beta) + + # perform masking before any reduction + espread = torch.where(nanmasks.sum(dim=0) != 0, 0.0, espread) + eskill = torch.where(nanmasks.sum(dim=0) != 0, 0.0, eskill) + + # do the spatial reduction if spatial_weights is not None: - mmd = torch.sum(mmd * self.quad_weight_split * spatial_weights_split, dim=-1) + espread = torch.sum(espread * self.quad_weight_split * spatial_weights_split, dim=-1) + eskill = torch.sum(eskill * self.quad_weight_split * spatial_weights_split, dim=-1) else: - mmd = torch.sum(mmd * self.quad_weight_split, dim=-1) + espread = torch.sum(espread * self.quad_weight_split, dim=-1) + eskill = torch.sum(eskill * self.quad_weight_split, dim=-1) + + # since we split spatial dim into ensemble dim, we need to do an ensemble sum as well if self.ensemble_distributed: - mmd = reduce_from_parallel_region(mmd, "ensemble") + espread = reduce_from_parallel_region(espread, "ensemble") + eskill = reduce_from_parallel_region(eskill, "ensemble") + # we need to do the spatial averaging manually since + # we are not calling the quadrature forward function if self.spatial_distributed: - mmd = reduce_from_parallel_region(mmd, "spatial") + espread = reduce_from_parallel_region(espread, "spatial") + eskill = reduce_from_parallel_region(eskill, "spatial") + + # do the channel reduction while ignoring NaNs + # if channel weights are required they should be added here to the reduction + if self.channel_reduction: + espread = espread.sum(dim=-2, keepdim=True) + eskill = eskill.sum(dim=-2, keepdim=True) + + # apply the Gaussian kernel + espread = torch.exp(-0.5 * torch.square(espread) / self.sigma) + eskill = torch.exp(-0.5 * torch.square(eskill) / self.sigma) + + # mask out the diagonal elements in the spread term + espread = torch.where(torch.eye(num_ensemble, device=espread.device).bool().reshape(num_ensemble, num_ensemble, 1, 1), 0.0, espread) + + # now we have reduced everything and need to sum appropriately + espread = espread.sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) + eskill = eskill.sum(dim=0) / float(num_ensemble) + + # the resulting tensor should have dimension B, C which is what we return + return eskill - 0.5 * espread - if not self.squared: - mmd = torch.sqrt(mmd) +# @torch.compile +# def _mmd_rbf_kernel(x: torch.Tensor, y: torch.Tensor): +# return torch.abs(x - y) + + +# @torch.compile +# def _mmd_rbf_kernel(x: torch.Tensor, y: torch.Tensor, bandwidth: float = 1.0): +# return torch.exp(-0.5 * torch.square(torch.abs(x - y)) / bandwidth) - # the resulting tensor should have dimension B, C, which is what we return - return mmd + +# # Computes the squared maximum mean discrepancy +# # @torch.compile +# def _mmd2_ensemble_kernel(observation: torch.Tensor, forecasts: torch.Tensor) -> torch.Tensor: + +# # initial values +# spread_term = torch.zeros_like(observation) +# disc_term = torch.zeros_like(observation) + +# num_forecasts = forecasts.shape[0] + +# for m in range(num_forecasts): + +# # get the forecast +# ym = forecasts[m] + +# # account for contributions on the off-diasgonal assuming that the kernel is symmetric +# spread_term = spread_term + 2.0 * torch.sum(_mmd_rbf_kernel(ym, forecasts[m:]), dim=0) + +# # contributions to the discrepancy term +# disc_term = disc_term + _mmd_rbf_kernel(ym, observation) + +# # compute the squared mmd +# mmd2 = spread_term / (num_forecasts - 1) / num_forecasts - 2.0 * disc_term / num_forecasts + +# return mmd2 + + +# class EnsembleMMDLoss(GeometricBaseLoss): +# r""" +# Computes the maximum mean discrepancy loss for a specific kernel. For details see [1] + +# [1] Dziugaite, Gintare Karolina; Roy, Daniel M.; Ghahramani, Zhoubin; Training generative neural networks via Maximum Mean Discrepancy optimization; arXiv:1505.03906 +# """ + +# def __init__( +# self, +# img_shape: Tuple[int, int], +# crop_shape: Tuple[int, int], +# crop_offset: Tuple[int, int], +# channel_names: List[str], +# grid_type: str, +# squared: Optional[bool] = False, +# pole_mask: Optional[int] = 0, +# spatial_distributed: Optional[bool] = False, +# ensemble_distributed: Optional[bool] = False, +# **kwargs, +# ): + +# super().__init__( +# img_shape=img_shape, +# crop_shape=crop_shape, +# crop_offset=crop_offset, +# channel_names=channel_names, +# grid_type=grid_type, +# pole_mask=pole_mask, +# spatial_distributed=spatial_distributed, +# ) + +# self.squared = squared + +# self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed +# self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed + +# # we also need a variant of the weights split in ensemble direction: +# quad_weight_split = self.quadrature.quad_weight.reshape(1, 1, -1) +# if self.ensemble_distributed: +# quad_weight_split = split_tensor_along_dim(quad_weight_split, dim=-1, num_chunks=comm.get_size("ensemble"))[comm.get_rank("ensemble")] +# quad_weight_split = quad_weight_split.contiguous() +# self.register_buffer("quad_weight_split", quad_weight_split, persistent=False) + +# @property +# def type(self): +# return LossType.Probabilistic + +# def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + +# # sanity checks +# if forecasts.dim() != 5: +# raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.") + +# # we assume that spatial_weights have NO ensemble dim +# if (spatial_weights is not None) and (spatial_weights.dim() != observations.dim()): +# raise ValueError("the weights have to have the same number of dimensions as observations") + +# # we assume the following shapes: +# # forecasts: batch, ensemble, channels, lat, lon +# # observations: batch, channels, lat, lon +# B, E, C, H, W = forecasts.shape + +# # if ensemble dim is one dimensional then computing the score is quick: +# if (not self.ensemble_distributed) and (forecasts.shape[1] == 1): +# # in this case, CRPS is straightforward +# mmd = _mmd_rbf_kernel(observations, forecasts.squeeze(1)).reshape(B, C, H * W) +# else: +# # transpose forecasts: ensemble, batch, channels, lat, lon +# forecasts = torch.moveaxis(forecasts, 1, 0) + +# # now we need to transpose the forecasts into ensemble direction. +# # ideally we split spatial dims +# forecasts = forecasts.reshape(E, B, C, H * W) +# if self.ensemble_distributed: +# ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] +# forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") +# # observations does not need a transpose, but just a split +# observations = observations.reshape(B, C, H * W) +# if self.ensemble_distributed: +# observations = scatter_to_parallel_region(observations, -1, "ensemble") +# if spatial_weights is not None: +# spatial_weights_split = spatial_weights.flatten(-2, -1) +# spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") + +# # now, E dimension is local and spatial dim is split further. Compute the mmd +# mmd = _mmd2_ensemble_kernel(observations, forecasts) + +# # perform spatial average of crps score +# if spatial_weights is not None: +# mmd = torch.sum(mmd * self.quad_weight_split * spatial_weights_split, dim=-1) +# else: +# mmd = torch.sum(mmd * self.quad_weight_split, dim=-1) +# if self.ensemble_distributed: +# mmd = reduce_from_parallel_region(mmd, "ensemble") + +# if self.spatial_distributed: +# mmd = reduce_from_parallel_region(mmd, "spatial") + +# if not self.squared: +# mmd = torch.sqrt(mmd) + +# # the resulting tensor should have dimension B, C, which is what we return +# return mmd diff --git a/makani/utils/metric.py b/makani/utils/metric.py index 8e06260..fd8118e 100644 --- a/makani/utils/metric.py +++ b/makani/utils/metric.py @@ -73,14 +73,14 @@ def __init__(self, metric_name, metric_channels, metric_handle, channel_names, n # CPU buffers pin_memory = self.device.type == "cuda" - + if self.aux_shape_finalized is None: data_shape_finalized = (self.num_rollout_steps, self.num_channels) integral_shape = (self.num_channels) else: data_shape_finalized = (self.num_rollout_steps, self.num_channels, *self.aux_shape_finalized) integral_shape = (self.num_channels, *self.aux_shape_finalized) - + self.rollout_curve_cpu = torch.zeros(data_shape_finalized, dtype=torch.float32, device="cpu", pin_memory=pin_memory) if self.integrate: @@ -213,12 +213,12 @@ def __init__( climatology, num_rollout_steps, device, - l1_var_names=["u10m", "t2m", "u500", "z500", "q500", "sp"], - rmse_var_names=["u10m", "t2m", "u500", "z500", "q500", "sp"], - acc_var_names=["u10m", "t2m", "u500", "z500", "q500", "sp"], - crps_var_names=["u10m", "t2m", "u500", "z500", "q500", "sp"], - spread_var_names=["u10m", "t2m", "u500", "z500", "q500", "sp"], - ssr_var_names=["u10m", "t2m", "u500", "z500", "q500", "sp"], + l1_var_names=["u10m", "t2m", "sp", "sst", "u500", "z500", "q500", "q50"], + rmse_var_names=["u10m", "t2m", "sp", "sst", "u500", "z500", "q500", "q50"], + acc_var_names=["u10m", "t2m", "sp", "sst", "u500", "z500", "q500", "q50"], + crps_var_names=["u10m", "t2m", "sp", "sst", "u500", "z500", "q500", "q50"], + spread_var_names=["u10m", "t2m", "sp", "sst", "u500", "z500", "q500", "q50"], + ssr_var_names=["u10m", "t2m", "sp", "sst", "u500", "z500", "q500", "q50"], rh_var_names=[], wb2_compatible=False, ): diff --git a/makani/utils/metrics/functions.py b/makani/utils/metrics/functions.py index ea4ae51..d99038c 100644 --- a/makani/utils/metrics/functions.py +++ b/makani/utils/metrics/functions.py @@ -22,7 +22,7 @@ from physicsnemo.distributed.utils import split_tensor_along_dim from makani.mpu.mappings import distributed_transpose -from makani.utils.losses import EnsembleCRPSLoss, LossType +from makani.utils.losses import CRPSLoss, LossType from makani.utils.metrics.base_metric import _sanitize_shapes, _welford_reduction_helper, GeometricBaseMetric class GeometricL1(GeometricBaseMetric): @@ -197,7 +197,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tenso # stack along dim -1: # we form the ratio in the finalization step acc = torch.stack([cov_xy, var_x, var_y], dim=-1) - + # reduce if self.channel_reduction == "mean": acc = torch.mean(acc, dim=1) @@ -252,12 +252,12 @@ def __init__( def combine(self, vals, counts, dim=0): # sanitize shapes vals, counts = _sanitize_shapes(vals, counts, dim=dim) - + # extract parameters covs = vals[..., 0].unsqueeze(-1) m2s = vals[..., 1:3] means = vals[..., 3:5] - + # counts are: n = sum_k n_k counts_agg = torch.sum(counts, dim=0) # means are: mu = sum_i n_i * mu_i / n @@ -280,7 +280,7 @@ def finalize(self, vals, counts): return vals[..., 0] / torch.sqrt(vals[..., 1] * vals[..., 2]) def forward(self, x: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor] = None) -> torch.Tensor: - + if hasattr(self, "bias"): x = x - self.bias y = y - self.bias @@ -349,7 +349,7 @@ def __init__( @property def type(self): return LossType.Probabilistic - + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, weight: Optional[torch.Tensor] = None) -> torch.Tensor: # sanity checks @@ -498,7 +498,7 @@ def __init__( ): super().__init__() - self.metric_func = EnsembleCRPSLoss( + self.metric_func = CRPSLoss( img_shape=img_shape, crop_shape=crop_shape, crop_offset=crop_offset, diff --git a/makani/utils/training/deterministic_trainer.py b/makani/utils/training/deterministic_trainer.py index 1ed27d1..f69761b 100644 --- a/makani/utils/training/deterministic_trainer.py +++ b/makani/utils/training/deterministic_trainer.py @@ -179,11 +179,11 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] = self.loss_obj = self.loss_obj.to(self.device) self.timers["loss handler init"] = timer.time - # channel weights: - if self.log_to_screen: - chw_weights = self.loss_obj.channel_weights.squeeze().cpu().numpy().tolist() - chw_output = {k: v for k,v in zip(self.params.channel_names, chw_weights)} - self.logger.info(f"Channel weights: {chw_output}") + # # channel weights: + # if self.log_to_screen: + # chw_weights = self.loss_obj.channel_weights.squeeze().cpu().numpy().tolist() + # chw_output = {k: v for k,v in zip(self.params.channel_names, chw_weights)} + # self.logger.info(f"Channel weights: {chw_output}") # optimizer and scheduler setup with Timer() as timer: @@ -243,7 +243,12 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] = # visualization wrapper: with Timer() as timer: - plot_list = [{"name": "windspeed_uv10", "functor": "lambda x: np.sqrt(np.square(x[0, ...]) + np.square(x[1, ...]))", "diverging": False}] + plot_channel = "z500" + # plot_channel = "q50" + # plot_index = self.params.channel_names.index(plot_channel) + plot_index = 0 + print(self.params.channel_names) + plot_list = [{"name": plot_channel, "functor": f"lambda x: x[{plot_index}, ...]", "diverging": False}] out_bias, out_scale = self.train_dataloader.get_output_normalization() self.visualizer = visualize.VisualizationWrapper( self.params.log_to_wandb, @@ -529,12 +534,12 @@ def train_one_epoch(self, profiler=None): if do_update: # regular forward pass including DDP pred = self.model_train(inp) - loss = self.loss_obj(pred, tar) + loss = self.loss_obj(pred, tar, inp=inp) else: # disable sync step with self.model_train.no_sync(): pred = self.model_train(inp) - loss = self.loss_obj(pred, tar) + loss = self.loss_obj(pred, tar, inp=inp) loss = loss * loss_scaling_fact # backward pass diff --git a/makani/utils/training/ensemble_trainer.py b/makani/utils/training/ensemble_trainer.py index e6e7a86..3b60440 100644 --- a/makani/utils/training/ensemble_trainer.py +++ b/makani/utils/training/ensemble_trainer.py @@ -182,11 +182,11 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] = self.loss_obj = self.loss_obj.to(self.device) self.timers["loss handler init"] = timer.time - # channel weights: - if self.log_to_screen: - chw_weights = self.loss_obj.channel_weights.squeeze().cpu().numpy().tolist() - chw_output = {k: v for k,v in zip(self.params.channel_names, chw_weights)} - self.logger.info(f"Channel weights: {chw_output}") + # # channel weights: + # if self.log_to_screen: + # chw_weights = self.loss_obj.channel_weights.squeeze().cpu().numpy().tolist() + # chw_output = {k: v for k,v in zip(self.params.channel_names, chw_weights)} + # self.logger.info(f"Channel weights: {chw_output}") # optimizer and scheduler setup # model @@ -248,7 +248,10 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] = # visualization wrapper: with Timer() as timer: - plot_list = [{"name": "windspeed_uv10", "functor": "lambda x: np.sqrt(np.square(x[0, ...]) + np.square(x[1, ...]))", "diverging": False}] + plot_channel = "sst" + # plot_index = self.params.channel_names.index(plot_channel) + plot_index = 0 + plot_list = [{"name": plot_channel, "functor": f"lambda x: x[{plot_index}, ...]", "diverging": False}] out_bias, out_scale = self.train_dataloader.get_output_normalization() self.visualizer = visualize.VisualizationWrapper( self.params.log_to_wandb, @@ -491,7 +494,7 @@ def _ensemble_step(self, inp: torch.Tensor, tar: torch.Tensor): # stack predictions along new dim (ensemble dim): pred = torch.stack(predlist, dim=1) # compute loss - loss = self.loss_obj(pred, tar) + loss = self.loss_obj(pred, tar, inp=inp) return pred, loss diff --git a/makani/utils/training/stochastic_trainer.py b/makani/utils/training/stochastic_trainer.py index 9d0ae1e..519cd37 100644 --- a/makani/utils/training/stochastic_trainer.py +++ b/makani/utils/training/stochastic_trainer.py @@ -506,11 +506,11 @@ def train_one_epoch(self): with amp.autocast(device_type="cuda", enabled=self.amp_enabled, dtype=self.amp_dtype): if do_update: pred, tar = self.model_train(inp, tar, n_samples=self.params.stochastic_size) - loss = self.loss_obj(pred, tar) + loss = self.loss_obj(pred, tar, inp=inp) else: with self.model_train.no_sync(): pred, tar = self.model_train(inp, tar, n_samples=self.params.stochastic_size) - loss = self.loss_obj(pred, tar) + loss = self.loss_obj(pred, tar, inp=inp) loss = loss * loss_scaling_fact self.gscaler.scale(loss).backward() diff --git a/tests/distributed/distributed_helpers.py b/tests/distributed/distributed_helpers.py index 334f269..303383a 100644 --- a/tests/distributed/distributed_helpers.py +++ b/tests/distributed/distributed_helpers.py @@ -61,7 +61,6 @@ def get_default_parameters(): params.N_in_channels = len(params.in_channels) params.N_out_channels = len(params.out_channels) - params.target = "default" params.batch_size = 1 params.valid_autoreg_steps = 0 params.num_data_workers = 1 @@ -93,7 +92,7 @@ def split_helper(tensor, dim=None, group=None): tensor_local = tensor_list_local[grank] else: tensor_local = tensor.clone() - + return tensor_local @@ -117,5 +116,5 @@ def gather_helper(tensor, dim=None, group=None): tensor_gather = torch.cat(tens_gather, dim=dim) else: tensor_gather = tensor.clone() - + return tensor_gather diff --git a/tests/distributed/tests_distributed_losses.py b/tests/distributed/tests_distributed_losses.py index ac9a8aa..ba69485 100644 --- a/tests/distributed/tests_distributed_losses.py +++ b/tests/distributed/tests_distributed_losses.py @@ -31,7 +31,7 @@ from physicsnemo.distributed.mappings import gather_from_parallel_region, scatter_to_parallel_region, reduce_from_parallel_region from makani.utils.grids import GridQuadrature -from makani.utils.losses import EnsembleCRPSLoss, EnsembleNLLLoss +from makani.utils.losses import CRPSLoss, EnsembleNLLLoss from distributed_helpers import split_helper, gather_helper @@ -116,7 +116,7 @@ def _gather_helper_bwd(self, tensor, ensemble=False): return tensor_gather - + @parameterized.expand( [ [128, 256, 32, 8, 1e-6, "naive", False], @@ -149,7 +149,7 @@ def test_distributed_quadrature(self, nlat, nlon, batch_size, num_chan, tol, qua ograd_full = torch.randn_like(out_full) out_full.backward(ograd_full) igrad_full = inp_full.grad.clone() - + # distributed out_local = quad_dist(inp_local) out_local.backward(ograd_full) @@ -172,7 +172,7 @@ def test_distributed_quadrature(self, nlat, nlon, batch_size, num_chan, tol, qua ############################################################# with torch.no_grad(): igrad_gather_full = self._gather_helper_bwd(igrad_local, False) - + # compute errors err = fn.relative_error(igrad_gather_full, igrad_full) if verbose and (self.world_rank == 0): @@ -204,7 +204,7 @@ def test_distributed_crps(self, nlat, nlon, batch_size, num_chan, ens_size, tol, if loss_type == "ensemble_crps": # local loss - loss_fn_local = EnsembleCRPSLoss( + loss_fn_local = CRPSLoss( img_shape=(H, W), crop_shape=None, crop_offset=(0, 0), @@ -218,7 +218,7 @@ def test_distributed_crps(self, nlat, nlon, batch_size, num_chan, ens_size, tol, ).to(self.device) # distributed loss - loss_fn_dist = EnsembleCRPSLoss( + loss_fn_dist = CRPSLoss( img_shape=(H, W), crop_shape=None, crop_offset=(0, 0), @@ -232,7 +232,7 @@ def test_distributed_crps(self, nlat, nlon, batch_size, num_chan, ens_size, tol, ).to(self.device) elif loss_type == "gauss_crps": # local loss - loss_fn_local = EnsembleCRPSLoss( + loss_fn_local = CRPSLoss( img_shape=(H, W), crop_shape=None, crop_offset=(0, 0), @@ -246,7 +246,7 @@ def test_distributed_crps(self, nlat, nlon, batch_size, num_chan, ens_size, tol, ).to(self.device) # distributed loss - loss_fn_dist = EnsembleCRPSLoss( + loss_fn_dist = CRPSLoss( img_shape=(H, W), crop_shape=None, crop_offset=(0, 0), @@ -260,7 +260,7 @@ def test_distributed_crps(self, nlat, nlon, batch_size, num_chan, ens_size, tol, ).to(self.device) elif loss_type == "skillspread_crps": # local loss - loss_fn_local = EnsembleCRPSLoss( + loss_fn_local = CRPSLoss( img_shape=(H, W), crop_shape=None, crop_offset=(0, 0), @@ -274,7 +274,7 @@ def test_distributed_crps(self, nlat, nlon, batch_size, num_chan, ens_size, tol, ).to(self.device) # distributed loss - loss_fn_dist = EnsembleCRPSLoss( + loss_fn_dist = CRPSLoss( img_shape=(H, W), crop_shape=None, crop_offset=(0, 0), diff --git a/tests/test_losses.py b/tests/test_losses.py index 0062333..9c3ffdc 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -24,7 +24,7 @@ from makani.models import model_registry from makani.utils import LossHandler -from makani.utils.losses import EnsembleCRPSLoss +from makani.utils.losses import CRPSLoss from testutils import get_default_parameters @@ -138,7 +138,7 @@ def test_loss_batchsize_independence(self, losses, uncertainty_weighting=False): # test initialization of loss object loss_obj = LossHandler(self.params) - + shape = (self.params.batch_size, self.params.N_out_channels, self.params.img_shape_x, self.params.img_shape_y) inp = torch.randn(*shape) @@ -150,14 +150,14 @@ def test_loss_batchsize_independence(self, losses, uncertainty_weighting=False): out2 = loss_obj(tar2, inp2) self.assertTrue(torch.allclose(out, out2)) - + @parameterized.expand(_loss_weighted_params) def test_loss_weighted(self, losses, uncertainty_weighting=False): """ Tests initialization of loss, as well as the forward and backward pass """ - + self.params.losses = losses self.params.uncertainty_weighting = uncertainty_weighting @@ -179,7 +179,7 @@ def test_loss_weighted(self, losses, uncertainty_weighting=False): # compute weighted loss out_weighted = loss_obj(tar, inp, wgt) - + self.assertTrue(torch.allclose(out, out_weighted)) @parameterized.expand(_loss_weighted_params) @@ -211,44 +211,44 @@ def test_loss_multistep(self, losses, uncertainty_weighting=False): # compute weighted loss out_weighted = loss_obj(tar, inp, wgt) - self.assertTrue(torch.allclose(out, out_weighted)) + self.assertTrue(torch.allclose(out, out_weighted)) def test_running_stats(self): """ Tests computation of the running stats """ - + self.params.losses = [{"type": "l2"}] - + # test initialization of loss object loss_obj = LossHandler(self.params, track_running_stats=True) loss_obj.train() - + shape = (self.params.batch_size, self.params.N_out_channels, self.params.img_shape_x, self.params.img_shape_y) - + # this needs to be sufficiently large to mitigarte the bias due to the initialization of the running stats num_samples = 100 for i in range(num_samples): - + inp = i * torch.ones(*shape) inp.requires_grad = True tar = torch.zeros(*shape) tar.requires_grad = True - + # forward pass and check shapes out = loss_obj(tar, inp) - + # generate simulated dataset data = torch.arange(num_samples).float().reshape(1, 1, -1).repeat(self.params.batch_size, self.params.N_out_channels, 1) expected_var, expected_mean = torch.var_mean(data, correction=0, dim=(0, -1)) - + var, mean = loss_obj.get_running_stats() - + self.assertTrue(torch.allclose(mean, expected_mean)) self.assertTrue(torch.allclose(var, expected_var)) def test_ensemble_crps(self): - crps_func = EnsembleCRPSLoss( + crps_func = CRPSLoss( img_shape=(self.params.img_shape_x, self.params.img_shape_y), crop_shape=(self.params.img_shape_x, self.params.img_shape_y), crop_offset=(0, 0), @@ -260,24 +260,24 @@ def test_ensemble_crps(self): ensemble_distributed=False, ensemble_weights=None, ) - + for ensemble_size in [1, 10]: with self.subTest(f"{ensemble_size}"): # generate input tensor inp = torch.empty((self.params.batch_size, ensemble_size, self.params.N_in_channels, self.params.img_shape_x, self.params.img_shape_y), dtype=torch.float32) with torch.no_grad(): inp.normal_(1.0, 1.0) - + # target tensor tar = torch.ones((self.params.batch_size, self.params.N_in_channels, self.params.img_shape_x, self.params.img_shape_y), dtype=torch.float32) - + # torch result result = crps_func(inp, tar).cpu().numpy() - + # properscoring result tar_arr = tar.cpu().numpy() inp_arr = inp.cpu().numpy() - + # I think this is a bug with the axis index in properscoring # for the degenerate case: if ensemble_size == 1: @@ -285,19 +285,19 @@ def test_ensemble_crps(self): inp_arr = np.squeeze(inp_arr, axis=1) else: axis = 1 - + result_proper = crps_ensemble(tar_arr, inp_arr, weights=None, issorted=False, axis=axis) quad_weight_arr = crps_func.quadrature.quad_weight.cpu().numpy() result_proper = np.sum(result_proper * quad_weight_arr, axis=(2, 3)) - + self.assertTrue(np.allclose(result, result_proper, rtol=1e-5, atol=0)) def test_gauss_crps(self): - + # protext against sigma=0 eps = 1.0e-5 - - crps_func = EnsembleCRPSLoss( + + crps_func = CRPSLoss( img_shape=(self.params.img_shape_x, self.params.img_shape_y), crop_shape=(self.params.img_shape_x, self.params.img_shape_y), crop_offset=(0, 0), @@ -309,32 +309,32 @@ def test_gauss_crps(self): ensemble_distributed=False, eps=eps, ) - + for ensemble_size in [1, 10]: with self.subTest(f"{ensemble_size}"): # generate input tensor inp = torch.empty((self.params.batch_size, ensemble_size, self.params.N_in_channels, self.params.img_shape_x, self.params.img_shape_y), dtype=torch.float32) with torch.no_grad(): inp.normal_(1.0, 1.0) - + # target tensor tar = torch.ones((self.params.batch_size, self.params.N_in_channels, self.params.img_shape_x, self.params.img_shape_y), dtype=torch.float32) - + # torch result result = crps_func(inp, tar).cpu().numpy() - + # properscoring result tar_arr = tar.cpu().numpy() inp_arr = inp.cpu().numpy() - + # compute mu, sigma, guard against underflows mu = np.mean(inp_arr, axis=1) sigma = np.maximum(np.sqrt(np.var(inp_arr, axis=1)), eps) - + result_proper = crps_gaussian(tar_arr, mu, sigma, grad=False) quad_weight_arr = crps_func.quadrature.quad_weight.cpu().numpy() result_proper = np.sum(result_proper * quad_weight_arr, axis=(2, 3)) - + self.assertTrue(np.allclose(result, result_proper, rtol=1e-5, atol=0)) diff --git a/tests/testutils.py b/tests/testutils.py index dd78b08..86f6d85 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -62,7 +62,6 @@ def get_default_parameters(): params.N_in_channels = len(params.in_channels) params.N_out_channels = len(params.out_channels) - params.target = "default" params.batch_size = 1 params.valid_autoreg_steps = 0 params.num_data_workers = 1