diff --git a/fme/ace/data_loading/batch_data.py b/fme/ace/data_loading/batch_data.py index f94a02a05..ea3e13348 100644 --- a/fme/ace/data_loading/batch_data.py +++ b/fme/ace/data_loading/batch_data.py @@ -12,6 +12,7 @@ from fme.core.dataset.dataset import DatasetItem from fme.core.device import get_device +from fme.core.distributed import Distributed from fme.core.labels import BatchLabels, LabelEncoding from fme.core.tensors import repeat_interleave_batch_dim, unfold_ensemble_dim from fme.core.typing_ import EnsembleTensorDict, TensorDict, TensorMapping @@ -170,6 +171,18 @@ def to_device(self) -> "BatchData": labels=self.labels.to(device) if self.labels is not None else None, ) + def scatter_spatial(self, global_img_shape: tuple[int, int]) -> "BatchData": + """Slice data tensors to the local spatial chunk.""" + dist = Distributed.get_instance() + return self.__class__( + data=dist.scatter_spatial(dict(self.data), global_img_shape), + time=self.time, + horizontal_dims=self.horizontal_dims, + epoch=self.epoch, + labels=self.labels, + n_ensemble=self.n_ensemble, + ) + def to_cpu(self) -> "BatchData": return self.__class__( data={k: v.cpu() for k, v in self.data.items()}, diff --git a/fme/ace/data_loading/config.py b/fme/ace/data_loading/config.py index 289971186..2a4bac517 100644 --- a/fme/ace/data_loading/config.py +++ b/fme/ace/data_loading/config.py @@ -78,10 +78,11 @@ def available_labels(self) -> set[str] | None: def __post_init__(self): dist = Distributed.get_instance() - if self.batch_size % dist.world_size != 0: + if self.batch_size % dist.total_data_parallel_ranks != 0: raise ValueError( - "batch_size must be divisible by the number of parallel " - f"workers, got {self.batch_size} and {dist.world_size}" + "batch_size must be divisible by the number of data-parallel " + f"workers, got {self.batch_size} and " + f"{dist.total_data_parallel_ranks}" ) self._zarr_engine_used = self.dataset.zarr_engine_used if self.time_buffer < 0: diff --git a/fme/ace/data_loading/getters.py b/fme/ace/data_loading/getters.py index bb060117b..d1552a999 100644 --- a/fme/ace/data_loading/getters.py +++ b/fme/ace/data_loading/getters.py @@ -49,8 +49,13 @@ def _get_sampler( ) -> torch.utils.data.Sampler: dist = Distributed.get_instance() if sample_with_replacement_dataset_size is not None: + dist.require_no_spatial_parallelism( + "sample_with_replacement is not supported with spatial " + "parallelism. Spatial co-ranks would draw different samples, " + "producing corrupted data after scatter_spatial reassembly." + ) local_sample_with_replacement_dataset_size = ( - sample_with_replacement_dataset_size // dist.world_size + sample_with_replacement_dataset_size // dist.total_data_parallel_ranks ) sampler = torch.utils.data.RandomSampler( dataset, diff --git a/fme/ace/data_loading/gridded_data.py b/fme/ace/data_loading/gridded_data.py index bd7fe86b6..1fd4d0484 100644 --- a/fme/ace/data_loading/gridded_data.py +++ b/fme/ace/data_loading/gridded_data.py @@ -50,7 +50,9 @@ def __init__( will be on the current device. """ self._loader = loader - self._properties = properties.to_device() + shape = properties.horizontal_coordinates.shape + self._global_img_shape: tuple[int, int] = (shape[-2], shape[-1]) + self._properties = properties.to_device().localize() self._timestep = self._properties.timestep self._vertical_coordinate = self._properties.vertical_coordinate self._mask_provider = self._properties.mask_provider @@ -72,7 +74,8 @@ def _get_gpu_loader( self, base_loader: DataLoader[BatchData] ) -> DataLoader[BatchData]: def modify_and_on_device(batch: BatchData) -> BatchData: - return self._modifier(batch).to_device() + batch = self._modifier(batch) + return batch.to_device().scatter_spatial(self._global_img_shape) return SizedMap(modify_and_on_device, base_loader) @@ -174,21 +177,23 @@ def __init__( will be on the current device. """ self._loader = loader - self._properties = properties.to_device() + shape = properties.horizontal_coordinates.shape + self._global_img_shape: tuple[int, int] = (shape[-2], shape[-1]) + self._properties = properties.to_device().localize() self._n_initial_conditions: int | None = None if isinstance(initial_condition, PrognosticStateDataRequirements): self._initial_condition: PrognosticState = get_initial_condition( - loader, initial_condition + self.loader, initial_condition ) else: self._initial_condition = initial_condition.to_device() @property def loader(self) -> DataLoader[BatchData]: - def on_device(batch: BatchData) -> BatchData: - return batch.to_device() + def scatter_and_on_device(batch: BatchData) -> BatchData: + return batch.to_device().scatter_spatial(self._global_img_shape) - return SizedMap(on_device, self._loader) + return SizedMap(scatter_and_on_device, self._loader) @property def variable_metadata(self) -> dict[str, VariableMetadata]: diff --git a/fme/ace/data_loading/inference.py b/fme/ace/data_loading/inference.py index 7f9455e12..1bdeea445 100644 --- a/fme/ace/data_loading/inference.py +++ b/fme/ace/data_loading/inference.py @@ -286,7 +286,7 @@ def _get_batch_data(self, index) -> BatchData: sample_tuples = [] for i_member in range(self._n_initial_conditions): # check if sample is one this local rank should process - if i_member % dist.world_size != dist.rank: + if i_member % dist.total_data_parallel_ranks != dist.data_parallel_rank: continue i_window_start = i_start + self._start_indices[i_member] i_window_end = i_window_start + self._forward_steps_in_memory + 1 @@ -334,7 +334,9 @@ def __getitem__(self, index) -> BatchData: for key, value in self._persistence_data.data.items(): updated_data[key] = value.expand_as(result.data[key]) result.data = {**result.data, **updated_data} - assert result.time.shape[0] == self._n_initial_conditions // dist.world_size + assert result.time.shape[0] == ( + self._n_initial_conditions // dist.total_data_parallel_ranks + ) return result def __len__(self) -> int: diff --git a/fme/ace/data_loading/test_data_loader.py b/fme/ace/data_loading/test_data_loader.py index 99463838f..5c6bc6282 100644 --- a/fme/ace/data_loading/test_data_loader.py +++ b/fme/ace/data_loading/test_data_loader.py @@ -1,6 +1,7 @@ """This file contains unit tests related to creating torch Datasets from climate data (e.g. netCDF files).""" +import datetime import math import os import pathlib @@ -30,14 +31,20 @@ ) from fme.ace.data_loading.perturbation import PerturbationSelector, SSTPerturbation from fme.ace.requirements import DataRequirements, PrognosticStateDataRequirements -from fme.core.coordinates import HybridSigmaPressureCoordinate +from fme.core.coordinates import ( + HybridSigmaPressureCoordinate, + LatLonCoordinates, + NullVerticalCoordinate, +) from fme.core.dataset.concat import ConcatDatasetConfig +from fme.core.dataset.data_typing import VariableMetadata from fme.core.dataset.merged import MergeDatasetConfig, MergeNoConcatDatasetConfig +from fme.core.dataset.properties import DatasetProperties from fme.core.dataset.schedule import IntMilestone, IntSchedule from fme.core.dataset.xarray import XarrayDataConfig from fme.core.device import using_gpu from fme.core.distributed.distributed import Distributed -from fme.core.distributed.model_torch_distributed import ModelTorchDistributed +from fme.core.mask_provider import MaskProvider from fme.core.testing.regression import validate_tensor_dict from fme.core.typing_ import Slice @@ -203,10 +210,6 @@ def test_ensemble_loader_n_samples(tmp_path, num_ensemble_members=3, n_samples=1 def test_xarray_loader(tmp_path): """Checks that vertical coordinates are present.""" dist = Distributed.get_instance() - if isinstance(dist._distributed, ModelTorchDistributed): - pytest.xfail( - "ModelTorchDistributed slicing along spatial dimensions is not implemented." - ) tmp_path = dist.scatter_object(tmp_path) # get the root value global_batch_size = 24 if dist.is_root(): @@ -1173,3 +1176,67 @@ def test_pinned_memory(tmp_path, time_buffer: int): for batch in loader: tensor = next(iter(batch.data.values())) assert tensor.is_pinned() is using_gpu() + + +@pytest.mark.parallel +def test_localize_properties(): + """Verify DatasetProperties.localize() partitions coords and masks across ranks.""" + dist = Distributed.get_instance() + n_lat, n_lon = N_LAT, N_LON + lat = torch.linspace(-90.0, 90.0, n_lat) + lon = torch.linspace(0.0, 360.0, n_lon) + coords = LatLonCoordinates(lat=lat, lon=lon) + mask_tensor = torch.arange(n_lat * n_lon, dtype=torch.float32).reshape( + 1, n_lat, n_lon + ) + mask_provider = MaskProvider(masks={"mask_test": mask_tensor}) + timestep = datetime.timedelta(hours=6) + metadata = {"temp": VariableMetadata(units="K", long_name="Temperature")} + vertical = NullVerticalCoordinate() + props = DatasetProperties( + variable_metadata=metadata, + vertical_coordinate=vertical, + horizontal_coordinates=coords, + mask_provider=mask_provider, + timestep=timestep, + is_remote=False, + all_labels=None, + ) + + local = props.localize() + + # Unchanged fields + assert local.variable_metadata is metadata + assert local.vertical_coordinate is vertical + assert local.timestep == timestep + + # Gather local coordinates to root and verify full, non-overlapping coverage. + assert isinstance(local.horizontal_coordinates, LatLonCoordinates) + local_lat = local.horizontal_coordinates.lat.tolist() + local_lon = local.horizontal_coordinates.lon.tolist() + all_lats = dist.gather_object(local_lat) + all_lons = dist.gather_object(local_lon) + if dist.is_root(): + assert all_lats is not None and all_lons is not None + # Spatial co-ranks (same data_parallel_rank) should have distinct + # lat/lon slices whose union covers the global coordinates exactly. + combined_lat = sorted(set(v for rank_lat in all_lats for v in rank_lat)) + combined_lon = sorted(set(v for rank_lon in all_lons for v in rank_lon)) + assert combined_lat == lat.tolist(), "lat coverage incomplete or has gaps" + assert combined_lon == lon.tolist(), "lon coverage incomplete or has gaps" + + # Gather local masks to root and verify they tile to the global mask. + assert isinstance(local.mask_provider, MaskProvider) + local_mask = local.mask_provider.masks["mask_test"] + h_slice, w_slice = dist.get_local_slices((n_lat, n_lon)) + all_slices = dist.gather_object((h_slice, w_slice)) + all_masks = dist.gather_object(local_mask.tolist()) + if dist.is_root(): + assert all_slices is not None and all_masks is not None + canvas = torch.zeros_like(mask_tensor) + for (hs, ws), mask_data in zip(all_slices, all_masks): + canvas[..., hs, ws] += torch.tensor(mask_data) + # Each cell should be covered once per data-parallel rank (spatial + # co-ranks have distinct slices, data-parallel ranks have identical ones). + expected_count = dist.total_data_parallel_ranks + torch.testing.assert_close(canvas, mask_tensor * expected_count) diff --git a/fme/core/coordinates.py b/fme/core/coordinates.py index 552dabf33..68c2fbb29 100644 --- a/fme/core/coordinates.py +++ b/fme/core/coordinates.py @@ -17,6 +17,7 @@ from fme.core.corrector.registry import CorrectorABC from fme.core.derived_variables import compute_derived_quantities from fme.core.device import get_device +from fme.core.distributed import Distributed from fme.core.gridded_ops import GriddedOperations, HEALPixOperations, LatLonOperations from fme.core.mask_provider import MaskProvider, MaskProviderABC, NullMaskProvider from fme.core.ocean_derived_variables import compute_ocean_derived_quantities @@ -693,6 +694,17 @@ def meshgrid(self) -> tuple[torch.Tensor, torch.Tensor]: def shape(self) -> tuple[int, ...]: pass + @abc.abstractmethod + def localize(self: HC) -> HC: + """Return a copy with coordinates sliced to the local spatial chunk. + + Uses ``Distributed.get_instance()`` to determine the local slices. + Coordinate types that do not support spatial parallelism should raise + ``SpatialParallelismNotImplemented`` when the distributed layout + requires slicing. + """ + pass + @abc.abstractmethod def get_state(self) -> TensorMapping: pass @@ -788,6 +800,14 @@ def meshgrid(self) -> tuple[torch.Tensor, torch.Tensor]: def shape(self) -> tuple[int, int]: return (len(self.lat), len(self.lon)) + def localize(self) -> "LatLonCoordinates": + dist = Distributed.get_instance() + h_slice, w_slice = dist.get_local_slices(self.shape) + return LatLonCoordinates( + lat=self.lat[h_slice], + lon=self.lon[w_slice], + ) + def get_state(self) -> TensorMapping: return {"lat": self.lat, "lon": self.lon} @@ -939,6 +959,12 @@ def meshgrid(self) -> tuple[torch.Tensor, torch.Tensor]: def shape(self) -> tuple[int, int, int]: return (len(self.face), len(self.width), len(self.height)) + def localize(self) -> "HEALPixCoordinates": + Distributed.get_instance().require_no_spatial_parallelism( + "HEALPixCoordinates does not support spatial parallelism." + ) + return self + def get_state(self) -> TensorMapping: return {"face": self.face, "height": self.height, "width": self.width} diff --git a/fme/core/dataset/properties.py b/fme/core/dataset/properties.py index be17f8cbe..5399492dc 100644 --- a/fme/core/dataset/properties.py +++ b/fme/core/dataset/properties.py @@ -38,6 +38,20 @@ def to_device(self) -> "DatasetProperties": self.all_labels, ) + def localize(self) -> "DatasetProperties": + """Return a copy with coordinates and masks sliced to the local + spatial chunk based on the current distributed layout. + """ + return DatasetProperties( + self.variable_metadata, + self.vertical_coordinate, + self.horizontal_coordinates.localize(), + self.mask_provider.localize(), + self.timestep, + self.is_remote, + self.all_labels, + ) + def update(self, other: "DatasetProperties", strict: bool = True): try: if self.timestep != other.timestep: diff --git a/fme/core/distributed/distributed.py b/fme/core/distributed/distributed.py index 9f3d842e9..d4ff33eb7 100644 --- a/fme/core/distributed/distributed.py +++ b/fme/core/distributed/distributed.py @@ -16,6 +16,12 @@ T = TypeVar("T") +class SpatialParallelismNotImplemented(NotImplementedError): + """Raised when a code path is incompatible with spatial parallelism.""" + + pass + + class Distributed: """ A class to represent the distributed concerns for FME training. @@ -183,6 +189,15 @@ def world_size(self) -> int: """ return self._distributed.total_ranks + def require_no_spatial_parallelism(self, msg: str) -> None: + """Raise if spatial parallelism is active. + + Use this to guard code paths that are known to be incorrect + when spatial co-ranks exist (world_size > total_data_parallel_ranks). + """ + if self.world_size != self.total_data_parallel_ranks: + raise SpatialParallelismNotImplemented(msg) + def get_sampler( self, dataset: torch.utils.data.Dataset, diff --git a/fme/core/mask_provider.py b/fme/core/mask_provider.py index 43b0587df..b5598a693 100644 --- a/fme/core/mask_provider.py +++ b/fme/core/mask_provider.py @@ -6,6 +6,7 @@ import torch +from fme.core.distributed import Distributed from fme.core.masking import NullMasking, StaticMasking from fme.core.typing_ import TensorDict, TensorMapping @@ -27,6 +28,11 @@ def update(self: SelfType, other: SelfType) -> None: ... @abc.abstractmethod def build_output_masker(self) -> Callable[[TensorMapping], TensorDict]: ... + @abc.abstractmethod + def localize(self: SelfType) -> SelfType: + """Return a copy with masks sliced to the local spatial chunk.""" + ... + @abc.abstractmethod def get_state(self) -> dict[str, Any]: ... @@ -38,6 +44,9 @@ def get_mask_tensor_for(self, name: str) -> torch.Tensor | None: def to(self, device: str) -> "_NullMaskProvider": return self + def localize(self) -> "_NullMaskProvider": + return self + def update(self, other: MaskProviderABC) -> None: if not isinstance(other, _NullMaskProvider): raise ValueError( @@ -124,6 +133,15 @@ def to(self, device: str) -> "MaskProvider": {name: tensor.to(device) for name, tensor in self.masks.items()} ) + def localize(self) -> "MaskProvider": + if not self._masks: + return self + dist = Distributed.get_instance() + example_mask = next(iter(self._masks.values())) + img_shape = example_mask.shape[-2:] + {k: v[dist.get_local_slices(v.shape)].contiguous() for k, v in self._masks.items()} + ) + def update(self, other: "MaskProvider") -> None: """Update the MaskProvider's masks with masks from another MaskProvider.