Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
8da4607
adding index exclusion list support in dali loaders, making index han…
azrael417 Sep 17, 2025
7bd3f42
some small fixes for non combined dali loader and unannotated files
azrael417 Sep 17, 2025
0c875db
distringuishing between read_shape and return_shape. read_shape will …
azrael417 Sep 18, 2025
3773f84
fixing a bug where samples per year was not computed correctly
azrael417 Sep 18, 2025
e2ec9e0
making nicer prints
azrael417 Sep 18, 2025
12093d7
fixed dataloader indexing
azrael417 Sep 18, 2025
cca3a58
fixed history with noise
azrael417 Sep 18, 2025
555e930
setting explicit noise seeds for ensemble serial settings. without th…
azrael417 Sep 19, 2025
38780ee
adding accessor functions for some RNG features when noise is enabled
azrael417 Sep 19, 2025
b68733e
removing unneccessary imports
azrael417 Sep 19, 2025
7656ad4
removing more imports
azrael417 Sep 19, 2025
228e7d9
removing more imports
azrael417 Sep 19, 2025
02b6add
adapting changes from internal branch
bonevbs Sep 19, 2025
ae27deb
some changes to the model and losses
bonevbs Oct 30, 2025
9bc000c
restoring logic from main regarding noise states
bonevbs Oct 30, 2025
ff71af3
fixing model package
bonevbs Nov 1, 2025
45a4c2e
some changes to deterministic trainer
bonevbs Nov 12, 2025
25862ae
removing the residual training optiona and instead extending loss han…
bonevbs Nov 24, 2025
756a51b
updating other configs
bonevbs Nov 24, 2025
092e7a5
added routine to compute spherical bandlimit
bonevbs Nov 24, 2025
df70e1b
changes to registry to pass normalization parameters
bonevbs Dec 2, 2025
2e79a17
adding energy score loss
bonevbs Dec 2, 2025
08b8f2e
removing the random CRPS
bonevbs Dec 2, 2025
7bca7d1
fixing another bug
bonevbs Dec 2, 2025
7f5a59a
more fixes for energy score loss
bonevbs Dec 2, 2025
e29a588
another fix
bonevbs Dec 2, 2025
207fa6d
fixing another bug
bonevbs Dec 2, 2025
af59f91
fixing energy score
bonevbs Dec 2, 2025
ef4efc4
cleaning up some losses
bonevbs Dec 10, 2025
ea87f87
implemented improved Sobolev energy score
bonevbs Dec 21, 2025
790ee03
added powerspectrum to stats computation
bonevbs Dec 23, 2025
15f06a5
added psd normalization to energy score
bonevbs Dec 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions config/afnonet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,6 @@ full_field: &BASE_CONFIG
nested_skip_fno: !!bool True # whether to nest the inner skip connection or have it be sequential, inside the AFNO block
verbose: False

#options default, residual
target: "default"

channel_names: ["u10m", "v10m", "t2m", "sp", "msl", "t850", "u1000", "v1000", "z1000", "u850", "v850", "z850", "u500", "v500", "z500", "t500", "z50", "r500", "r850", "tcwv", "u100m", "v100m", "u250", "v250", "z250", "t250", "u100", "v100", "z100", "t100", "u900", "v900", "z900", "t900"]
normalization: "zscore" #options zscore or minmax
hard_thresholding_fraction: 1.0
Expand Down
3 changes: 0 additions & 3 deletions config/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,6 @@ base_config: &BASE_CONFIG
add_noise: !!bool False
noise_std: 0.

target: "default" # options default, residual
normalize_residual: false

# define channels to be read from data
channel_names: ["u10m", "v10m", "u100m", "v100m", "t2m", "sp", "msl", "tcwv", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"]
# normalization mode zscore but for q
Expand Down
3 changes: 0 additions & 3 deletions config/fourcastnet3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ base_config: &BASE_CONFIG
add_noise: !!bool False
noise_std: 0.

target: "default" # options default, residual
normalize_residual: false

# define channels to be read from data. sp has been removed here
channel_names: ["u10m", "v10m", "u100m", "v100m", "t2m", "msl", "tcwv", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"]
# normalization mode zscore but for q
Expand Down
3 changes: 0 additions & 3 deletions config/icml_models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ base_config: &BASE_CONFIG
N_grid_channels: 0
gridtype: "sinusoidal" #options "sinusoidal" or "linear"

#options default, residual
target: "default"

channel_names: ["u10m", "v10m", "t2m", "sp", "msl", "t850", "u1000", "v1000", "z1000", "u850", "v850", "z850", "u500", "v500", "z500", "t500", "z50", "r500", "r850", "tcwv", "u100m", "v100m", "u250", "v250", "z250", "t250", "u100", "v100", "z100", "t100", "u900", "v900", "z900", "t900"]
normalization: "zscore" #options zscore or minmax or none

Expand Down
7 changes: 2 additions & 5 deletions config/pangu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,6 @@ base_config: &BASE_CONFIG
add_noise: !!bool False
noise_std: 0.

target: "default" # options default, residual
normalize_residual: !!bool False

# define channels to be read from data
channel_names: ["u10m", "v10m", "t2m", "msl", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"]
normalization: "zscore" # options zscore or minmax or none
Expand Down Expand Up @@ -131,10 +128,10 @@ base_onnx: &BASE_ONNX
# ONNX wrapper related overwrite
nettype: "/makani/makani/makani/models/networks/pangu_onnx.py:PanguOnnx"
onnx_file: '/model/pangu_weather_6.onnx'

amp_mode: "none"
disable_ddp: True

# Set Pangu ONNX channel order
channel_names: ["msl", "u10m", "v10m", "t2m", "z1000", "z925", "z850", "z700", "z600", "z500", "z400", "z300", "z250", "z200", "z150", "z100", "z50", "q1000", "q925", "q850", "q700", "q600", "q500", "q400", "q300", "q250", "q200", "q150", "q100", "q50", "t1000", "t925", "t850", "t700", "t600", "t500", "t400", "t300", "t250", "t200", "t150", "t100", "t50", "u1000", "u925", "u850", "u700", "u600", "u500", "u400", "u300", "u250", "u200", "u150", "u100", "u50", "v1000", "v925", "v850", "v700", "v600", "v500", "v400", "v300", "v250", "v200", "v150", "v100", "v50"]
# Remove input/output normalization
Expand Down
3 changes: 0 additions & 3 deletions config/sfnonet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,6 @@ base_config: &BASE_CONFIG
add_noise: !!bool False
noise_std: 0.

target: "default" # options default, residual
normalize_residual: !!bool False

# define channels to be read from data
channel_names: ["u10m", "v10m", "u100m", "v100m", "t2m", "sp", "msl", "tcwv", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"]
normalization: "zscore" # options zscore or minmax or none
Expand Down
3 changes: 0 additions & 3 deletions config/vit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ full_field: &BASE_CONFIG
embed_dim: 384
normalization_layer: "layer_norm"

#options default, residual
target: "default"

channel_names: ["u10m", "v10m", "u100m", "v100m", "t2m", "sp", "msl", "tcwv", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"]
normalization: "zscore" #options zscore or minmax
hard_thresholding_fraction: 1.0
Expand Down
109 changes: 80 additions & 29 deletions data_process/get_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import torch
import torch.distributed as dist
from torch_harmonics import RealSHT
from makani.utils.grids import GridQuadrature

from wb2_helpers import DistributedProgressBar
Expand All @@ -50,7 +51,7 @@ def allgather_dict(stats, group):
for _ in range(dist.get_world_size(group)):
substats = {varname: {} for varname in stats.keys()}
stats_gather.append(substats)

# iterate over full dict
for varname, substats in stats.items():
for k,v in substats.items():
Expand Down Expand Up @@ -179,6 +180,14 @@ def welford_combine(stats1, stats2):
return stats


def compute_powerspectrum(x, sht):
coeffs = sht(x).abs().pow(2)
# account for hermitian symetry
coeffs[..., 1:] *= 2.0
power_spectrum = coeffs.sum(dim=-1)
return power_spectrum


def get_file_stats(filename,
file_slice,
wind_indices,
Expand All @@ -187,7 +196,8 @@ def get_file_stats(filename,
dt=1,
batch_size=16,
device=torch.device("cpu"),
progress=None):
progress=None,
sht=None):

stats = None
with h5.File(filename, 'r') as f:
Expand All @@ -203,7 +213,7 @@ def get_file_stats(filename,

if batch_size is None:
batch_size = slc_stop - slc_start

for batch_start in range(slc_start, slc_stop, batch_size):
batch_stop = min(batch_start+batch_size, slc_stop)
sub_slc = slice(batch_start, batch_stop)
Expand All @@ -227,14 +237,33 @@ def get_file_stats(filename,
counts_time = tdata.shape[0]
valid_count = torch.sum(quadrature(valid_mask), dim=0)
counts_time_space = valid_count

# Basic observables
# compute mean and variance
# the mean needs to be divided by number of valid samples:
tmean = torch.sum(quadrature(tdata_masked), dim=0, keepdim=False).reshape(1, -1, 1, 1) / valid_count[None, :, None, None]
# we compute m2 directly, so we do not need to divide by number of valid samples:
tm2 = torch.sum(quadrature(torch.square(tdata_masked - tmean)), dim=0, keepdim=False).reshape(1, -1, 1, 1)

# compute PSD stats
psd_stats = None
if sht is not None:
# compute psd (B, C, L)
psd = compute_powerspectrum(tdata_masked, sht)
# compute mean and m2 over batch
psd_mean = torch.mean(psd, dim=0, keepdim=True) # (1, C, L)
psd_m2 = torch.sum(torch.square(psd - psd_mean), dim=0, keepdim=True) # (1, C, L)

# reshape to (1, C, L, 1) for welford compatibility
psd_mean = psd_mean.unsqueeze(-1)
psd_m2 = psd_m2.unsqueeze(-1)

psd_stats = {
"type": "meanvar",
"counts": float(counts_time) * torch.ones((data.shape[1]), dtype=torch.float64, device=device),
"values": torch.stack([psd_mean, psd_m2], dim=0).contiguous()
}

# fill the dict
tmpstats = dict(
maxs = {
Expand All @@ -261,6 +290,9 @@ def get_file_stats(filename,
}
)

if psd_stats is not None:
tmpstats["psd_meanvar"] = psd_stats

# time diffs: read one more sample for these, if possible
# TODO: tile it for dt < batch_size
if batch_start >= dt:
Expand Down Expand Up @@ -288,13 +320,13 @@ def get_file_stats(filename,
# we need the shapes
tshape = tmean.shape
tmpstats["time_diff_meanvar"] = {
"type": "meanvar",
"type": "meanvar",
"counts": torch.zeros(data.shape[1], dtype=torch.float64, device=device),
"values": torch.stack(
[
torch.zeros(tshape, dtype=torch.float64, device=device),
torch.zeros(tshape, dtype=torch.float64, device=device),
torch.zeros(tshape, dtype=torch.float64, device=device)
],
],
dim=0
).contiguous(),
}
Expand Down Expand Up @@ -333,7 +365,7 @@ def get_file_stats(filename,
"counts": torch.zeros(wdiffshape[1], dtype=torch.float64, device=device),
"values": torch.stack(
[
torch.zeros(wdiffshape, dtype=torch.float64, device=device),
torch.zeros(wdiffshape, dtype=torch.float64, device=device),
torch.zeros(wdiffshape, dtype=torch.float64, device=device)
],
dim=0
Expand Down Expand Up @@ -379,7 +411,7 @@ def get_stats(input_path: str, output_path: str, metadata_file: str,
dt: int, quadrature_rule: str, wind_angle_aware: bool, fail_on_nan: bool=False,
batch_size: Optional[int]=16, reduction_group_size: Optional[int]=8):

"""Function to compute various statistics of all variables of a makani HDF5 dataset.
"""Function to compute various statistics of all variables of a makani HDF5 dataset.

This function reads data from input_path and computes minimum, maximum, mean and standard deviation
for all variables in the dataset. This is done globally, meaning averaged over space and time.
Expand Down Expand Up @@ -410,7 +442,7 @@ def get_stats(input_path: str, output_path: str, metadata_file: str,
coords: this is a dictionary which contains two lists, latitude and longitude coordinates in degrees as well as channel names.
Example: coords = dict(lat=[-90.0, ..., 90.], lon=[0, ..., 360], channel=["t2m", "u500", "v500", ...])
Note that the number of entries in coords["lat"] has to match dimension -2 of the dataset, and coords["lon"] dimension -1.
The length of the channel names has to match dimension -3 (or dimension 1, which is the same) of the dataset.
The length of the channel names has to match dimension -3 (or dimension 1, which is the same) of the dataset.
dt : int
The temporal difference for which the temporal means and standard deviations should be computed. Note that this is in units of dhours (see metadata file),
quadrature_rule : str
Expand Down Expand Up @@ -456,14 +488,14 @@ def get_stats(input_path: str, output_path: str, metadata_file: str,
# init torch distributed
device = torch.device(f"cuda:{comm_local_rank}") if torch.cuda.is_available() else torch.device("cpu")
dist.init_process_group(
backend="nccl" if torch.cuda.is_available() else "gloo",
backend="nccl" if torch.cuda.is_available() else "gloo",
init_method="env://",
world_size=comm_size,
rank=comm_rank,
device_id=device,
)
mesh = dist.init_device_mesh(
device_type=device.type,
device_type=device.type,
mesh_shape=[reduction_group_size, comm_size // reduction_group_size],
mesh_dim_names=["reduction", "tree"],
)
Expand Down Expand Up @@ -520,10 +552,21 @@ def get_stats(input_path: str, output_path: str, metadata_file: str,
height, width = (data_shape[2], data_shape[3])

# quadrature:
# we normalize the quadrature rule to 4pi
quadrature = GridQuadrature(quadrature_rule, (height, width),
crop_shape=None, crop_offset=(0, 0),
normalize=False, pole_mask=None).to(device)

# Initialize SHT
grid_type_map = {
"naive": "equiangular",
"clenshaw-curtiss": "clenshaw-curtiss",
"gauss-legendre": "legendre-gauss"
}
grid_type = grid_type_map[quadrature_rule]
sht = RealSHT(height, width, grid=grid_type).to(device)
lmax = sht.lmax

if comm_rank == 0:
print(f"Found {len(filelist)} files with a total of {num_samples_total} samples. Each sample has the shape {num_channels}x{height}x{width} (CxHxW).")

Expand Down Expand Up @@ -564,50 +607,55 @@ def get_stats(input_path: str, output_path: str, metadata_file: str,
# initialize arrays
stats = dict(
global_meanvar = {
"type": "meanvar",
"counts": torch.zeros((num_channels), dtype=torch.float64, device=device),
"type": "meanvar",
"counts": torch.zeros((num_channels), dtype=torch.float64, device=device),
"values": torch.zeros((2, 1, num_channels, 1, 1), dtype=torch.float64, device=device),
},
mins = {
"type": "min",
"counts": torch.zeros((num_channels), dtype=torch.float64, device=device),
"type": "min",
"counts": torch.zeros((num_channels), dtype=torch.float64, device=device),
"values": torch.full((1, num_channels, 1, 1), torch.inf, dtype=torch.float64, device=device)
},
maxs = {
"type": "max",
"counts": torch.zeros((num_channels), dtype=torch.float64, device=device),
"type": "max",
"counts": torch.zeros((num_channels), dtype=torch.float64, device=device),
"values": torch.full((1, num_channels, 1, 1), -torch.inf, dtype=torch.float64, device=device)
},
time_means = {
"type": "mean",
"counts": torch.zeros((num_channels), dtype=torch.float64, device=device),
"type": "mean",
"counts": torch.zeros((num_channels), dtype=torch.float64, device=device),
"values": torch.zeros((1, num_channels, height, width), dtype=torch.float64, device=device)
},
time_diff_meanvar = {
"type": "meanvar",
"counts": torch.zeros((num_channels), dtype=torch.float64, device=device),
"values": torch.zeros((2, 1, num_channels, 1, 1), dtype=torch.float64, device=device),
"type": "meanvar",
"counts": torch.zeros((num_channels), dtype=torch.float64, device=device),
"values": torch.zeros((2, 1, num_channels, 1, 1), dtype=torch.float64, device=device),
},
psd_meanvar = {
"type": "meanvar",
"counts": torch.zeros((num_channels), dtype=torch.float64, device=device),
"values": torch.zeros((2, 1, num_channels, lmax, 1), dtype=torch.float64, device=device),
}
)

if wind_channels is not None:
num_wind_channels = len(wind_channels[0])
stats["wind_meanvar"] = {
"type": "meanvar",
"counts": torch.zeros((num_wind_channels), dtype=torch.float64, device=device),
"values": torch.zeros((2, 1, num_wind_channels, 1, 1), dtype=torch.float64, device=device),
"type": "meanvar",
"counts": torch.zeros((num_wind_channels), dtype=torch.float64, device=device),
"values": torch.zeros((2, 1, num_wind_channels, 1, 1), dtype=torch.float64, device=device),
}
stats["winddiff_meanvar"] = {
"type": "meanvar",
"type": "meanvar",
"counts": torch.zeros((num_wind_channels), dtype=torch.float64, device=device),
"values": torch.zeros((2, 1, num_wind_channels, 1, 1), dtype=torch.float64, device=device),
"values": torch.zeros((2, 1, num_wind_channels, 1, 1), dtype=torch.float64, device=device),
}

# compute local stats
progress = DistributedProgressBar(num_samples_total, comm)
start = time.time()
for filename, index_bounds in mapping.items():
tmpstats = get_file_stats(filename, slice(index_bounds[0], index_bounds[1]+1), wind_channels, quadrature, fail_on_nan, dt, batch_size, device, progress)
tmpstats = get_file_stats(filename, slice(index_bounds[0], index_bounds[1]+1), wind_channels, quadrature, fail_on_nan, dt, batch_size, device, progress, sht)
stats = welford_combine(stats, tmpstats)

# wait for everybody else
Expand Down Expand Up @@ -648,6 +696,7 @@ def get_stats(input_path: str, output_path: str, metadata_file: str,
# compute global stds:
stats["global_meanvar"]["values"][1, ...] = np.sqrt(stats["global_meanvar"]["values"][1, ...] / stats["global_meanvar"]["counts"][None, :, None, None])
stats["time_diff_meanvar"]["values"][1, ...] = np.sqrt(stats["time_diff_meanvar"]["values"][1, ...] / stats["time_diff_meanvar"]["counts"][None, :, None, None])
stats["psd_meanvar"]["values"][1, ...] = np.sqrt(stats["psd_meanvar"]["values"][1, ...] / stats["psd_meanvar"]["counts"][None, :, None, None])

# overwrite the wind channels
if wind_channels is not None:
Expand All @@ -672,6 +721,8 @@ def get_stats(input_path: str, output_path: str, metadata_file: str,
np.save(os.path.join(output_path, 'time_means.npy'), stats["time_means"]["values"].astype(np.float32))
np.save(os.path.join(output_path, f'time_diff_means_dt{dt}.npy'), stats["time_diff_meanvar"]["values"][0, ...].astype(np.float32))
np.save(os.path.join(output_path, f'time_diff_stds_dt{dt}.npy'), stats["time_diff_meanvar"]["values"][1, ...].astype(np.float32))
np.save(os.path.join(output_path, 'psd_means.npy'), stats["psd_meanvar"]["values"][0, ..., 0].astype(np.float32))
np.save(os.path.join(output_path, 'psd_stds.npy'), stats["psd_meanvar"]["values"][1, ..., 0].astype(np.float32))

duration = time.time() - start
print(f"Saving stats done. Duration: {duration:.2f}s", flush=True)
Expand Down
Loading
Loading