Conversation
There was a problem hiding this comment.
Pull request overview
This PR updates ACE data loading to behave correctly under spatial/model parallelism by scattering spatial tensor chunks and by switching several data-parallel calculations from world_size/rank to the data-parallel group equivalents.
Changes:
- Add spatial scattering of
BatchData.data(and slicing of some dataset properties) inGriddedData/InferenceGriddedData. - Update inference sampling logic and batch-size validation to use
total_data_parallel_ranks/data_parallel_rank. - Add a unit test covering dataset-properties scattering behavior.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
fme/ace/data_loading/gridded_data.py |
Introduces _scatter_properties and scatters spatial chunks in loaders to support spatial parallel backends. |
fme/ace/data_loading/inference.py |
Uses data-parallel rank/count (vs global rank/count) for initial-condition sharding and assertions. |
fme/ace/data_loading/getters.py |
Adjusts sample-with-replacement sizing to be per data-parallel rank count. |
fme/ace/data_loading/config.py |
Validates batch size divisibility against data-parallel ranks and updates error message. |
fme/ace/data_loading/test_data_loader.py |
Removes prior xfail and adds a new parallel test for _scatter_properties. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
fme/ace/data_loading/gridded_data.py
Outdated
| def _scatter_properties(properties: DatasetProperties) -> DatasetProperties: | ||
| """Slice horizontal coordinates and masks to the local spatial chunk.""" | ||
| dist = Distributed.get_instance() | ||
| coords = properties.horizontal_coordinates | ||
| if isinstance(coords, LatLonCoordinates): | ||
| h_slice, w_slice = dist.get_local_slices(coords.shape) | ||
| coords = LatLonCoordinates( | ||
| lat=coords.lat[h_slice], | ||
| lon=coords.lon[w_slice], | ||
| ) | ||
| mask_provider = properties.mask_provider |
There was a problem hiding this comment.
_scatter_properties only slices LatLonCoordinates, but GriddedData/InferenceGriddedData now always scatter batch.data spatially under ModelTorchDistributed. For other HorizontalCoordinates types (e.g., HEALPixCoordinates), this will leave DatasetInfo.coords/img_shape inconsistent with the locally-sliced tensors and can break downstream consumers (e.g., writers/aggregators). Consider extending _scatter_properties to slice other coordinate types (at least HEALPix height/width) or explicitly raising a NotImplementedError when spatial parallelism is enabled for unsupported coordinate grids.
| import datetime | ||
| import math | ||
| import os |
There was a problem hiding this comment.
PR description still contains the template placeholders (e.g., generic bullets / unchecked checklist). Please update it with the actual purpose of the change, key symbols impacted, and whether tests/dependency steps were done so reviewers can validate intent and scope.
2e00f53 to
3937218
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated no new comments.
Comments suppressed due to low confidence (1)
fme/ace/data_loading/getters.py:59
- When sample_with_replacement is set, this uses RandomSampler, which is not distributed-aware: each process will sample independently, so data-parallel ranks won’t get disjoint samples, and under spatial model-parallelism spatial ranks may even see different samples (breaking the expectation that spatial ranks share the same batch). Consider using a distributed sampler-with-replacement approach (e.g., seed a torch.Generator by data_parallel_rank/epoch or implement a DistributedSampler variant with replacement) so sampling is consistent across spatial ranks but distinct across data-parallel ranks.
dist = Distributed.get_instance()
if sample_with_replacement_dataset_size is not None:
local_sample_with_replacement_dataset_size = (
sample_with_replacement_dataset_size // dist.total_data_parallel_ranks
)
sampler = torch.utils.data.RandomSampler(
dataset,
num_samples=local_sample_with_replacement_dataset_size,
replacement=True,
)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
3937218 to
d6ce445
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| assert result.time.shape[0] == ( | ||
| self._n_initial_conditions // dist.total_data_parallel_ranks | ||
| ) |
There was a problem hiding this comment.
__getitem__ uses an assert to enforce that n_initial_conditions is evenly divisible by dist.total_data_parallel_ranks. Assertions can be disabled with Python optimization flags, and this will currently fail late (or be skipped) instead of raising a clear, actionable exception. Please replace this with an explicit validation (e.g., raise ValueError with a message that includes n_initial_conditions and total_data_parallel_ranks), ideally performed once during initialization rather than per-sample.
| assert result.time.shape[0] == ( | |
| self._n_initial_conditions // dist.total_data_parallel_ranks | |
| ) | |
| total_data_parallel_ranks = dist.total_data_parallel_ranks | |
| n_initial_conditions = self._n_initial_conditions | |
| if total_data_parallel_ranks <= 0: | |
| raise ValueError( | |
| "total_data_parallel_ranks must be a positive integer, " | |
| f"got {total_data_parallel_ranks}." | |
| ) | |
| if n_initial_conditions % total_data_parallel_ranks != 0: | |
| raise ValueError( | |
| "n_initial_conditions must be evenly divisible by " | |
| "total_data_parallel_ranks, but got " | |
| f"n_initial_conditions={n_initial_conditions} and " | |
| f"total_data_parallel_ranks={total_data_parallel_ranks}." | |
| ) | |
| expected_time_dim = n_initial_conditions // total_data_parallel_ranks | |
| if result.time.shape[0] != expected_time_dim: | |
| raise ValueError( | |
| "Unexpected time dimension in inference batch: " | |
| f"expected {expected_time_dim} based on " | |
| f"n_initial_conditions={n_initial_conditions} and " | |
| f"total_data_parallel_ranks={total_data_parallel_ranks}, " | |
| f"but got {result.time.shape[0]}." | |
| ) |
| "sample_with_replacement is not supported with spatial " | ||
| "parallelism. Spatial co-ranks would draw different samples, " | ||
| "producing corrupted data after scatter_spatial reassembly." | ||
| ) |
There was a problem hiding this comment.
sample_with_replacement_dataset_size is split across ranks using integer floor division. If the requested size is not divisible by dist.total_data_parallel_ranks, the total number of samples drawn across all ranks will be smaller than the configured value. Consider validating divisibility and raising a ValueError (or otherwise documenting/handling the remainder) so sample_with_replacement behaves predictably in distributed runs.
| ) | |
| ) | |
| if ( | |
| sample_with_replacement_dataset_size | |
| % dist.total_data_parallel_ranks | |
| != 0 | |
| ): | |
| raise ValueError( | |
| "sample_with_replacement_dataset_size " | |
| f"({sample_with_replacement_dataset_size}) must be divisible " | |
| "by the total number of data-parallel ranks " | |
| f"({dist.total_data_parallel_ranks}) when using " | |
| "sample_with_replacement." | |
| ) |
| "batch_size must be divisible by the number of data-parallel " | ||
| f"workers, got {self.batch_size} and " | ||
| f"{dist.total_data_parallel_ranks}" | ||
| ) |
There was a problem hiding this comment.
DataLoaderConfig validates that batch_size is divisible by total_data_parallel_ranks, but there is no analogous validation for sample_with_replacement. Since _get_sampler uses floor division when splitting sample_with_replacement across ranks, a non-divisible value will result in fewer total samples than requested. Consider adding a __post_init__ check that sample_with_replacement (when set) is divisible by dist.total_data_parallel_ranks, and raise a clear ValueError if not.
| ) | |
| ) | |
| if ( | |
| self.sample_with_replacement is not None | |
| and self.sample_with_replacement % dist.total_data_parallel_ranks != 0 | |
| ): | |
| raise ValueError( | |
| "sample_with_replacement must be divisible by the number of " | |
| "data-parallel workers, got " | |
| f"{self.sample_with_replacement} and " | |
| f"{dist.total_data_parallel_ranks}" | |
| ) |
21e75a0 to
02c7c3b
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 21 out of 21 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Spatial co-ranks (same data_parallel_rank) should have distinct | ||
| # lat/lon slices whose union covers the global coordinates exactly. | ||
| combined_lat = sorted(set(v for rank_lat in all_lats for v in rank_lat)) | ||
| combined_lon = sorted(set(v for rank_lon in all_lons for v in rank_lon)) | ||
| assert combined_lat == lat.tolist(), "lat coverage incomplete or has gaps" | ||
| assert combined_lon == lon.tolist(), "lon coverage incomplete or has gaps" |
There was a problem hiding this comment.
This test claims it verifies “full, non-overlapping coverage”, but it only checks set-union equality. Overlapping slices (duplicates) would still pass. Consider also asserting that, within each data-parallel group, slices are disjoint (or that the sum of local lengths equals the global length) to actually detect overlaps.
| shape = properties.horizontal_coordinates.shape | ||
| self._global_img_shape: tuple[int, int] = (shape[-2], shape[-1]) | ||
| self._properties = properties.to_device().localize() |
There was a problem hiding this comment.
The PR description in metadata appears to still be the template (placeholders like “Resolves #” and unchecked boxes). Please replace it with an accurate summary of the motivation, concrete change list, and test status so reviewers can validate scope and risk.
| ``SpatialParallelismNotImplemented`` when the distributed layout | ||
| requires slicing. |
There was a problem hiding this comment.
The docstring references SpatialParallelismNotImplemented, but that symbol isn’t imported in this module and isn’t re-exported from fme.core.distributed (only Distributed is). This makes the guidance hard to follow and can confuse users/type checkers. Consider importing the exception here (or fully-qualifying it in the docstring) and/or re-exporting it from fme.core.distributed if it’s intended to be part of the public API.
| ``SpatialParallelismNotImplemented`` when the distributed layout | |
| requires slicing. | |
| an appropriate exception when the distributed layout requires slicing. |
Co-authored-by: Jeremy McGibbon <mcgibbon@uw.edu>
Short description of why the PR is needed and how it satisfies those requirements, in sentence form.
Changes:
symbol (e.g.
fme.core.my_function) or script and concise description of changes or added featureCan group multiple related symbols on a single bullet
Tests added
If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated
Resolves # (delete if none)