Skip to content

test dataloading#10

Open
mahf708 wants to merge 5 commits intomainfrom
feature/model-parallel-load-data
Open

test dataloading#10
mahf708 wants to merge 5 commits intomainfrom
feature/model-parallel-load-data

Conversation

@mahf708
Copy link
Collaborator

@mahf708 mahf708 commented Mar 10, 2026

Short description of why the PR is needed and how it satisfies those requirements, in sentence form.

Changes:

  • symbol (e.g. fme.core.my_function) or script and concise description of changes or added feature

  • Can group multiple related symbols on a single bullet

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

Resolves # (delete if none)

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates ACE data loading to behave correctly under spatial/model parallelism by scattering spatial tensor chunks and by switching several data-parallel calculations from world_size/rank to the data-parallel group equivalents.

Changes:

  • Add spatial scattering of BatchData.data (and slicing of some dataset properties) in GriddedData / InferenceGriddedData.
  • Update inference sampling logic and batch-size validation to use total_data_parallel_ranks / data_parallel_rank.
  • Add a unit test covering dataset-properties scattering behavior.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
fme/ace/data_loading/gridded_data.py Introduces _scatter_properties and scatters spatial chunks in loaders to support spatial parallel backends.
fme/ace/data_loading/inference.py Uses data-parallel rank/count (vs global rank/count) for initial-condition sharding and assertions.
fme/ace/data_loading/getters.py Adjusts sample-with-replacement sizing to be per data-parallel rank count.
fme/ace/data_loading/config.py Validates batch size divisibility against data-parallel ranks and updates error message.
fme/ace/data_loading/test_data_loader.py Removes prior xfail and adds a new parallel test for _scatter_properties.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +30 to +40
def _scatter_properties(properties: DatasetProperties) -> DatasetProperties:
"""Slice horizontal coordinates and masks to the local spatial chunk."""
dist = Distributed.get_instance()
coords = properties.horizontal_coordinates
if isinstance(coords, LatLonCoordinates):
h_slice, w_slice = dist.get_local_slices(coords.shape)
coords = LatLonCoordinates(
lat=coords.lat[h_slice],
lon=coords.lon[w_slice],
)
mask_provider = properties.mask_provider
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_scatter_properties only slices LatLonCoordinates, but GriddedData/InferenceGriddedData now always scatter batch.data spatially under ModelTorchDistributed. For other HorizontalCoordinates types (e.g., HEALPixCoordinates), this will leave DatasetInfo.coords/img_shape inconsistent with the locally-sliced tensors and can break downstream consumers (e.g., writers/aggregators). Consider extending _scatter_properties to slice other coordinate types (at least HEALPix height/width) or explicitly raising a NotImplementedError when spatial parallelism is enabled for unsupported coordinate grids.

Copilot uses AI. Check for mistakes.
Comment on lines +4 to 6
import datetime
import math
import os
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description still contains the template placeholders (e.g., generic bullets / unchecked checklist). Please update it with the actual purpose of the change, key symbols impacted, and whether tests/dependency steps were done so reviewers can validate intent and scope.

Copilot uses AI. Check for mistakes.
@mahf708 mahf708 force-pushed the feature/model-parallel-load-data branch 2 times, most recently from 2e00f53 to 3937218 Compare March 11, 2026 01:43
@mahf708 mahf708 requested a review from Copilot March 11, 2026 01:44
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated no new comments.

Comments suppressed due to low confidence (1)

fme/ace/data_loading/getters.py:59

  • When sample_with_replacement is set, this uses RandomSampler, which is not distributed-aware: each process will sample independently, so data-parallel ranks won’t get disjoint samples, and under spatial model-parallelism spatial ranks may even see different samples (breaking the expectation that spatial ranks share the same batch). Consider using a distributed sampler-with-replacement approach (e.g., seed a torch.Generator by data_parallel_rank/epoch or implement a DistributedSampler variant with replacement) so sampling is consistent across spatial ranks but distinct across data-parallel ranks.
    dist = Distributed.get_instance()
    if sample_with_replacement_dataset_size is not None:
        local_sample_with_replacement_dataset_size = (
            sample_with_replacement_dataset_size // dist.total_data_parallel_ranks
        )
        sampler = torch.utils.data.RandomSampler(
            dataset,
            num_samples=local_sample_with_replacement_dataset_size,
            replacement=True,
        )

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@mahf708 mahf708 force-pushed the feature/model-parallel-load-data branch from 3937218 to d6ce445 Compare March 11, 2026 01:56
@mahf708 mahf708 requested a review from Copilot March 11, 2026 14:30
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +337 to +339
assert result.time.shape[0] == (
self._n_initial_conditions // dist.total_data_parallel_ranks
)
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__getitem__ uses an assert to enforce that n_initial_conditions is evenly divisible by dist.total_data_parallel_ranks. Assertions can be disabled with Python optimization flags, and this will currently fail late (or be skipped) instead of raising a clear, actionable exception. Please replace this with an explicit validation (e.g., raise ValueError with a message that includes n_initial_conditions and total_data_parallel_ranks), ideally performed once during initialization rather than per-sample.

Suggested change
assert result.time.shape[0] == (
self._n_initial_conditions // dist.total_data_parallel_ranks
)
total_data_parallel_ranks = dist.total_data_parallel_ranks
n_initial_conditions = self._n_initial_conditions
if total_data_parallel_ranks <= 0:
raise ValueError(
"total_data_parallel_ranks must be a positive integer, "
f"got {total_data_parallel_ranks}."
)
if n_initial_conditions % total_data_parallel_ranks != 0:
raise ValueError(
"n_initial_conditions must be evenly divisible by "
"total_data_parallel_ranks, but got "
f"n_initial_conditions={n_initial_conditions} and "
f"total_data_parallel_ranks={total_data_parallel_ranks}."
)
expected_time_dim = n_initial_conditions // total_data_parallel_ranks
if result.time.shape[0] != expected_time_dim:
raise ValueError(
"Unexpected time dimension in inference batch: "
f"expected {expected_time_dim} based on "
f"n_initial_conditions={n_initial_conditions} and "
f"total_data_parallel_ranks={total_data_parallel_ranks}, "
f"but got {result.time.shape[0]}."
)

Copilot uses AI. Check for mistakes.
"sample_with_replacement is not supported with spatial "
"parallelism. Spatial co-ranks would draw different samples, "
"producing corrupted data after scatter_spatial reassembly."
)
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sample_with_replacement_dataset_size is split across ranks using integer floor division. If the requested size is not divisible by dist.total_data_parallel_ranks, the total number of samples drawn across all ranks will be smaller than the configured value. Consider validating divisibility and raising a ValueError (or otherwise documenting/handling the remainder) so sample_with_replacement behaves predictably in distributed runs.

Suggested change
)
)
if (
sample_with_replacement_dataset_size
% dist.total_data_parallel_ranks
!= 0
):
raise ValueError(
"sample_with_replacement_dataset_size "
f"({sample_with_replacement_dataset_size}) must be divisible "
"by the total number of data-parallel ranks "
f"({dist.total_data_parallel_ranks}) when using "
"sample_with_replacement."
)

Copilot uses AI. Check for mistakes.
"batch_size must be divisible by the number of data-parallel "
f"workers, got {self.batch_size} and "
f"{dist.total_data_parallel_ranks}"
)
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataLoaderConfig validates that batch_size is divisible by total_data_parallel_ranks, but there is no analogous validation for sample_with_replacement. Since _get_sampler uses floor division when splitting sample_with_replacement across ranks, a non-divisible value will result in fewer total samples than requested. Consider adding a __post_init__ check that sample_with_replacement (when set) is divisible by dist.total_data_parallel_ranks, and raise a clear ValueError if not.

Suggested change
)
)
if (
self.sample_with_replacement is not None
and self.sample_with_replacement % dist.total_data_parallel_ranks != 0
):
raise ValueError(
"sample_with_replacement must be divisible by the number of "
"data-parallel workers, got "
f"{self.sample_with_replacement} and "
f"{dist.total_data_parallel_ranks}"
)

Copilot uses AI. Check for mistakes.
@mahf708 mahf708 force-pushed the feature/model-parallel-load-data branch from 21e75a0 to 02c7c3b Compare March 11, 2026 16:17
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 21 out of 21 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1221 to +1226
# Spatial co-ranks (same data_parallel_rank) should have distinct
# lat/lon slices whose union covers the global coordinates exactly.
combined_lat = sorted(set(v for rank_lat in all_lats for v in rank_lat))
combined_lon = sorted(set(v for rank_lon in all_lons for v in rank_lon))
assert combined_lat == lat.tolist(), "lat coverage incomplete or has gaps"
assert combined_lon == lon.tolist(), "lon coverage incomplete or has gaps"
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test claims it verifies “full, non-overlapping coverage”, but it only checks set-union equality. Overlapping slices (duplicates) would still pass. Consider also asserting that, within each data-parallel group, slices are disjoint (or that the sum of local lengths equals the global length) to actually detect overlaps.

Copilot uses AI. Check for mistakes.
Comment on lines +53 to +55
shape = properties.horizontal_coordinates.shape
self._global_img_shape: tuple[int, int] = (shape[-2], shape[-1])
self._properties = properties.to_device().localize()
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description in metadata appears to still be the template (placeholders like “Resolves #” and unchecked boxes). Please replace it with an accurate summary of the motivation, concrete change list, and test status so reviewers can validate scope and risk.

Copilot uses AI. Check for mistakes.
Comment on lines +701 to +702
``SpatialParallelismNotImplemented`` when the distributed layout
requires slicing.
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring references SpatialParallelismNotImplemented, but that symbol isn’t imported in this module and isn’t re-exported from fme.core.distributed (only Distributed is). This makes the guidance hard to follow and can confuse users/type checkers. Consider importing the exception here (or fully-qualifying it in the docstring) and/or re-exporting it from fme.core.distributed if it’s intended to be part of the public API.

Suggested change
``SpatialParallelismNotImplemented`` when the distributed layout
requires slicing.
an appropriate exception when the distributed layout requires slicing.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants