-
Notifications
You must be signed in to change notification settings - Fork 0
test dataloading #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
test dataloading #10
Changes from all commits
8e5d89f
82b1bd4
02c7c3b
51be7e1
b28067d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -49,8 +49,13 @@ def _get_sampler( | |||||||||||||||||||||||||||||
| ) -> torch.utils.data.Sampler: | ||||||||||||||||||||||||||||||
| dist = Distributed.get_instance() | ||||||||||||||||||||||||||||||
| if sample_with_replacement_dataset_size is not None: | ||||||||||||||||||||||||||||||
| dist.require_no_spatial_parallelism( | ||||||||||||||||||||||||||||||
| "sample_with_replacement is not supported with spatial " | ||||||||||||||||||||||||||||||
| "parallelism. Spatial co-ranks would draw different samples, " | ||||||||||||||||||||||||||||||
| "producing corrupted data after scatter_spatial reassembly." | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
| ) | |
| ) | |
| if ( | |
| sample_with_replacement_dataset_size | |
| % dist.total_data_parallel_ranks | |
| != 0 | |
| ): | |
| raise ValueError( | |
| "sample_with_replacement_dataset_size " | |
| f"({sample_with_replacement_dataset_size}) must be divisible " | |
| "by the total number of data-parallel ranks " | |
| f"({dist.total_data_parallel_ranks}) when using " | |
| "sample_with_replacement." | |
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,7 +50,9 @@ def __init__( | |
| will be on the current device. | ||
| """ | ||
| self._loader = loader | ||
| self._properties = properties.to_device() | ||
| shape = properties.horizontal_coordinates.shape | ||
| self._global_img_shape: tuple[int, int] = (shape[-2], shape[-1]) | ||
| self._properties = properties.to_device().localize() | ||
|
Comment on lines
+53
to
+55
|
||
| self._timestep = self._properties.timestep | ||
| self._vertical_coordinate = self._properties.vertical_coordinate | ||
| self._mask_provider = self._properties.mask_provider | ||
|
|
@@ -72,7 +74,8 @@ def _get_gpu_loader( | |
| self, base_loader: DataLoader[BatchData] | ||
| ) -> DataLoader[BatchData]: | ||
| def modify_and_on_device(batch: BatchData) -> BatchData: | ||
| return self._modifier(batch).to_device() | ||
| batch = self._modifier(batch) | ||
| return batch.to_device().scatter_spatial(self._global_img_shape) | ||
|
|
||
| return SizedMap(modify_and_on_device, base_loader) | ||
|
|
||
|
|
@@ -174,21 +177,23 @@ def __init__( | |
| will be on the current device. | ||
| """ | ||
| self._loader = loader | ||
| self._properties = properties.to_device() | ||
| shape = properties.horizontal_coordinates.shape | ||
| self._global_img_shape: tuple[int, int] = (shape[-2], shape[-1]) | ||
| self._properties = properties.to_device().localize() | ||
| self._n_initial_conditions: int | None = None | ||
| if isinstance(initial_condition, PrognosticStateDataRequirements): | ||
| self._initial_condition: PrognosticState = get_initial_condition( | ||
| loader, initial_condition | ||
| self.loader, initial_condition | ||
| ) | ||
| else: | ||
| self._initial_condition = initial_condition.to_device() | ||
|
|
||
| @property | ||
| def loader(self) -> DataLoader[BatchData]: | ||
| def on_device(batch: BatchData) -> BatchData: | ||
| return batch.to_device() | ||
| def scatter_and_on_device(batch: BatchData) -> BatchData: | ||
| return batch.to_device().scatter_spatial(self._global_img_shape) | ||
|
|
||
| return SizedMap(on_device, self._loader) | ||
| return SizedMap(scatter_and_on_device, self._loader) | ||
|
|
||
| @property | ||
| def variable_metadata(self) -> dict[str, VariableMetadata]: | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -286,7 +286,7 @@ def _get_batch_data(self, index) -> BatchData: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sample_tuples = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for i_member in range(self._n_initial_conditions): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # check if sample is one this local rank should process | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if i_member % dist.world_size != dist.rank: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if i_member % dist.total_data_parallel_ranks != dist.data_parallel_rank: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| i_window_start = i_start + self._start_indices[i_member] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| i_window_end = i_window_start + self._forward_steps_in_memory + 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -334,7 +334,9 @@ def __getitem__(self, index) -> BatchData: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for key, value in self._persistence_data.data.items(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| updated_data[key] = value.expand_as(result.data[key]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result.data = {**result.data, **updated_data} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert result.time.shape[0] == self._n_initial_conditions // dist.world_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert result.time.shape[0] == ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._n_initial_conditions // dist.total_data_parallel_ranks | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+337
to
+339
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert result.time.shape[0] == ( | |
| self._n_initial_conditions // dist.total_data_parallel_ranks | |
| ) | |
| total_data_parallel_ranks = dist.total_data_parallel_ranks | |
| n_initial_conditions = self._n_initial_conditions | |
| if total_data_parallel_ranks <= 0: | |
| raise ValueError( | |
| "total_data_parallel_ranks must be a positive integer, " | |
| f"got {total_data_parallel_ranks}." | |
| ) | |
| if n_initial_conditions % total_data_parallel_ranks != 0: | |
| raise ValueError( | |
| "n_initial_conditions must be evenly divisible by " | |
| "total_data_parallel_ranks, but got " | |
| f"n_initial_conditions={n_initial_conditions} and " | |
| f"total_data_parallel_ranks={total_data_parallel_ranks}." | |
| ) | |
| expected_time_dim = n_initial_conditions // total_data_parallel_ranks | |
| if result.time.shape[0] != expected_time_dim: | |
| raise ValueError( | |
| "Unexpected time dimension in inference batch: " | |
| f"expected {expected_time_dim} based on " | |
| f"n_initial_conditions={n_initial_conditions} and " | |
| f"total_data_parallel_ranks={total_data_parallel_ranks}, " | |
| f"but got {result.time.shape[0]}." | |
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| """This file contains unit tests related to creating torch Datasets from climate | ||
| data (e.g. netCDF files).""" | ||
|
|
||
| import datetime | ||
| import math | ||
| import os | ||
|
Comment on lines
+4
to
6
|
||
| import pathlib | ||
|
|
@@ -30,14 +31,20 @@ | |
| ) | ||
| from fme.ace.data_loading.perturbation import PerturbationSelector, SSTPerturbation | ||
| from fme.ace.requirements import DataRequirements, PrognosticStateDataRequirements | ||
| from fme.core.coordinates import HybridSigmaPressureCoordinate | ||
| from fme.core.coordinates import ( | ||
| HybridSigmaPressureCoordinate, | ||
| LatLonCoordinates, | ||
| NullVerticalCoordinate, | ||
| ) | ||
| from fme.core.dataset.concat import ConcatDatasetConfig | ||
| from fme.core.dataset.data_typing import VariableMetadata | ||
| from fme.core.dataset.merged import MergeDatasetConfig, MergeNoConcatDatasetConfig | ||
| from fme.core.dataset.properties import DatasetProperties | ||
| from fme.core.dataset.schedule import IntMilestone, IntSchedule | ||
| from fme.core.dataset.xarray import XarrayDataConfig | ||
| from fme.core.device import using_gpu | ||
| from fme.core.distributed.distributed import Distributed | ||
| from fme.core.distributed.model_torch_distributed import ModelTorchDistributed | ||
| from fme.core.mask_provider import MaskProvider | ||
| from fme.core.testing.regression import validate_tensor_dict | ||
| from fme.core.typing_ import Slice | ||
|
|
||
|
|
@@ -203,10 +210,6 @@ def test_ensemble_loader_n_samples(tmp_path, num_ensemble_members=3, n_samples=1 | |
| def test_xarray_loader(tmp_path): | ||
| """Checks that vertical coordinates are present.""" | ||
| dist = Distributed.get_instance() | ||
| if isinstance(dist._distributed, ModelTorchDistributed): | ||
| pytest.xfail( | ||
| "ModelTorchDistributed slicing along spatial dimensions is not implemented." | ||
| ) | ||
| tmp_path = dist.scatter_object(tmp_path) # get the root value | ||
| global_batch_size = 24 | ||
| if dist.is_root(): | ||
|
|
@@ -1173,3 +1176,67 @@ def test_pinned_memory(tmp_path, time_buffer: int): | |
| for batch in loader: | ||
| tensor = next(iter(batch.data.values())) | ||
| assert tensor.is_pinned() is using_gpu() | ||
|
|
||
|
|
||
| @pytest.mark.parallel | ||
| def test_localize_properties(): | ||
| """Verify DatasetProperties.localize() partitions coords and masks across ranks.""" | ||
| dist = Distributed.get_instance() | ||
| n_lat, n_lon = N_LAT, N_LON | ||
| lat = torch.linspace(-90.0, 90.0, n_lat) | ||
| lon = torch.linspace(0.0, 360.0, n_lon) | ||
| coords = LatLonCoordinates(lat=lat, lon=lon) | ||
| mask_tensor = torch.arange(n_lat * n_lon, dtype=torch.float32).reshape( | ||
| 1, n_lat, n_lon | ||
| ) | ||
| mask_provider = MaskProvider(masks={"mask_test": mask_tensor}) | ||
| timestep = datetime.timedelta(hours=6) | ||
| metadata = {"temp": VariableMetadata(units="K", long_name="Temperature")} | ||
| vertical = NullVerticalCoordinate() | ||
| props = DatasetProperties( | ||
| variable_metadata=metadata, | ||
| vertical_coordinate=vertical, | ||
| horizontal_coordinates=coords, | ||
| mask_provider=mask_provider, | ||
| timestep=timestep, | ||
| is_remote=False, | ||
| all_labels=None, | ||
| ) | ||
|
|
||
| local = props.localize() | ||
|
|
||
| # Unchanged fields | ||
| assert local.variable_metadata is metadata | ||
| assert local.vertical_coordinate is vertical | ||
| assert local.timestep == timestep | ||
|
|
||
| # Gather local coordinates to root and verify full, non-overlapping coverage. | ||
| assert isinstance(local.horizontal_coordinates, LatLonCoordinates) | ||
| local_lat = local.horizontal_coordinates.lat.tolist() | ||
| local_lon = local.horizontal_coordinates.lon.tolist() | ||
| all_lats = dist.gather_object(local_lat) | ||
| all_lons = dist.gather_object(local_lon) | ||
| if dist.is_root(): | ||
| assert all_lats is not None and all_lons is not None | ||
| # Spatial co-ranks (same data_parallel_rank) should have distinct | ||
| # lat/lon slices whose union covers the global coordinates exactly. | ||
| combined_lat = sorted(set(v for rank_lat in all_lats for v in rank_lat)) | ||
| combined_lon = sorted(set(v for rank_lon in all_lons for v in rank_lon)) | ||
| assert combined_lat == lat.tolist(), "lat coverage incomplete or has gaps" | ||
| assert combined_lon == lon.tolist(), "lon coverage incomplete or has gaps" | ||
|
Comment on lines
+1221
to
+1226
|
||
|
|
||
| # Gather local masks to root and verify they tile to the global mask. | ||
| assert isinstance(local.mask_provider, MaskProvider) | ||
| local_mask = local.mask_provider.masks["mask_test"] | ||
| h_slice, w_slice = dist.get_local_slices((n_lat, n_lon)) | ||
| all_slices = dist.gather_object((h_slice, w_slice)) | ||
| all_masks = dist.gather_object(local_mask.tolist()) | ||
| if dist.is_root(): | ||
| assert all_slices is not None and all_masks is not None | ||
| canvas = torch.zeros_like(mask_tensor) | ||
| for (hs, ws), mask_data in zip(all_slices, all_masks): | ||
| canvas[..., hs, ws] += torch.tensor(mask_data) | ||
| # Each cell should be covered once per data-parallel rank (spatial | ||
| # co-ranks have distinct slices, data-parallel ranks have identical ones). | ||
| expected_count = dist.total_data_parallel_ranks | ||
| torch.testing.assert_close(canvas, mask_tensor * expected_count) | ||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |||||||
| from fme.core.corrector.registry import CorrectorABC | ||||||||
| from fme.core.derived_variables import compute_derived_quantities | ||||||||
| from fme.core.device import get_device | ||||||||
| from fme.core.distributed import Distributed | ||||||||
| from fme.core.gridded_ops import GriddedOperations, HEALPixOperations, LatLonOperations | ||||||||
| from fme.core.mask_provider import MaskProvider, MaskProviderABC, NullMaskProvider | ||||||||
| from fme.core.ocean_derived_variables import compute_ocean_derived_quantities | ||||||||
|
|
@@ -693,6 +694,17 @@ def meshgrid(self) -> tuple[torch.Tensor, torch.Tensor]: | |||||||
| def shape(self) -> tuple[int, ...]: | ||||||||
| pass | ||||||||
|
|
||||||||
| @abc.abstractmethod | ||||||||
| def localize(self: HC) -> HC: | ||||||||
| """Return a copy with coordinates sliced to the local spatial chunk. | ||||||||
|
|
||||||||
| Uses ``Distributed.get_instance()`` to determine the local slices. | ||||||||
| Coordinate types that do not support spatial parallelism should raise | ||||||||
| ``SpatialParallelismNotImplemented`` when the distributed layout | ||||||||
| requires slicing. | ||||||||
|
Comment on lines
+703
to
+704
|
||||||||
| ``SpatialParallelismNotImplemented`` when the distributed layout | |
| requires slicing. | |
| an appropriate exception when the distributed layout requires slicing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DataLoaderConfigvalidates thatbatch_sizeis divisible bytotal_data_parallel_ranks, but there is no analogous validation forsample_with_replacement. Since_get_sampleruses floor division when splittingsample_with_replacementacross ranks, a non-divisible value will result in fewer total samples than requested. Consider adding a__post_init__check thatsample_with_replacement(when set) is divisible bydist.total_data_parallel_ranks, and raise a clearValueErrorif not.