Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions fme/ace/data_loading/batch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from fme.core.dataset.dataset import DatasetItem
from fme.core.device import get_device
from fme.core.distributed import Distributed
from fme.core.labels import BatchLabels, LabelEncoding
from fme.core.tensors import repeat_interleave_batch_dim, unfold_ensemble_dim
from fme.core.typing_ import EnsembleTensorDict, TensorDict, TensorMapping
Expand Down Expand Up @@ -170,6 +171,18 @@ def to_device(self) -> "BatchData":
labels=self.labels.to(device) if self.labels is not None else None,
)

def scatter_spatial(self, global_img_shape: tuple[int, int]) -> "BatchData":
"""Slice data tensors to the local spatial chunk."""
dist = Distributed.get_instance()
return self.__class__(
data=dist.scatter_spatial(dict(self.data), global_img_shape),
time=self.time,
horizontal_dims=self.horizontal_dims,
epoch=self.epoch,
labels=self.labels,
n_ensemble=self.n_ensemble,
)

def to_cpu(self) -> "BatchData":
return self.__class__(
data={k: v.cpu() for k, v in self.data.items()},
Expand Down
7 changes: 4 additions & 3 deletions fme/ace/data_loading/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,11 @@ def available_labels(self) -> set[str] | None:

def __post_init__(self):
dist = Distributed.get_instance()
if self.batch_size % dist.world_size != 0:
if self.batch_size % dist.total_data_parallel_ranks != 0:
raise ValueError(
"batch_size must be divisible by the number of parallel "
f"workers, got {self.batch_size} and {dist.world_size}"
"batch_size must be divisible by the number of data-parallel "
f"workers, got {self.batch_size} and "
f"{dist.total_data_parallel_ranks}"
)
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataLoaderConfig validates that batch_size is divisible by total_data_parallel_ranks, but there is no analogous validation for sample_with_replacement. Since _get_sampler uses floor division when splitting sample_with_replacement across ranks, a non-divisible value will result in fewer total samples than requested. Consider adding a __post_init__ check that sample_with_replacement (when set) is divisible by dist.total_data_parallel_ranks, and raise a clear ValueError if not.

Suggested change
)
)
if (
self.sample_with_replacement is not None
and self.sample_with_replacement % dist.total_data_parallel_ranks != 0
):
raise ValueError(
"sample_with_replacement must be divisible by the number of "
"data-parallel workers, got "
f"{self.sample_with_replacement} and "
f"{dist.total_data_parallel_ranks}"
)

Copilot uses AI. Check for mistakes.
self._zarr_engine_used = self.dataset.zarr_engine_used
if self.time_buffer < 0:
Expand Down
7 changes: 6 additions & 1 deletion fme/ace/data_loading/getters.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,13 @@ def _get_sampler(
) -> torch.utils.data.Sampler:
dist = Distributed.get_instance()
if sample_with_replacement_dataset_size is not None:
dist.require_no_spatial_parallelism(
"sample_with_replacement is not supported with spatial "
"parallelism. Spatial co-ranks would draw different samples, "
"producing corrupted data after scatter_spatial reassembly."
)
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sample_with_replacement_dataset_size is split across ranks using integer floor division. If the requested size is not divisible by dist.total_data_parallel_ranks, the total number of samples drawn across all ranks will be smaller than the configured value. Consider validating divisibility and raising a ValueError (or otherwise documenting/handling the remainder) so sample_with_replacement behaves predictably in distributed runs.

Suggested change
)
)
if (
sample_with_replacement_dataset_size
% dist.total_data_parallel_ranks
!= 0
):
raise ValueError(
"sample_with_replacement_dataset_size "
f"({sample_with_replacement_dataset_size}) must be divisible "
"by the total number of data-parallel ranks "
f"({dist.total_data_parallel_ranks}) when using "
"sample_with_replacement."
)

Copilot uses AI. Check for mistakes.
local_sample_with_replacement_dataset_size = (
sample_with_replacement_dataset_size // dist.world_size
sample_with_replacement_dataset_size // dist.total_data_parallel_ranks
)
sampler = torch.utils.data.RandomSampler(
dataset,
Expand Down
19 changes: 12 additions & 7 deletions fme/ace/data_loading/gridded_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def __init__(
will be on the current device.
"""
self._loader = loader
self._properties = properties.to_device()
shape = properties.horizontal_coordinates.shape
self._global_img_shape: tuple[int, int] = (shape[-2], shape[-1])
self._properties = properties.to_device().localize()
Comment on lines +53 to +55
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description in metadata appears to still be the template (placeholders like “Resolves #” and unchecked boxes). Please replace it with an accurate summary of the motivation, concrete change list, and test status so reviewers can validate scope and risk.

Copilot uses AI. Check for mistakes.
self._timestep = self._properties.timestep
self._vertical_coordinate = self._properties.vertical_coordinate
self._mask_provider = self._properties.mask_provider
Expand All @@ -72,7 +74,8 @@ def _get_gpu_loader(
self, base_loader: DataLoader[BatchData]
) -> DataLoader[BatchData]:
def modify_and_on_device(batch: BatchData) -> BatchData:
return self._modifier(batch).to_device()
batch = self._modifier(batch)
return batch.to_device().scatter_spatial(self._global_img_shape)

return SizedMap(modify_and_on_device, base_loader)

Expand Down Expand Up @@ -174,21 +177,23 @@ def __init__(
will be on the current device.
"""
self._loader = loader
self._properties = properties.to_device()
shape = properties.horizontal_coordinates.shape
self._global_img_shape: tuple[int, int] = (shape[-2], shape[-1])
self._properties = properties.to_device().localize()
self._n_initial_conditions: int | None = None
if isinstance(initial_condition, PrognosticStateDataRequirements):
self._initial_condition: PrognosticState = get_initial_condition(
loader, initial_condition
self.loader, initial_condition
)
else:
self._initial_condition = initial_condition.to_device()

@property
def loader(self) -> DataLoader[BatchData]:
def on_device(batch: BatchData) -> BatchData:
return batch.to_device()
def scatter_and_on_device(batch: BatchData) -> BatchData:
return batch.to_device().scatter_spatial(self._global_img_shape)

return SizedMap(on_device, self._loader)
return SizedMap(scatter_and_on_device, self._loader)

@property
def variable_metadata(self) -> dict[str, VariableMetadata]:
Expand Down
6 changes: 4 additions & 2 deletions fme/ace/data_loading/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def _get_batch_data(self, index) -> BatchData:
sample_tuples = []
for i_member in range(self._n_initial_conditions):
# check if sample is one this local rank should process
if i_member % dist.world_size != dist.rank:
if i_member % dist.total_data_parallel_ranks != dist.data_parallel_rank:
continue
i_window_start = i_start + self._start_indices[i_member]
i_window_end = i_window_start + self._forward_steps_in_memory + 1
Expand Down Expand Up @@ -334,7 +334,9 @@ def __getitem__(self, index) -> BatchData:
for key, value in self._persistence_data.data.items():
updated_data[key] = value.expand_as(result.data[key])
result.data = {**result.data, **updated_data}
assert result.time.shape[0] == self._n_initial_conditions // dist.world_size
assert result.time.shape[0] == (
self._n_initial_conditions // dist.total_data_parallel_ranks
)
Comment on lines +337 to +339
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__getitem__ uses an assert to enforce that n_initial_conditions is evenly divisible by dist.total_data_parallel_ranks. Assertions can be disabled with Python optimization flags, and this will currently fail late (or be skipped) instead of raising a clear, actionable exception. Please replace this with an explicit validation (e.g., raise ValueError with a message that includes n_initial_conditions and total_data_parallel_ranks), ideally performed once during initialization rather than per-sample.

Suggested change
assert result.time.shape[0] == (
self._n_initial_conditions // dist.total_data_parallel_ranks
)
total_data_parallel_ranks = dist.total_data_parallel_ranks
n_initial_conditions = self._n_initial_conditions
if total_data_parallel_ranks <= 0:
raise ValueError(
"total_data_parallel_ranks must be a positive integer, "
f"got {total_data_parallel_ranks}."
)
if n_initial_conditions % total_data_parallel_ranks != 0:
raise ValueError(
"n_initial_conditions must be evenly divisible by "
"total_data_parallel_ranks, but got "
f"n_initial_conditions={n_initial_conditions} and "
f"total_data_parallel_ranks={total_data_parallel_ranks}."
)
expected_time_dim = n_initial_conditions // total_data_parallel_ranks
if result.time.shape[0] != expected_time_dim:
raise ValueError(
"Unexpected time dimension in inference batch: "
f"expected {expected_time_dim} based on "
f"n_initial_conditions={n_initial_conditions} and "
f"total_data_parallel_ranks={total_data_parallel_ranks}, "
f"but got {result.time.shape[0]}."
)

Copilot uses AI. Check for mistakes.
return result

def __len__(self) -> int:
Expand Down
79 changes: 73 additions & 6 deletions fme/ace/data_loading/test_data_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""This file contains unit tests related to creating torch Datasets from climate
data (e.g. netCDF files)."""

import datetime
import math
import os
Comment on lines +4 to 6
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description still contains the template placeholders (e.g., generic bullets / unchecked checklist). Please update it with the actual purpose of the change, key symbols impacted, and whether tests/dependency steps were done so reviewers can validate intent and scope.

Copilot uses AI. Check for mistakes.
import pathlib
Expand Down Expand Up @@ -30,14 +31,20 @@
)
from fme.ace.data_loading.perturbation import PerturbationSelector, SSTPerturbation
from fme.ace.requirements import DataRequirements, PrognosticStateDataRequirements
from fme.core.coordinates import HybridSigmaPressureCoordinate
from fme.core.coordinates import (
HybridSigmaPressureCoordinate,
LatLonCoordinates,
NullVerticalCoordinate,
)
from fme.core.dataset.concat import ConcatDatasetConfig
from fme.core.dataset.data_typing import VariableMetadata
from fme.core.dataset.merged import MergeDatasetConfig, MergeNoConcatDatasetConfig
from fme.core.dataset.properties import DatasetProperties
from fme.core.dataset.schedule import IntMilestone, IntSchedule
from fme.core.dataset.xarray import XarrayDataConfig
from fme.core.device import using_gpu
from fme.core.distributed.distributed import Distributed
from fme.core.distributed.model_torch_distributed import ModelTorchDistributed
from fme.core.mask_provider import MaskProvider
from fme.core.testing.regression import validate_tensor_dict
from fme.core.typing_ import Slice

Expand Down Expand Up @@ -203,10 +210,6 @@ def test_ensemble_loader_n_samples(tmp_path, num_ensemble_members=3, n_samples=1
def test_xarray_loader(tmp_path):
"""Checks that vertical coordinates are present."""
dist = Distributed.get_instance()
if isinstance(dist._distributed, ModelTorchDistributed):
pytest.xfail(
"ModelTorchDistributed slicing along spatial dimensions is not implemented."
)
tmp_path = dist.scatter_object(tmp_path) # get the root value
global_batch_size = 24
if dist.is_root():
Expand Down Expand Up @@ -1173,3 +1176,67 @@ def test_pinned_memory(tmp_path, time_buffer: int):
for batch in loader:
tensor = next(iter(batch.data.values()))
assert tensor.is_pinned() is using_gpu()


@pytest.mark.parallel
def test_localize_properties():
"""Verify DatasetProperties.localize() partitions coords and masks across ranks."""
dist = Distributed.get_instance()
n_lat, n_lon = N_LAT, N_LON
lat = torch.linspace(-90.0, 90.0, n_lat)
lon = torch.linspace(0.0, 360.0, n_lon)
coords = LatLonCoordinates(lat=lat, lon=lon)
mask_tensor = torch.arange(n_lat * n_lon, dtype=torch.float32).reshape(
1, n_lat, n_lon
)
mask_provider = MaskProvider(masks={"mask_test": mask_tensor})
timestep = datetime.timedelta(hours=6)
metadata = {"temp": VariableMetadata(units="K", long_name="Temperature")}
vertical = NullVerticalCoordinate()
props = DatasetProperties(
variable_metadata=metadata,
vertical_coordinate=vertical,
horizontal_coordinates=coords,
mask_provider=mask_provider,
timestep=timestep,
is_remote=False,
all_labels=None,
)

local = props.localize()

# Unchanged fields
assert local.variable_metadata is metadata
assert local.vertical_coordinate is vertical
assert local.timestep == timestep

# Gather local coordinates to root and verify full, non-overlapping coverage.
assert isinstance(local.horizontal_coordinates, LatLonCoordinates)
local_lat = local.horizontal_coordinates.lat.tolist()
local_lon = local.horizontal_coordinates.lon.tolist()
all_lats = dist.gather_object(local_lat)
all_lons = dist.gather_object(local_lon)
if dist.is_root():
assert all_lats is not None and all_lons is not None
# Spatial co-ranks (same data_parallel_rank) should have distinct
# lat/lon slices whose union covers the global coordinates exactly.
combined_lat = sorted(set(v for rank_lat in all_lats for v in rank_lat))
combined_lon = sorted(set(v for rank_lon in all_lons for v in rank_lon))
assert combined_lat == lat.tolist(), "lat coverage incomplete or has gaps"
assert combined_lon == lon.tolist(), "lon coverage incomplete or has gaps"
Comment on lines +1221 to +1226
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test claims it verifies “full, non-overlapping coverage”, but it only checks set-union equality. Overlapping slices (duplicates) would still pass. Consider also asserting that, within each data-parallel group, slices are disjoint (or that the sum of local lengths equals the global length) to actually detect overlaps.

Copilot uses AI. Check for mistakes.

# Gather local masks to root and verify they tile to the global mask.
assert isinstance(local.mask_provider, MaskProvider)
local_mask = local.mask_provider.masks["mask_test"]
h_slice, w_slice = dist.get_local_slices((n_lat, n_lon))
all_slices = dist.gather_object((h_slice, w_slice))
all_masks = dist.gather_object(local_mask.tolist())
if dist.is_root():
assert all_slices is not None and all_masks is not None
canvas = torch.zeros_like(mask_tensor)
for (hs, ws), mask_data in zip(all_slices, all_masks):
canvas[..., hs, ws] += torch.tensor(mask_data)
# Each cell should be covered once per data-parallel rank (spatial
# co-ranks have distinct slices, data-parallel ranks have identical ones).
expected_count = dist.total_data_parallel_ranks
torch.testing.assert_close(canvas, mask_tensor * expected_count)
26 changes: 26 additions & 0 deletions fme/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from fme.core.corrector.registry import CorrectorABC
from fme.core.derived_variables import compute_derived_quantities
from fme.core.device import get_device
from fme.core.distributed import Distributed
from fme.core.gridded_ops import GriddedOperations, HEALPixOperations, LatLonOperations
from fme.core.mask_provider import MaskProvider, MaskProviderABC, NullMaskProvider
from fme.core.ocean_derived_variables import compute_ocean_derived_quantities
Expand Down Expand Up @@ -693,6 +694,17 @@ def meshgrid(self) -> tuple[torch.Tensor, torch.Tensor]:
def shape(self) -> tuple[int, ...]:
pass

@abc.abstractmethod
def localize(self: HC) -> HC:
"""Return a copy with coordinates sliced to the local spatial chunk.

Uses ``Distributed.get_instance()`` to determine the local slices.
Coordinate types that do not support spatial parallelism should raise
``SpatialParallelismNotImplemented`` when the distributed layout
requires slicing.
Comment on lines +703 to +704
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring references SpatialParallelismNotImplemented, but that symbol isn’t imported in this module and isn’t re-exported from fme.core.distributed (only Distributed is). This makes the guidance hard to follow and can confuse users/type checkers. Consider importing the exception here (or fully-qualifying it in the docstring) and/or re-exporting it from fme.core.distributed if it’s intended to be part of the public API.

Suggested change
``SpatialParallelismNotImplemented`` when the distributed layout
requires slicing.
an appropriate exception when the distributed layout requires slicing.

Copilot uses AI. Check for mistakes.
"""
pass

@abc.abstractmethod
def get_state(self) -> TensorMapping:
pass
Expand Down Expand Up @@ -788,6 +800,14 @@ def meshgrid(self) -> tuple[torch.Tensor, torch.Tensor]:
def shape(self) -> tuple[int, int]:
return (len(self.lat), len(self.lon))

def localize(self) -> "LatLonCoordinates":
dist = Distributed.get_instance()
h_slice, w_slice = dist.get_local_slices(self.shape)
return LatLonCoordinates(
lat=self.lat[h_slice],
lon=self.lon[w_slice],
)

def get_state(self) -> TensorMapping:
return {"lat": self.lat, "lon": self.lon}

Expand Down Expand Up @@ -939,6 +959,12 @@ def meshgrid(self) -> tuple[torch.Tensor, torch.Tensor]:
def shape(self) -> tuple[int, int, int]:
return (len(self.face), len(self.width), len(self.height))

def localize(self) -> "HEALPixCoordinates":
Distributed.get_instance().require_no_spatial_parallelism(
"HEALPixCoordinates does not support spatial parallelism."
)
return self

def get_state(self) -> TensorMapping:
return {"face": self.face, "height": self.height, "width": self.width}

Expand Down
14 changes: 14 additions & 0 deletions fme/core/dataset/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ def to_device(self) -> "DatasetProperties":
self.all_labels,
)

def localize(self) -> "DatasetProperties":
"""Return a copy with coordinates and masks sliced to the local
spatial chunk based on the current distributed layout.
"""
return DatasetProperties(
self.variable_metadata,
self.vertical_coordinate,
self.horizontal_coordinates.localize(),
self.mask_provider.localize(),
self.timestep,
self.is_remote,
self.all_labels,
)

def update(self, other: "DatasetProperties", strict: bool = True):
try:
if self.timestep != other.timestep:
Expand Down
15 changes: 15 additions & 0 deletions fme/core/distributed/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
T = TypeVar("T")


class SpatialParallelismNotImplemented(NotImplementedError):
"""Raised when a code path is incompatible with spatial parallelism."""

pass


class Distributed:
"""
A class to represent the distributed concerns for FME training.
Expand Down Expand Up @@ -183,6 +189,15 @@ def world_size(self) -> int:
"""
return self._distributed.total_ranks

def require_no_spatial_parallelism(self, msg: str) -> None:
"""Raise if spatial parallelism is active.

Use this to guard code paths that are known to be incorrect
when spatial co-ranks exist (world_size > total_data_parallel_ranks).
"""
if self.world_size != self.total_data_parallel_ranks:
raise SpatialParallelismNotImplemented(msg)

def get_sampler(
self,
dataset: torch.utils.data.Dataset,
Expand Down
18 changes: 18 additions & 0 deletions fme/core/mask_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch

from fme.core.distributed import Distributed
from fme.core.masking import NullMasking, StaticMasking
from fme.core.typing_ import TensorDict, TensorMapping

Expand All @@ -27,6 +28,11 @@ def update(self: SelfType, other: SelfType) -> None: ...
@abc.abstractmethod
def build_output_masker(self) -> Callable[[TensorMapping], TensorDict]: ...

@abc.abstractmethod
def localize(self: SelfType) -> SelfType:
"""Return a copy with masks sliced to the local spatial chunk."""
...

@abc.abstractmethod
def get_state(self) -> dict[str, Any]: ...

Expand All @@ -38,6 +44,9 @@ def get_mask_tensor_for(self, name: str) -> torch.Tensor | None:
def to(self, device: str) -> "_NullMaskProvider":
return self

def localize(self) -> "_NullMaskProvider":
return self

def update(self, other: MaskProviderABC) -> None:
if not isinstance(other, _NullMaskProvider):
raise ValueError(
Expand Down Expand Up @@ -124,6 +133,15 @@ def to(self, device: str) -> "MaskProvider":
{name: tensor.to(device) for name, tensor in self.masks.items()}
)

def localize(self) -> "MaskProvider":
if not self._masks:
return self
dist = Distributed.get_instance()
example_mask = next(iter(self._masks.values()))
img_shape = example_mask.shape[-2:]
{k: v[dist.get_local_slices(v.shape)].contiguous() for k, v in self._masks.items()}
)

def update(self, other: "MaskProvider") -> None:
"""Update the MaskProvider's masks with masks from another MaskProvider.

Expand Down
Loading