-
Notifications
You must be signed in to change notification settings - Fork 26
feat(dcp): dcp optimized s3reader for faster and partial DCP loading #378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
daef051 to
39853e4
Compare
39853e4 to
08a815a
Compare
- Update SequentialS3Reader to support partial reads (and added logs) - New ListOfRangesS3Reader - Coalesces ranges to form chunks of ranges - Manages ranged SequentialS3Reader instances for each chunk - Maps each read / readinto / seek request to each s3reader instance - Integrate this reader into S3StorageReader (force ListOfRangesS3Reader for now) via S3ReaderConstructor params for list of ranges.
Add DCPListOfRangesConstructor and dcp_list_of_ranges() factory method to enable DCP range optimization through reader_constructor parameter. Includes better range injection logic and support for both direct ListOfRanges usage and DCP optimization. Users can now opt-in via: reader_constructor=S3ReaderConstructor.dcp_list_of_ranges()
- type annotations, missing arguments / return statements, etc - minor logic/name changes in list_of_ranges.py - very minor change to fix mypy error on test_user_agent.py
This commit improves performance of ListOfRangesS3Reader by up to 30% for DCP load: - Remove dependency on SequentialS3Reader for self-managed streams - Implement direct stream management with per-group buffering - Optimize read() method with no BytesIO buffer assuming sequential reading - We now enforce non-seekable behaviour to force sequential reading patterns This implementation is now significantly faster for distributed checkpoint loading patterns while maintaining correctness for sequential access. This relies on load ordering optimisation which enforces sequential reading with read() operations, but will not work with readinto() operations since those still have backward seek patterns.
- Rename list of ranges / dcp list of ranges to DCP optimized - Allow max_gap_size to be passed through via S3ReaderConstructor Since the reader now does both list of ranges AND DCP optimisation by exploiting and requiring sequential access, I am renaming them to dcp optimized instead to better reflect its scope.
- With import changes for other files - and updated some missed renames in comments and docstrings
- Use 200MB (arbitrary value for now) as DEFAULT_MAX_GAP_SIZE - Place in dcp_optimized.py for single source of truth
- Add dcp_reader_constructor fixture for DCP tests - Update test_e2e_s3_file_system.py to use dcp_reader_constructor fixture - Update test_e2e_s3_storage_reader.py load ordering test to also cover dcop-optimized s3 reader
08a815a to
68165e6
Compare
…ibility Allows PyTorch versions pre-torch==2.7.0 to use our optimisations, which implicitly assumes the provided stream is seekable. Allows Python 3.8 (uses torch==2.4.1) tests to pass by allowing backward seeks within each LoadItem. This commit essentially offloads the PyTorch BytesIO logic to our reader, by reverting to seekable=True, reading each LoadItem into new internal BytesIO buffer, and handling the read/readinto calls with the extra offset. - Add item-based buffering with BytesIO for full seekability - Move streaming logic from read() into _stream_range_data() - Add _load_item_buffer() for on-demand item loading - Add _find_item_for_position() with fast paths (check current/next item first) - Rewrite read()/readinto() to use buffer operations using current_item_buffer - Remove seekable() -> False to enable PyTorch to seek S3Reader directly
Replace multi-stream cache with single active stream state.
…essages - Refactor core functions to remove redundancies - Add docstrings and error messages - Add input validation for read/readinto/seek methods
- Use filename instead of item_md.relative_path in prepare_local_plan - Use easydict for a cleaner approach
- Renaming to ItemRange since each represents the ranges of a ReadItem in PyTorch DCP LoadPlan - Minor docstring update / kwarg removals / comments
- Remove fragile S3 key parsing in both filename extractions - Only difference should be path/file/ will return "file" instead of ""
- Add range validation to detect overlapping and invalid ranges - Add extra information in error messages to help with potential failures - Add bounds checking in error messages to prevent IndexError - Fix read() method to properly reject None/negative sizes - Remove duplicate type check in seek() - Minor docstring / comment updates
| ) | ||
|
|
||
| if not isinstance(constructor, partial): | ||
| if isinstance(constructor, DCPOptimizedConstructor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here - this feels pretty janky to me. What's this used for? Just debugging or to actually do something based on it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
User agent - agree this still feels janky.
|
|
||
| # Skip ahead if behind target | ||
| if current_pos < item.start: | ||
| skip = min(item.start - current_pos, len(chunk)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic here isn't obvious
| """ | ||
| return self._position | ||
|
|
||
| def close(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should set _closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, aligns with BytesIO methods, but will want if _closed checks for all methods which might affect performance.
Alternatively do not use close() at all and let GC cleanup like other readers; GetObjectStream here does not seem closable unlike PutObjectStream anyways.
| item_range = self._item_ranges[self._current_item_idx] | ||
| local_pos = self._position - item_range.start | ||
|
|
||
| assert self._current_item_buffer is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should check that the read doesn't take us outside of the current item
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Proposing to add check within item_idx = self._find_item_for_position(self._position) that read doesn't exceed item range.
|
|
||
| assert self._current_item_buffer is not None | ||
| self._current_item_buffer.seek(local_pos) | ||
| bytes_read = self._current_item_buffer.readinto(buf) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, should verify the lengths involved here
| ValueError: If seeking to negative position or accessing previous items. | ||
| TypeError: If whence is not SEEK_SET or SEEK_CUR. | ||
| """ | ||
| if not isinstance(offset, int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no bounds checking here to make sure we stay within the current item
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to only put the checks in read/readinto and return the errors there.
Perhaps it's better to place the _find_item_for_position check into seek().
- Remove sort (since we pre-sorted in prepare_local_plan) - Place all imports to the top of s3_file_system - rename rannges as item_ranges for DCPOptimizedS3Reader - Add comments for human-readable error message construction - Update wrong test docstring after changing back to seekable - Minor Todo coments / typing / docstring changes
…structor - Simplify integration in S3StorageReader by moving logic to constructor - Add runtime_checkable decorator to protocols for isinstance checks - Add proper PyTorch DCP type annotations (ReadItem, MetadataIndex, _StorageInfo) - Some renames - S3ReaderConstructorProtocolWithSetRanges to DCPS3ReaderConstructorProtocol - set_ranges method to set_item_ranges_by_file for clarity - _file_ranges field to _item_ranges_by_file to match method name
…heck - Use Dict / List instead of dict / list - Remove redundant boolean check: if self._item_ranges_by_file...
- Reverts some changes to previous version in main - Removed 1 TODO which we did in previous commit
Move torch.distributed.checkpoint imports under TYPE_CHECKING to prevent importlib_metadata eerrors on Python 3.9 when DCP functionality is not used.
Description
DCPOptimizedS3Reader optimizes PyTorch Distributed Checkpoint (DCP) partial loading by 1/ exploiting sequential access patterns to avoid BytesIO buffer copy, and 2/ only fetching required byte ranges instead of entire objects. This can increase DCP loading performance ~10% to 30%, and even more when loading parts of the checkpoint.
dcp_optimized()as reader_constructorUsage:
Optimized for partial DCP loading where only specific items are needed from large distributed checkpoint files.
Additional context
Related items
Testing
By submitting this pull request, I confirm that my contribution is made under the terms of BSD 3-Clause License and I agree to the terms of the LICENSE.