diff --git a/CHANGELOG.md b/CHANGELOG.md index 90dfc9aa..c578c56e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ ## TBD +### New features +* Add DCP-optimized s3reader for 2x faster and partial DCP loading (#378) + ### Bug fixes * Override S3Writer closed property and block writes after close (#360) * Fix SequentialS3Reader seek beyond EOF to clamp position to object size (#362) diff --git a/README.md b/README.md index f4f0bc1e..720f415c 100644 --- a/README.md +++ b/README.md @@ -128,7 +128,9 @@ Amazon S3 Connector for PyTorch provides robust support for PyTorch distributed - `S3StorageWriter`: Implementation of PyTorch's StorageWriter interface. -- `S3StorageReader`: Implementation of PyTorch's StorageReader interface. Supports configurable reading strategies via the `reader_constructor` parameter (see [Reader Configurations](#reader-configurations)). +- `S3StorageReader`: Implementation of PyTorch's StorageReader interface. + - Supports configurable reading strategies via the `reader_constructor` parameter (see [Reader Configurations](#reader-configurations)). + - `S3ReaderConstructor.dcp_optimized()` is recommended for up to 2x faster loading with partial checkpoint optimizations. - `S3FileSystem`: An implementation of PyTorch's FileSystemBase. These tools enable seamless integration of Amazon S3 with @@ -151,6 +153,7 @@ can be found in the [examples/dcp](https://github.com/awslabs/s3-connector-for-p ```py from s3torchconnector.dcp import S3StorageWriter, S3StorageReader +from s3torchconnector import S3ReaderConstructor import torchvision import torch.distributed.checkpoint as DCP @@ -175,7 +178,13 @@ DCP.save( # Load distributed checkpoint from S3 model = torchvision.models.resnet18() model_state_dict = model.state_dict() -s3_storage_reader = S3StorageReader(region=REGION, path=CHECKPOINT_URI) +# Use DCP-optimized reader for faster loading +reader_constructor = S3ReaderConstructor.dcp_optimized() +s3_storage_reader = S3StorageReader( + region=REGION, + path=CHECKPOINT_URI, + reader_constructor=reader_constructor, # optional; constructor for S3Reader types +) DCP.load( state_dict=model_state_dict, storage_reader=s3_storage_reader, @@ -409,7 +418,7 @@ data = s3reader.read() ## Reader Configurations -Amazon S3 Connector for PyTorch supports two types of readers, configurable through `S3ReaderConstructor`. +Amazon S3 Connector for PyTorch supports three types of readers, configurable through `S3ReaderConstructor`. ### Reader Types @@ -420,21 +429,32 @@ Amazon S3 Connector for PyTorch supports two types of readers, configurable thro #### 2. Range-based Reader -- Performs byte-range requests to read specific portions of S3 objects without downloading the entire file. -- Prioritizes memory efficiency, with performance gains only for sparse partial reads. +- Performs byte-range requests to read specific portions of S3 objects without downloading the entire object. +- Prioritizes memory efficiency, with performance gains only for sparse partial reads in large objects. - Features adaptive buffering with forward overlap handling: - **Small reads** (< `buffer_size`): Use internal buffer to reduce S3 API calls. - **Large reads** (≥ `buffer_size`): Bypass buffer for direct transfer. +#### 3. DCP-Optimized Reader (DCP only) + +- Specialized usage for PyTorch Distributed Checkpoint (DCP) loading. +- Provides up to 2x performance improvement through zero-copy buffers and sequential access patterns. +- Enables efficient partial checkpoint loading (e.g. model-only) through range-based streams and range coalescing. +- Automatically handles range metadata injection from DCP load plan. +- Requires sequential access patterns (automatically enforced in `S3StorageReader.prepare_local_plan()`) + ### When to Use Each Reader -- **Sequential Reader**: For processing entire files, and when repeated access to the data is required. Best for most general use cases. +- **Sequential Reader**: For processing entire objects, and when repeated access to the data is required. Best for most general use cases. - **Range-based Reader**: For larger objects (100MB+) that require sparse partial reads, and in memory-constrained environments. +- **DCP-Optimized Reader**: For typical PyTorch Distributed Checkpoint loading scenarios for highest performance and memory-efficiency. **Note**: S3Reader instances are not thread-safe and should not be shared across threads. For multiprocessing with DataLoader, each worker process creates its own S3Reader instance automatically. ### Examples +For `S3ReaderConstructor` usage details, please refer to the [`S3ReaderConstructor` documentation](https://awslabs.github.io/s3-connector-for-pytorch/autoapi/s3torchconnector/s3reader/constructor/index.html). Below are some examples for `S3ReaderConstructor` usage. + Direct method - `S3Client` usage with range-based reader without buffer: ```py # Direct S3Client usage for zero-copy partial reads into pre-allocated buffers, for memory efficiency and fast data transfer @@ -456,15 +476,13 @@ s3reader.seek(100 * 1024 * 1024) # Skip to 100MB offset bytes_read = s3reader.readinto(buffer) # Direct read into buffer ``` -DCP interface - `S3StorageReader` usage with range-based reader with buffer: +DCP interface - `S3StorageReader` usage with dcp-optimized reader: ```py -# Load distributed checkpoint with range-based reader to optimize memory usage for large checkpoint files +# Load checkpoint with dcp-optimized reader for better performance from s3torchconnector.dcp import S3StorageReader from s3torchconnector import S3ReaderConstructor -reader_constructor = S3ReaderConstructor.range_based( - buffer_size=16*1024*1024 # 16MB buffer -) +reader_constructor = S3ReaderConstructor.dcp_optimized() s3_storage_reader = S3StorageReader( region=REGION, path=CHECKPOINT_URI, @@ -492,7 +510,6 @@ for item in dataset: ... ``` -For `S3ReaderConstructor` usage details, please refer to the [`S3ReaderConstructor` documentation](https://awslabs.github.io/s3-connector-for-pytorch/autoapi/s3torchconnector/s3reader/constructor/index.html). ## Contributing diff --git a/s3torchconnector/pyproject.toml b/s3torchconnector/pyproject.toml index 82a8d98e..dbe5274a 100644 --- a/s3torchconnector/pyproject.toml +++ b/s3torchconnector/pyproject.toml @@ -34,7 +34,8 @@ test = [ "hypothesis", "flake8", "black", - "mypy" + "mypy", + "importlib_metadata; python_version == '3.9'", # PyTorch 2.7.0+ DCP w/ Python 3.9 requires this module; for dcp_optimized reader unit tests ] e2e = [ @@ -59,12 +60,12 @@ lightning-tests = [ dcp = [ "tenacity", "torch >= 2.3, != 2.5.0", + "importlib_metadata; python_version == '3.9'", # PyTorch 2.7.0+ DCP w/ Python 3.9 requires this module ] dcp-test = [ "s3torchconnector[dcp]", "pytest", - "importlib_metadata; python_version == '3.9'", ] [tool.setuptools.packages] diff --git a/s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py b/s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py index 3759e5a5..db2e2914 100644 --- a/s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py +++ b/s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py @@ -1,15 +1,17 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # // SPDX-License-Identifier: BSD +import dataclasses import io import logging import os import urllib.parse from contextlib import contextmanager +from dataclasses import dataclass from pathlib import Path -from typing import Generator, Union, Optional -from typing import List +from typing import Generator, Union, Optional, List +import torch from s3torchconnectorclient._mountpoint_s3_client import S3Exception from tenacity import ( retry, @@ -24,11 +26,16 @@ FileSystemWriter, FileSystemBase, ) -import torch +from torch.distributed.checkpoint.planner import SavePlan, LoadPlan + from s3torchconnector._s3client import S3Client from s3torchconnector._s3dataset_common import parse_s3_uri -from ..s3reader import S3ReaderConstructor, S3ReaderConstructorProtocol +from ..s3reader import ( + S3ReaderConstructor, + S3ReaderConstructorProtocol, + DCPS3ReaderConstructorProtocol, +) from .. import S3ClientConfig from .s3_prefix_strategy import S3PrefixStrategyBase, DefaultPrefixStrategy from .._user_agent import UserAgent @@ -37,6 +44,8 @@ class S3FileSystem(FileSystemBase): + """S3-based implementation of PyTorch's FileSystemBase for distributed checkpointing.""" + def __init__( self, region: str, @@ -252,11 +261,6 @@ def _escape_path(string): return "/".join(parts) -from torch.distributed.checkpoint.planner import SavePlan, LoadPlan -import dataclasses -from dataclasses import dataclass - - @dataclass class StorageMetadata: """Metadata for S3 storage prefix.""" @@ -265,6 +269,8 @@ class StorageMetadata: class S3StorageWriter(FileSystemWriter): + """S3 implementation of PyTorch's FileSystemWriter for distributed checkpoints.""" + def __init__( self, region: str, @@ -319,12 +325,16 @@ def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: class S3StorageReader(FileSystemReader): + """S3 implementation of PyTorch's FileSystemReader with configurable reader strategies.""" + def __init__( self, region: str, path: Union[str, os.PathLike], s3client_config: Optional[S3ClientConfig] = None, - reader_constructor: Optional[S3ReaderConstructorProtocol] = None, + reader_constructor: Optional[ + Union[S3ReaderConstructorProtocol, DCPS3ReaderConstructorProtocol] + ] = None, ) -> None: """ Initialize an S3 reader for distributed checkpointing. @@ -337,7 +347,12 @@ def __init__( e.g. S3ReaderConstructor.sequential() or S3ReaderConstructor.range_based() """ super().__init__(path) - self.fs = S3FileSystem(region, s3client_config=s3client_config, reader_constructor=reader_constructor) # type: ignore + self._reader_constructor = reader_constructor or S3ReaderConstructor.default() + self.fs: S3FileSystem = S3FileSystem( # type: ignore[assignment] + region, + s3client_config=s3client_config, + reader_constructor=self._reader_constructor, + ) self.path = self.fs.init_path(path) self.sync_files = False @@ -347,16 +362,31 @@ def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: """ - Sort load items by storage offset for sequential access optimization. + Performs two key optimizations: + + 1. **Load Ordering**: Sorts load items by storage offset to enable sequential access + + 2. **Range Injection**: Provides byte range metadata to DCP reader constructors to enable + usage of DCPOptimizedS3Reader for range-based streams and range coalescing Args: plan (LoadPlan): The load plan from PyTorch DCP. Returns: LoadPlan: The same plan with items sorted by storage offset. + + Note: + Both optimizations are required for DCPOptimizedS3Reader. """ # Sort items in plan based on their offset in checkpoints shards plan.items.sort(key=lambda item: self.storage_data[item.storage_index].offset) + + # Inject ranges if using DCP optimized reader constructor + if isinstance(self._reader_constructor, DCPS3ReaderConstructorProtocol): + self._reader_constructor.set_item_ranges_by_file( + plan.items, self.storage_data + ) + return plan diff --git a/s3torchconnector/src/s3torchconnector/s3reader/__init__.py b/s3torchconnector/src/s3torchconnector/s3reader/__init__.py index 2a8829dd..98d528f5 100644 --- a/s3torchconnector/src/s3torchconnector/s3reader/__init__.py +++ b/s3torchconnector/src/s3torchconnector/s3reader/__init__.py @@ -2,14 +2,20 @@ # // SPDX-License-Identifier: BSD from .s3reader import S3Reader -from .constructor import S3ReaderConstructor +from .constructor import S3ReaderConstructor, DCPOptimizedConstructor from .sequential import SequentialS3Reader from .ranged import RangedS3Reader -from .protocol import GetStreamCallable, S3ReaderConstructorProtocol +from .dcp_optimized import DCPOptimizedS3Reader, ItemRange, RangeGroup +from .protocol import ( + GetStreamCallable, + S3ReaderConstructorProtocol, + DCPS3ReaderConstructorProtocol, +) __all__ = [ "S3Reader", "S3ReaderConstructor", "SequentialS3Reader", "RangedS3Reader", + "DCPOptimizedS3Reader", ] diff --git a/s3torchconnector/src/s3torchconnector/s3reader/constructor.py b/s3torchconnector/src/s3torchconnector/s3reader/constructor.py index 703f6aee..4e4cca23 100644 --- a/s3torchconnector/src/s3torchconnector/s3reader/constructor.py +++ b/s3torchconnector/src/s3torchconnector/s3reader/constructor.py @@ -1,12 +1,71 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # // SPDX-License-Identifier: BSD +import logging from functools import partial -from typing import Optional - -from .protocol import S3ReaderConstructorProtocol +from typing import TYPE_CHECKING, Optional, List, Dict, Union +from collections import defaultdict + +from .s3reader import S3Reader +from .protocol import ( + S3ReaderConstructorProtocol, + DCPS3ReaderConstructorProtocol, +) from .sequential import SequentialS3Reader from .ranged import RangedS3Reader +from .dcp_optimized import DCPOptimizedS3Reader, ItemRange, DEFAULT_MAX_GAP_SIZE + +if TYPE_CHECKING: + from torch.distributed.checkpoint.planner import ReadItem + from torch.distributed.checkpoint.metadata import MetadataIndex + from torch.distributed.checkpoint.filesystem import _StorageInfo + +log = logging.getLogger(__name__) + + +class DCPOptimizedConstructor: + def __init__(self, max_gap_size: Union[int, float] = DEFAULT_MAX_GAP_SIZE) -> None: + + if max_gap_size < 0: + raise ValueError("max_gap_size must be non-negative") + + self._item_ranges_by_file: Dict[str, List[ItemRange]] = {} + self._max_gap_size: Union[int, float] = max_gap_size + + def set_item_ranges_by_file( + self, + plan_items: "List[ReadItem]", + storage_data: "Dict[MetadataIndex, _StorageInfo]", + ) -> None: + + if not plan_items: + return # Allow lack of plan_items, for SequentialS3Reader fallbacks + + self._item_ranges_by_file = defaultdict(list) + for read_item in plan_items: + item_md = storage_data[read_item.storage_index] + self._item_ranges_by_file[item_md.relative_path].append( + ItemRange(item_md.offset, item_md.offset + item_md.length) + ) + + def __call__(self, bucket: str, key: str, get_object_info, get_stream) -> S3Reader: + for relative_path in self._item_ranges_by_file.keys(): + if key.endswith(relative_path): + return DCPOptimizedS3Reader( + bucket, + key, + item_ranges=self._item_ranges_by_file[relative_path], + get_object_info=get_object_info, + get_stream=get_stream, + max_gap_size=self._max_gap_size, + ) + + # Fallback if file_ranges unavailable (e.g. when reading .metadata) + # TODO: Warn users for fallbacks for non-'.metadata' files? + log.debug( + f"DCPOptimizedConstructor: No ranges found for {key}, falling back to SequentialS3Reader" + ) + return SequentialS3Reader(bucket, key, get_object_info, get_stream) class S3ReaderConstructor: @@ -78,6 +137,49 @@ def range_based(buffer_size: Optional[int] = None) -> S3ReaderConstructorProtoco """ return partial(RangedS3Reader, buffer_size=buffer_size) + @staticmethod + def dcp_optimized( + max_gap_size: Union[int, float] = DEFAULT_MAX_GAP_SIZE, + ) -> DCPS3ReaderConstructorProtocol: + """Creates a constructor for DCP-optimized readers for faster checkpoint loading. + + The DCP-optimized reader provides up to 2x performance improvement over the default sequential reader through: + + - Zero-copy buffer management by storing data as memoryview segments + - Sequential access optimization to reduce buffer sizes from file-level to item-level + - Range-based fetching that downloads only required byte ranges and coalesces nearby ranges to reduce S3 request latency + + Args: + max_gap_size: Maximum gap size in bytes between ranges to coalesce into the same S3 read stream. + Most users should use the default value. + + - Default: 32MB (``32 * 1024 * 1024``) + - Use ``float("inf")`` to coalesce all ranges regardless of gaps + - Use 0 to disable coalescing, which creates a new range-based stream for each gap + + Returns: + DCPOptimizedConstructorProtocol: + Constructor that creates DCPOptimizedS3Reader when ranges are available, falling back to + SequentialS3Reader otherwise. + + Requirements: + Should be used with S3StorageReader, in which ``prepare_local_plan()`` automatically handles: + + - Load ordering: Sorts items by storage offset for sequential access + - Range injection: Provides byte ranges from DCP load plan to the reader + + Advanced users implementing custom readers must include these optimizations + in their ``prepare_local_plan()``/``read_data()`` implementation to use the DCP-optimized reader. + + Example:: + + reader_constructor = S3ReaderConstructor.dcp_optimized() + storage_reader = S3StorageReader(region, path, reader_constructor=reader_constructor) + DCP.load(state_dict, storage_reader=storage_reader) + + """ + return DCPOptimizedConstructor(max_gap_size=max_gap_size) + @staticmethod def default() -> S3ReaderConstructorProtocol: """Creates default reader constructor (sequential) @@ -97,10 +199,11 @@ def get_reader_type_string( S3ReaderConstructor.default() ) - if not isinstance(constructor, partial): + if isinstance(constructor, DCPOptimizedConstructor): + return "dcp_optimized" + elif not isinstance(constructor, partial): return "unknown" - - if constructor.func == RangedS3Reader: + elif constructor.func == RangedS3Reader: return "range_based" elif constructor.func == SequentialS3Reader: return "sequential" diff --git a/s3torchconnector/src/s3torchconnector/s3reader/dcp_optimized.py b/s3torchconnector/src/s3torchconnector/s3reader/dcp_optimized.py new file mode 100644 index 00000000..4f62ad0b --- /dev/null +++ b/s3torchconnector/src/s3torchconnector/s3reader/dcp_optimized.py @@ -0,0 +1,588 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# // SPDX-License-Identifier: BSD + +import bisect +import logging +from dataclasses import dataclass +from typing import List, Optional, Callable, Union, Iterator, Dict, cast +from io import SEEK_SET, SEEK_CUR, SEEK_END + +from s3torchconnectorclient._mountpoint_s3_client import ( + ObjectInfo, + GetObjectStream, + HeadObjectResult, +) +from .s3reader import S3Reader + +log = logging.getLogger(__name__) + +DEFAULT_MAX_GAP_SIZE = 32 * 1024 * 1024 # TODO tune this default +FIND_ITEM_ERROR_PREFIX = ( + "DCPOptimizedS3Reader only supports sequentially accessing provided ranges: " +) + + +@dataclass +class ItemRange: + """Byte range for a ReadItem; Inclusive start, exclusive end""" + + start: int + end: int + + +@dataclass +class RangeGroup: + start: int + end: int + item_ranges: List[ItemRange] + + +# TODO: extend buffer for use in other S3Reader implementations after extensive testing +class _ItemViewBuffer: + """ + A tiny, zero-copy, read-only buffer built from multiple memoryview segments. + Replaces io.BytesIO which involved extra copies for creation and buffer growth. + """ + + __slots__ = ("_segments", "_offsets", "_lengths", "_size", "_pos", "_closed") + + def __init__(self) -> None: + self._segments: List[memoryview] = [] # memoryview segments + self._offsets: List[int] = [] # start offset (within the item) of each segment + self._lengths: List[int] = [] # length of each segment + self._size: int = 0 # total item length (sum of _lengths) + self._pos: int = 0 # current read position within the item + self._closed: bool = False + + def append_view(self, view: memoryview) -> None: + """Append a memoryview segment (ignored if empty).""" + assert not self._closed, "Buffer is closed" + + seg_len = len(view) + if seg_len == 0: + return + self._segments.append(view) + self._offsets.append(self._size) + self._lengths.append(seg_len) + self._size += seg_len + + def close(self) -> None: + if not self._closed: + self._closed = True + self._segments.clear() + + def seek(self, offset: int, whence: int = SEEK_SET, /) -> int: + assert isinstance(offset, int), f"integer expected, got {type(offset)!r}" + + if whence == SEEK_SET: + new_pos = offset + elif whence == SEEK_CUR: + new_pos = self._pos + offset + elif whence == SEEK_END: + new_pos = self._size + offset + else: + raise ValueError( + "Seek must be passed io SEEK_CUR, SEEK_SET, or SEEK_END integers" + ) + + assert new_pos >= 0, f"negative seek value {new_pos}" + + # Seeking past EOF is allowed. + self._pos = new_pos + return self._pos + + def tell(self) -> int: + """Return the current pos position (like BytesIO.tell).""" + return self._pos + + def read(self, size: Optional[int] = None) -> bytes: + assert size is not None, "Size cannot be None; full read is not supported" + assert size >= 0, "Size cannot be negative; full read is not supported" + + # Fast path for read(4) at pos=0 (Optimizes pytorch/torch/serialization.py _is_zipfile()) + if size == 4 and self._pos == 0 and self._lengths and self._lengths[0] >= 4: + self._pos = 4 + # TODO: eliminating bytes() conversion can save ~3% time? Requires interface changes. + return bytes(self._segments[0][:4]) + + if size == 0: + return b"" + + # Pass implementation to readinto() + out = bytearray(size) + n = self.readinto(out) + return bytes(out) if n == size else memoryview(out)[:n].tobytes() + + def readinto(self, buf) -> int: + dest = buf if isinstance(buf, memoryview) else memoryview(buf) + assert not dest.readonly, "writable buffer required" + + dest_len = len(dest) + size = self._size + pos = self._pos + + if dest_len == 0 or pos >= size: + return 0 + + # Cache to avoid repeated attribute calls + segments = self._segments + offsets = self._offsets + lengths = self._lengths + + # Starting segment idx: last i where _offsets[i] <= _pos + seg_idx = bisect.bisect_right(offsets, pos) - 1 + if seg_idx < 0: + seg_idx = 0 + + written = 0 + bytes_to_read = min(dest_len, size - pos) + + # Copy from segments to dest + while written < bytes_to_read: + seg_start = offsets[seg_idx] + seg_len = lengths[seg_idx] + seg = segments[seg_idx] + + # Account for first chunk when pos > seg_start + offset_in_seg = pos - seg_start + + # Account for last chunk when bytes_to_read < seg_len + available_in_seg = seg_len - offset_in_seg + bytes_left_to_read = bytes_to_read - written + copy_size = min(bytes_left_to_read, available_in_seg) + + dest[written : written + copy_size] = seg[ + offset_in_seg : offset_in_seg + copy_size + ] + + written += copy_size + pos += copy_size + seg_idx += 1 + + self._pos += written + return written + + +class DCPOptimizedS3Reader(S3Reader): + """S3 reader implementation optimized for PyTorch Distributed Checkpoint (DCP) loading. + + Provides up to 2x performance improvement over default sequential reader through: + + 1. **Zero-Copy Buffer**: Custom ``_ItemViewBuffer`` storing data as memoryview + segments to eliminate BytesIO allocation and copy overhead. + + 2. **Sequential Access Optimization**: Exploits sequential access patterns over tensor + enforced by ``S3StorageReader.prepare_local_plan()`` to reduce buffer sizes from file-level to + item-level. + + 3. **Range-based fetching**: For partial checkpoint loading, uses load plan item ranges information + to group nearby byte ranges within ``max_gap_size`` to minimize S3 first byte latency (compared to + range-based reader), while only fetching required byte ranges instead of entire files + (compared to sequential reader). + + **Requirements**: + + - DCP Loading - reader is only designed for usage via dcp_optimized reader_constructor for ``dcp.load()`` + - Pre-sorted list of item_ranges, injected automatically in ``prepare_local_plan``. + - Sequential Access over exact item_ranges provided, also applied automatically by ``prepare_local_plan`` + + **Usage**: + Typically created automatically by ``DCPOptimizedConstructor`` when used with ``S3StorageReader`` and + ``S3ReaderConstructor.dcp_optimized()``: + + reader_constructor = S3ReaderConstructor.dcp_optimized(max_gap_size=32*1024*1024) + storage_reader = S3StorageReader(region, path, reader_constructor=reader_constructor) + DCP.load(state_dict, storage_reader=storage_reader) + + **Error Handling**: + Non-sequential access attempts raise ValueError with descriptive messages. + """ + + def __init__( + self, + bucket: str, + key: str, + item_ranges: List[ItemRange], + get_object_info: Callable[[], Union[ObjectInfo, HeadObjectResult]], + get_stream: Callable[[Optional[int], Optional[int]], GetObjectStream], + max_gap_size: Union[int, float] = DEFAULT_MAX_GAP_SIZE, + ): + if not bucket: + raise ValueError("Bucket should be specified") + if not key: + raise ValueError("Key should be specified") + if not item_ranges: + raise ValueError("item_ranges must be a non-empty List[ItemRange] object") + if not isinstance(max_gap_size, (int, float)): + raise TypeError( + f"max_gap_size must be int or float, got {type(max_gap_size).__name__}" + ) + if max_gap_size < 0: + raise ValueError("max_gap_size must be non-negative") + + self._bucket = bucket + self._key = key + self._get_object_info = get_object_info + self._get_stream = get_stream + self._max_gap_size = max_gap_size + self._closed = False + + # Filter zero-length ranges + self._item_ranges: List[ItemRange] = [ + r for r in item_ranges if r.end != r.start + ] + if not self._item_ranges: + raise ValueError("No non-empty ranges to read (all ranges were length 0)") + + # Coalesce ranges into range groups + self._group_start_to_group: Dict[int, RangeGroup] = ( + {} + ) # Group lookup using group start offset. for first item in each grou; populated below + self._range_groups: List[RangeGroup] = self._validate_and_coalesce_ranges( + self._item_ranges, self._max_gap_size + ) + + # Stream state + self._stream: Optional[GetObjectStream] = None + self._stream_pos: int = -1 # position at head of stream - dummy int + self._leftover: Optional[memoryview] = None + + # Item buffer state + self._item_iter: Iterator[ItemRange] = iter(self._item_ranges) + self._current_item: ItemRange = next(self._item_iter) + self._current_item_buffer: Optional[_ItemViewBuffer] = None + + self._position: int = 0 + + @property + def bucket(self) -> str: + return self._bucket + + @property + def key(self) -> str: + return self._key + + @property + def closed(self) -> bool: + """ + Returns: + bool: Return whether the object is closed. + """ + return self._closed + + def _validate_and_coalesce_ranges( + self, + ranges: List[ItemRange], + max_gap_size: Union[int, float], + ) -> List[RangeGroup]: + """ + This method: + 1. Validates ranges are valid, sorted, and non-overlapping. + 2. Coalesces nearby ItemRanges within max_gap_size into RangeGroups. + """ + if not ranges: + return [] + + groups: List[RangeGroup] = [] + items: List[ItemRange] = [ranges[0]] + + if ranges[0].start < 0 or ranges[0].end < ranges[0].start: + raise ValueError(f"Invalid range: {ranges[0].start}-{ranges[0].end}") + for r in ranges[1:]: + if r.end <= r.start: # Empty ranges filtered out in __init__ + raise ValueError(f"Invalid range: {r.start}-{r.end}") + if r.start < items[-1].end: + if r.start < items[-1].start: + raise ValueError( + f"Unsorted ranges: {items[-1].start}-{items[-1].end} and {r.start}-{r.end}" + ) + else: + raise ValueError( + f"Overlapping ranges: {items[-1].start}-{items[-1].end} and {r.start}-{r.end}" + ) + # Coalesce or create new group + if r.start - items[-1].end <= max_gap_size: + items.append(r) + else: + group = RangeGroup(items[0].start, items[-1].end, items) + groups.append(group) + self._group_start_to_group[items[0].start] = group + items = [r] + + final_group = RangeGroup(items[0].start, items[-1].end, items) + self._group_start_to_group[items[0].start] = final_group + groups.append(final_group) + return groups + + def _find_item_for_position(self, pos: int) -> ItemRange: + """Find which item contains the given position with validations.""" + + if pos < self._current_item.start: + raise ValueError( + f"{FIND_ITEM_ERROR_PREFIX}Position {pos} before current range " + f"{self._current_item.start}-{self._current_item.end}" + ) + + # Return item if position still in current item + if pos < self._current_item.end: + return self._current_item + + # Check next item + prev_item = self._current_item + try: + item = next(self._item_iter) + + if pos < item.start: + raise ValueError( + f"{FIND_ITEM_ERROR_PREFIX}Position {pos} in gap between ranges " + f"{prev_item.start}-{prev_item.end} and {item.start}-{item.end}" + ) + # Return item if position is in new item + if pos < item.end: + return item + else: + raise ValueError( + f"{FIND_ITEM_ERROR_PREFIX}Position {pos} beyond next range " + f"{item.start}-{item.end}" + ) + except StopIteration: + raise ValueError( + f"{FIND_ITEM_ERROR_PREFIX}Position {pos} beyond last range " + f"{prev_item.start}-{prev_item.end}" + ) + + def _get_stream_for_item(self, item: ItemRange) -> GetObjectStream: + """Find which RangeGroup contains the given position.""" + + # If item is the first item of a new group, create new stream + if item.start in self._group_start_to_group: + group = self._group_start_to_group[item.start] + self._stream = self._get_stream(group.start, group.end) + self._stream_pos = group.start + self._leftover = None + return self._stream + + # Otherwise, we're still in same group - reuse stream created when reading 1st item + if self._stream is None: + raise ValueError( + f"{FIND_ITEM_ERROR_PREFIX}Attempted to read item {item.start}-{item.end} " + f"without starting at the first item of its range-group" + ) + return self._stream + + def _get_item_buffer(self, item: ItemRange) -> _ItemViewBuffer: + """Load entire item into a memoryview-segment buffer from existing stream.""" + + buffer = _ItemViewBuffer() + + # Get stream from the right RangeGroup for start_pos + stream = self._get_stream_for_item(item) + pos = self._stream_pos # local copy + leftover = self._leftover # local copy + bytes_left = item.end - item.start + + # 1. Read from leftover bytes if available and needed + if leftover: + lv_len = len(leftover) + lv_end = pos + lv_len + + if pos <= item.start < lv_end: + # Item starts within leftover data + start = item.start - pos + available_bytes = lv_len - start + size = min(bytes_left, available_bytes) + end = start + size + + # Extract needed portion + buffer.append_view(leftover[start:end]) + bytes_left -= size + pos = item.start + size + leftover = leftover[end:] if end < lv_len else None + elif item.start >= lv_end: + # Item beyond leftover: advance pos to end of leftover + pos += lv_len + leftover = None + + # 2. Skip past unwanted data (due to coalescing) + while pos < item.start: + try: + chunk = memoryview(next(stream)) + except StopIteration: + break + + chunk_len = len(chunk) + + if pos + chunk_len <= item.start: + # Entire chunk before item start - skip completely + pos += chunk_len + continue + else: + # Partial Skip - slice off unwanted part first + skip_bytes = item.start - pos + chunk = chunk[skip_bytes:] + pos = item.start + chunk_len -= skip_bytes + + # Now process boundary chunk + if chunk_len <= bytes_left: + # Entire chunk needed - skip slicing + buffer.append_view(chunk) + bytes_left -= chunk_len + pos += chunk_len + leftover = None + else: + # Only part of chunk needed + buffer.append_view(chunk[:bytes_left]) + leftover = chunk[bytes_left:] + pos += bytes_left + bytes_left = 0 + break + + # 3. Take needed data for the item + while bytes_left > 0: + try: + chunk = memoryview(next(stream)) + except StopIteration: + break + + chunk_len = len(chunk) + + if chunk_len <= bytes_left: + # Entire chunk needed - skip slicing + buffer.append_view(chunk) + bytes_left -= chunk_len + pos += chunk_len + leftover = None + else: + # Only part of chunk needed + buffer.append_view(chunk[:bytes_left]) + leftover = chunk[bytes_left:] + pos += bytes_left + bytes_left = 0 + break + + self._stream_pos = pos + self._leftover = leftover + return buffer + + def read(self, size: Optional[int] = None) -> bytes: + """ + Read up to size bytes from the current position. + + Supports backward seeking within the current item buffer, but forward-only + access across DCP items (sequential item access required). + + Args: + size (int | None): how many bytes to read. + + Returns: + bytes: Bytes read from specified range. + + Raises: + TypeError: If size is not an integer. + ValueError: If position is outside valid DCP ranges, and if size is None or negative (full file reads not supported). + S3Exception: An error occurred accessing S3. + """ + if size is None: + raise ValueError("Size cannot be None; full read not supported") + if not isinstance(size, int): + raise TypeError(f"argument should be integer or None, not {type(size)!r}") + if size < 0: + raise ValueError("Size cannot be negative; full read not supported") + if size == 0: + return b"" + + item = self._find_item_for_position(self._position) + + if item is not self._current_item or self._current_item_buffer is None: + self._current_item = item + self._current_item_buffer = self._get_item_buffer(item) + + local_pos = self._position - item.start + self._current_item_buffer.seek(local_pos) + data = self._current_item_buffer.read(size) + + self._position += len(data) + return data + + def readinto(self, buf) -> int: + """ + Read up to len(buf) bytes into a pre-allocated, writable bytes-like object buf. + Return the number of bytes read. If no bytes are available, zero is returned. + + Args: + buf : writable bytes-like object + + Returns: + int : number of bytes read or zero, if no bytes available + + Raises: + ValueError: If position is outside valid DCP ranges. + TypeError: If buf is not writable. + S3Exception: An error occurred accessing S3. + """ + item = self._find_item_for_position(self._position) + + if item is not self._current_item or self._current_item_buffer is None: + self._current_item = item + self._current_item_buffer = self._get_item_buffer(item) + + local_pos = self._position - item.start + self._current_item_buffer.seek(local_pos) + bytes_read = self._current_item_buffer.readinto(buf) + + self._position += bytes_read + return bytes_read + + def seek(self, offset: int, whence: int = SEEK_SET, /) -> int: + """ + Change position within DCP ranges, interpreted relative to whence. + + Supports arbitrary seeking within current item buffer, but only forward + sequential access across DCP items (cannot seek back to previous items). + + Args: + offset (int): How many bytes to seek relative to whence. + whence (int): One of SEEK_SET, and SEEK_CUR. SEEK_END not supported. Default: SEEK_SET. + + Returns: + int: Current position of the stream + + Raises: + TypeError: If whence is not SEEK_SET or SEEK_CUR. + ValueError: If seeking to negative position or accessing previous items. + TypeError: If whence is not SEEK_SET or SEEK_CUR. + """ + if not isinstance(offset, int): + raise TypeError(f"integer argument expected, got {type(offset)!r}") + + if whence == SEEK_SET: + self._position = offset + elif whence == SEEK_CUR: + self._position += offset + else: + raise ValueError("whence must be SEEK_CUR or SEEK_SET integers") + + if self._position < 0: + raise ValueError(f"negative seek value {self._position}") + + return self._position + + def tell(self) -> int: + """ + Returns: + int: Current absolute position in the object. + """ + return self._position + + def close(self) -> None: + """ + Close the stream and release resources. + """ + if not self._closed: + self._closed = True + self._stream = None + self._leftover = None + if self._current_item_buffer: + self._current_item_buffer.close() + self._current_item_buffer = None diff --git a/s3torchconnector/src/s3torchconnector/s3reader/protocol.py b/s3torchconnector/src/s3torchconnector/s3reader/protocol.py index 44cfa0e0..d1b8f0d9 100644 --- a/s3torchconnector/src/s3torchconnector/s3reader/protocol.py +++ b/s3torchconnector/src/s3torchconnector/s3reader/protocol.py @@ -1,14 +1,29 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # // SPDX-License-Identifier: BSD -from typing import Protocol, Callable, Optional, Union +from typing import ( + TYPE_CHECKING, + Protocol, + Callable, + Optional, + Union, + List, + Dict, + runtime_checkable, +) from .s3reader import S3Reader +from .dcp_optimized import ItemRange from s3torchconnectorclient._mountpoint_s3_client import ( ObjectInfo, GetObjectStream, HeadObjectResult, ) +if TYPE_CHECKING: + from torch.distributed.checkpoint.planner import ReadItem + from torch.distributed.checkpoint.metadata import MetadataIndex + from torch.distributed.checkpoint.filesystem import _StorageInfo + class GetStreamCallable(Protocol): def __call__( @@ -16,6 +31,7 @@ def __call__( ) -> GetObjectStream: ... +@runtime_checkable class S3ReaderConstructorProtocol(Protocol): def __call__( self, @@ -24,3 +40,22 @@ def __call__( get_object_info: Callable[[], Union[ObjectInfo, HeadObjectResult]], get_stream: GetStreamCallable, ) -> S3Reader: ... + + +@runtime_checkable +class DCPS3ReaderConstructorProtocol(Protocol): + _item_ranges_by_file: Dict[str, List[ItemRange]] + + def __call__( + self, + bucket: str, + key: str, + get_object_info: Callable[[], Union[ObjectInfo, HeadObjectResult]], + get_stream: GetStreamCallable, + ) -> S3Reader: ... + + def set_item_ranges_by_file( + self, + plan_items: "List[ReadItem]", + storage_data: "Dict[MetadataIndex, _StorageInfo]", + ) -> None: ... diff --git a/s3torchconnector/tst/conftest.py b/s3torchconnector/tst/conftest.py index 8784b93e..583e19c4 100644 --- a/s3torchconnector/tst/conftest.py +++ b/s3torchconnector/tst/conftest.py @@ -6,22 +6,46 @@ from s3torchconnector.s3reader import ( S3ReaderConstructor, S3ReaderConstructorProtocol, + SequentialS3Reader, + RangedS3Reader, + DCPOptimizedS3Reader, ) +READER_TYPE_STRING_TO_CLASS = { + "sequential": SequentialS3Reader, + "range_based": RangedS3Reader, + "dcp_optimized": DCPOptimizedS3Reader, +} + # Shared reader constructors for parametrized tests # TODO: use this variable in test_distributed_training.py and test_multiprocess_dataloading.py READER_CONSTRUCTORS = [ - S3ReaderConstructor.sequential(), # Sequential Reader - S3ReaderConstructor.range_based(), # Default range-based reader, with buffer - S3ReaderConstructor.range_based(buffer_size=0), # range-based reader, no buffer + ("sequential", S3ReaderConstructor.sequential()), + ("range_based_with_buffer", S3ReaderConstructor.range_based()), + ("range_based_no_buffer", S3ReaderConstructor.range_based(buffer_size=0)), +] + +# Include dcp_optimized for DCP tests +DCP_READER_CONSTRUCTORS = READER_CONSTRUCTORS + [ + ("dcp_optimized", S3ReaderConstructor.dcp_optimized()), ] @pytest.fixture( - params=READER_CONSTRUCTORS, - ids=["sequential", "range_based_with_buffer", "range_based_no_buffer"], + params=[constructor for _, constructor in READER_CONSTRUCTORS], + ids=[name for name, _ in READER_CONSTRUCTORS], scope="module", ) def reader_constructor(request) -> S3ReaderConstructorProtocol: """Provide reader constructor (partial(S3Reader)) instances for all supported reader types.""" return request.param + + +@pytest.fixture( + params=[constructor for _, constructor in DCP_READER_CONSTRUCTORS], + ids=[name for name, _ in DCP_READER_CONSTRUCTORS], + scope="module", +) +def dcp_reader_constructor(request) -> S3ReaderConstructorProtocol: + """Provide reader constructor instances for DCP tests including dcp_optimized.""" + return request.param diff --git a/s3torchconnector/tst/e2e/dcp/test_e2e_s3_file_system.py b/s3torchconnector/tst/e2e/dcp/test_e2e_s3_file_system.py index ade93a70..052892b7 100644 --- a/s3torchconnector/tst/e2e/dcp/test_e2e_s3_file_system.py +++ b/s3torchconnector/tst/e2e/dcp/test_e2e_s3_file_system.py @@ -212,7 +212,7 @@ def test_dcp_when_multi_process( tensor_dimensions, thread_count, port_offset, - reader_constructor, + dcp_reader_constructor, ): multi_process_dcp_save_load( world_size=3, @@ -221,7 +221,7 @@ def test_dcp_when_multi_process( tensor_dimensions=tensor_dimensions, port_offset=port_offset, prefix_strategy=None, - reader_constructor=reader_constructor, + reader_constructor=dcp_reader_constructor, ) diff --git a/s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py b/s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py index d5e62af6..ea43da9b 100644 --- a/s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py +++ b/s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py @@ -10,7 +10,8 @@ from s3torchconnector import S3ReaderConstructor from s3torchconnector.dcp import S3StorageWriter, S3StorageReader -from s3torchconnector.s3reader.sequential import SequentialS3Reader +from s3torchconnector.s3reader import SequentialS3Reader, DCPOptimizedS3Reader +from s3torchconnector._s3client import S3Client SIMPLE_MODEL = torch.nn.Sequential( @@ -39,7 +40,16 @@ def __init__(self): @pytest.mark.parametrize("model", [SIMPLE_MODEL, LARGER_MODEL]) -def test_dcp_load_reads_tensors_in_sequential_order(checkpoint_directory, model): +@pytest.mark.parametrize( + "reader_class,reader_constructor", + [ + (SequentialS3Reader, S3ReaderConstructor.sequential()), + (DCPOptimizedS3Reader, S3ReaderConstructor.dcp_optimized()), + ], +) +def test_dcp_load_reads_tensors_in_sequential_order( + checkpoint_directory, model, reader_class, reader_constructor +): """ Test that prepare_local_plan allows dcp.load() to read items in offset order. @@ -61,8 +71,7 @@ def test_dcp_load_reads_tensors_in_sequential_order(checkpoint_directory, model) dcp.save(state_dict, storage_writer=storage_writer) read_positions = [] - - original_read = SequentialS3Reader.read + original_read = reader_class.read def track_reads(self, size=None): if not self.key.endswith(".metadata"): @@ -70,12 +79,10 @@ def track_reads(self, size=None): return original_read(self, size) # Load with position tracking on read() (called at the start of each torch.load()) - with patch.object(SequentialS3Reader, "read", track_reads): + with patch.object(reader_class, "read", track_reads): loaded_state_dict = {k: torch.empty_like(v) for k, v in state_dict.items()} storage_reader = S3StorageReader( - region=region, - path=s3_uri, - reader_constructor=S3ReaderConstructor.sequential(), + region=region, path=s3_uri, reader_constructor=reader_constructor ) dcp.load(loaded_state_dict, storage_reader=storage_reader) @@ -89,3 +96,90 @@ def track_reads(self, size=None): assert loaded_state_dict.keys() == state_dict.keys() for key in state_dict: assert torch.equal(loaded_state_dict[key], state_dict[key]) + + +@pytest.mark.parametrize("model", [SIMPLE_MODEL, LARGER_MODEL]) +@pytest.mark.parametrize( + "max_gap_size,load_filter,filter_name,expected_streams", + [ + # Full load - all tensors are consecutive, so always 1 stream + (0, lambda k: True, "Full", 1), + (float("inf"), lambda k: True, "Full", 1), + # Weights only - scattered by biases, so stream count depends on max_gap_size + (0, lambda k: k.endswith(".weight"), "Weights", 3), + (float("inf"), lambda k: k.endswith(".weight"), "Weights", 1), + # Layer 2 only - their bias+weight tensors are consecutive, so always 1 stream + (0, lambda k: "2." in k, "Layer 2", 1), + (float("inf"), lambda k: "2." in k, "Layer 2", 1), + ], +) +def test_dcp_optimized_loading_patterns( + checkpoint_directory, + model, + max_gap_size, + load_filter, + filter_name, + expected_streams, +): + """Test DCPOptimized reader with full and partial loading patterns and different max_gap_size. + + Validates that full loads use 1 stream, and partial load stream usage depends + on max_gap_size and whether tensors are consecutive / neighbours. + + SIMPLE_MODEL tensors: ['0.bias', '0.weight', '1.bias', '1.weight', '2.bias', '2.weight'] + LARGER_MODEL tensors: ['linear_relu_stack.0.bias', 'linear_relu_stack.0.weight', 'linear_relu_stack.2.bias', + 'linear_relu_stack.2.weight', 'linear_relu_stack.4.bias', 'linear_relu_stack.4.weight'] + """ + region = checkpoint_directory.region + s3_uri = checkpoint_directory.s3_uri + + state_dict = model.state_dict() + dcp.save(state_dict, storage_writer=S3StorageWriter(region, s3_uri, overwrite=True)) + + # Print model structure (once per model) + all_keys = list(state_dict.keys()) + if max_gap_size == 0 and filter_name == "Full": + print(f"\nTensors: {sorted(all_keys)}") + + # Apply filter for partial load + filtered_keys = [k for k in all_keys if load_filter(k)] + excluded_keys = [k for k in all_keys if not load_filter(k)] + assert filtered_keys, f"No keys match {filter_name} filter for this model" + filtered_dict = {k: torch.empty_like(state_dict[k]) for k in filtered_keys} + + # Load full / partial checkpoint with stream call tracker + stream_calls = [] + original_get_object_stream = S3Client._get_object_stream + + def track_get_object_stream(self, bucket, key, start=None, end=None): + if not key.endswith(".metadata"): + stream_calls.append((start, end)) + return original_get_object_stream(self, bucket, key, start=start, end=end) + + with patch.object(S3Client, "_get_object_stream", track_get_object_stream): + reader_constructor = S3ReaderConstructor.dcp_optimized(max_gap_size) + reader = S3StorageReader(region, s3_uri, reader_constructor=reader_constructor) + dcp.load(filtered_dict, storage_reader=reader) + + # Verify correctness + assert len(filtered_dict) == len(filtered_keys) + for k, v in filtered_dict.items(): + assert torch.equal(v, state_dict[k]) + assert load_filter(k) + + # Verify excluded keys are not loaded + for k in excluded_keys: + assert k not in filtered_dict, f"Key {k} should not be in {filter_name} load" + + # Verify expected stream count + assert len(stream_calls) == expected_streams + if len(stream_calls) > 1: + for i in range(1, len(stream_calls)): + assert stream_calls[i][0] >= stream_calls[i - 1][1] + assert stream_calls[i][0] - stream_calls[i - 1][1] >= max_gap_size + + # Print number of stream calls + coalesce = "no coalesce" if max_gap_size == 0 else "full coalesce" + print( + f"{filter_name} load, {coalesce}: {len(stream_calls)} streams, {len(filtered_keys)} tensors" + ) diff --git a/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py b/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py index f46dd56f..9d2ef668 100644 --- a/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py +++ b/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py @@ -1,13 +1,14 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # // SPDX-License-Identifier: BSD -from unittest.mock import Mock +from unittest.mock import Mock, patch from hypothesis import given from hypothesis.strategies import composite, integers, lists from torch.distributed.checkpoint.planner import LoadPlan, ReadItem from s3torchconnector.dcp import S3StorageReader +from s3torchconnector.s3reader import S3ReaderConstructor, ItemRange TEST_REGION = "eu-east-1" TEST_PATH = "s3://test-bucket/test-checkpoint/" @@ -15,15 +16,18 @@ @composite def load_plan_with_offsets(draw): - """Generate LoadPlan with random offsets.""" + """Generate LoadPlan with random offsets and lengths.""" offsets = draw(lists(integers(0, 10_000_000), min_size=1, max_size=10_000)) + lengths = draw(lists(integers(1, 10_000_000), min_size=1, max_size=10_000)) storage_data = {} items = [] - for i, offset in enumerate(offsets): + for i, (offset, length) in enumerate(zip(offsets, lengths)): storage_index = f"item{i}" - storage_data[storage_index] = Mock(offset=offset) + storage_data[storage_index] = Mock( + relative_path=f"__{i%8}_0.distcp", offset=offset, length=length + ) items.append(Mock(spec=ReadItem, storage_index=storage_index)) return LoadPlan(items), storage_data @@ -41,7 +45,7 @@ def test_s3storage_reader_prepare_local_plan_empty(): @given(load_plan_with_offsets()) -def test_s3storage_reader_prepare_local_plan(loadplan_and_storagedata): +def test_s3storage_reader_prepare_local_plan_sorts_items(loadplan_and_storagedata): """Test prepare local plan sorts items by storage_data offset.""" load_plan, storage_data = loadplan_and_storagedata @@ -62,3 +66,64 @@ def test_s3storage_reader_prepare_local_plan(loadplan_and_storagedata): # Verify Load Ordering keeps items the same assert len(sorted_plan.items) == len(load_plan.items) assert set(sorted_plan.items) == set(load_plan.items) + + +@given(load_plan_with_offsets()) +def test_s3storage_reader_prepare_local_plan_calls_range_injection( + loadplan_and_storagedata, +): + """Test prepare_local_plan calls set_item_ranges_by_file() for DCPS3ReaderConstructor.""" + load_plan, storage_data = loadplan_and_storagedata + + constructor = S3ReaderConstructor.dcp_optimized() + s3_storage_reader = S3StorageReader( + TEST_REGION, TEST_PATH, reader_constructor=constructor + ) + s3_storage_reader.storage_data = storage_data + + with patch.object(constructor, "set_item_ranges_by_file") as mock_method: + s3_storage_reader.prepare_local_plan(load_plan) + mock_method.assert_called_once_with(load_plan.items, storage_data) + + +@given(load_plan_with_offsets()) +def test_s3storage_reader_prepare_local_plan_injects_ranges_correctly( + loadplan_and_storagedata, +): + """Test prepare_local_plan correctly injects ranges into DCPS3ReaderConstructor.""" + load_plan, storage_data = loadplan_and_storagedata + + constructor = S3ReaderConstructor.dcp_optimized() + s3_storage_reader = S3StorageReader( + TEST_REGION, TEST_PATH, reader_constructor=constructor + ) + s3_storage_reader.storage_data = storage_data + s3_storage_reader.prepare_local_plan(load_plan) + + for item in load_plan.items: + storage_info = storage_data[item.storage_index] + offset, length = storage_info.offset, storage_info.length + relative_path = storage_info.relative_path + + expected_range = ItemRange(offset, offset + length) + assert expected_range in constructor._item_ranges_by_file[relative_path] + + +@given(load_plan_with_offsets()) +def test_s3storage_reader_prepare_local_plan_no_injection_for_other_constructors( + reader_constructor, + loadplan_and_storagedata, +): + """Test prepare_local_plan does NOT inject ranges for non-DCPOptimized reader constructors.""" + load_plan, storage_data = loadplan_and_storagedata + + s3_storage_reader = S3StorageReader( + TEST_REGION, TEST_PATH, reader_constructor=reader_constructor + ) + s3_storage_reader.storage_data = storage_data + + result = s3_storage_reader.prepare_local_plan(load_plan) + assert len(result.items) == len(load_plan.items) + + # Verify no injection occurred - regular constructors don't have _item_ranges_by_file + assert not hasattr(reader_constructor, "_item_ranges_by_file") diff --git a/s3torchconnector/tst/unit/test_s3reader_constructor.py b/s3torchconnector/tst/unit/test_s3reader_constructor.py index 7524ca8e..ba4f7658 100644 --- a/s3torchconnector/tst/unit/test_s3reader_constructor.py +++ b/s3torchconnector/tst/unit/test_s3reader_constructor.py @@ -2,26 +2,46 @@ # // SPDX-License-Identifier: BSD import pytest +import sys +from unittest.mock import Mock from s3torchconnector import S3ReaderConstructor -from s3torchconnector.s3reader import SequentialS3Reader, RangedS3Reader +from s3torchconnector.s3reader import ( + SequentialS3Reader, + RangedS3Reader, + DCPOptimizedS3Reader, + DCPOptimizedConstructor, +) +from s3torchconnector.s3reader.dcp_optimized import ItemRange, DEFAULT_MAX_GAP_SIZE from s3torchconnector.s3reader.ranged import DEFAULT_BUFFER_SIZE -TEST_BUCKET = "test-bucket" -TEST_KEY = "test-key" +from torch.distributed.checkpoint.planner import ReadItem +from torch.distributed.checkpoint.metadata import MetadataIndex +from torch.distributed.checkpoint.filesystem import _StorageInfo + +from .test_s3reader_common import TEST_BUCKET, TEST_KEY, MOCK_OBJECT_INFO, MOCK_STREAM + +# ---------- basic constructor tests ----------- + + +def test_s3readerconstructor_default_constructor(): + """Test default constructor returns sequential reader""" + constructor = S3ReaderConstructor.default() + s3reader = constructor(TEST_BUCKET, TEST_KEY, MOCK_OBJECT_INFO, MOCK_STREAM) + assert isinstance(s3reader, SequentialS3Reader) def test_s3readerconstructor_sequential_constructor(): """Test sequential reader construction""" constructor = S3ReaderConstructor.sequential() - s3reader = constructor(TEST_BUCKET, TEST_KEY, lambda: None, lambda: iter([])) + s3reader = constructor(TEST_BUCKET, TEST_KEY, MOCK_OBJECT_INFO, MOCK_STREAM) assert isinstance(s3reader, SequentialS3Reader) def test_s3readerconstructor_range_based_constructor(): """Test range-based reader construction""" constructor = S3ReaderConstructor.range_based() - s3reader = constructor(TEST_BUCKET, TEST_KEY, lambda: None, lambda: iter([])) + s3reader = constructor(TEST_BUCKET, TEST_KEY, MOCK_OBJECT_INFO, MOCK_STREAM) assert isinstance(s3reader, RangedS3Reader) @@ -38,32 +58,255 @@ def test_s3readerconstructor_range_based_constructor_buffer_configurations( ): """Test range-based reader construction with different buffer configurations""" constructor = S3ReaderConstructor.range_based(buffer_size=buffer_size) - s3reader = constructor(TEST_BUCKET, TEST_KEY, lambda: None, lambda: iter([])) + s3reader = constructor(TEST_BUCKET, TEST_KEY, MOCK_OBJECT_INFO, MOCK_STREAM) assert isinstance(s3reader, RangedS3Reader) assert s3reader._buffer_size == expected_buffer_size assert s3reader._enable_buffering is expected_enable_buffering -def test_s3readerconstructor_default_constructor(): - """Test default constructor returns sequential reader""" - constructor = S3ReaderConstructor.default() - s3reader = constructor(TEST_BUCKET, TEST_KEY, lambda: None, lambda: iter([])) +# ---------- dcp_optimized constructor tests ---------- + + +def test_s3readerconstructor_dcp_optimized_constructor(): + """Test DCP optimized reader construction""" + constructor = S3ReaderConstructor.dcp_optimized() + assert isinstance(constructor, DCPOptimizedConstructor) + + # Without ranges, should fallback to SequentialS3Reader + s3reader = constructor(TEST_BUCKET, TEST_KEY, MOCK_OBJECT_INFO, MOCK_STREAM) assert isinstance(s3reader, SequentialS3Reader) + # With ranges, should create DCPOptimizedS3Reader + constructor._item_ranges_by_file = {TEST_KEY: [ItemRange(0, 100)]} + s3reader = constructor(TEST_BUCKET, TEST_KEY, MOCK_OBJECT_INFO, MOCK_STREAM) + assert isinstance(s3reader, DCPOptimizedS3Reader) + + +# * max_gap_size tests + + +def test_dcp_optimized_constructor_default_max_gap_size(): + """Test max_gap_size parameter defaults and propagation""" + + constructor = S3ReaderConstructor.dcp_optimized() + assert isinstance(constructor, DCPOptimizedConstructor) + assert constructor._max_gap_size == DEFAULT_MAX_GAP_SIZE + + constructor._item_ranges_by_file = {TEST_KEY: [ItemRange(0, 100)]} + s3reader = constructor(TEST_BUCKET, TEST_KEY, MOCK_OBJECT_INFO, MOCK_STREAM) + assert isinstance(s3reader, DCPOptimizedS3Reader) + assert s3reader._max_gap_size == DEFAULT_MAX_GAP_SIZE -def test_s3readerconstructor_get_reader_type_string(): + +@pytest.mark.parametrize("max_gap_size", [0, 8 * 1024 * 1024, 1024 * 1024 * 1024]) +def test_dcp_optimized_constructor_custom_max_gap_size(max_gap_size): + """Test max_gap_size parameter defaults and propagation""" + + constructor = S3ReaderConstructor.dcp_optimized(max_gap_size=max_gap_size) + assert isinstance(constructor, DCPOptimizedConstructor) + assert constructor._max_gap_size == max_gap_size + + constructor._item_ranges_by_file = {TEST_KEY: [ItemRange(0, 100)]} + s3reader = constructor(TEST_BUCKET, TEST_KEY, MOCK_OBJECT_INFO, MOCK_STREAM) + assert isinstance(s3reader, DCPOptimizedS3Reader) + assert s3reader._max_gap_size == max_gap_size + + +@pytest.mark.parametrize("max_gap_size", [sys.maxsize, float("inf"), 0.5]) +def test_dcp_optimized_constructor_max_gap_size_edge_cases(max_gap_size): + """Test max_gap_size parameter defaults and propagation""" + + constructor = S3ReaderConstructor.dcp_optimized(max_gap_size=max_gap_size) + assert isinstance(constructor, DCPOptimizedConstructor) + assert constructor._max_gap_size == max_gap_size + + constructor._item_ranges_by_file = {TEST_KEY: [ItemRange(0, 100)]} + s3reader = constructor(TEST_BUCKET, TEST_KEY, MOCK_OBJECT_INFO, MOCK_STREAM) + assert isinstance(s3reader, DCPOptimizedS3Reader) + assert s3reader._max_gap_size == max_gap_size + + +@pytest.mark.parametrize( + "max_gap_size,expected_error", + [ + (-1, ValueError), + ("1", TypeError), + ([1], TypeError), + (None, TypeError), + ], +) +def test_dcp_optimized_constructor_invalid_max_gap_size(max_gap_size, expected_error): + """Test parameter validation for max_gap_size""" + with pytest.raises(expected_error): + S3ReaderConstructor.dcp_optimized(max_gap_size) + + +# * set_item_ranges_by_file() + + +def test_dcp_optimized_constructor_set_item_ranges_by_file_empty_file(): + """Test empty plan is allowed (disallowed during call)""" + constructor = S3ReaderConstructor.dcp_optimized() + + constructor.set_item_ranges_by_file([], {}) + assert constructor._item_ranges_by_file == {} + + +@pytest.mark.parametrize( + "relative_path", + [ + ("__0_0.distcp"), + ("nested/path/to/file/__0_0.distcp"), + ("prefix_strategy/shard1/epoch_5/__0_0.distcp"), + ], +) +def test_dcp_optimized_constructor_set_item_ranges_by_file_filename_extraction( + relative_path, +): + """Test filename extraction from various path formats""" + constructor = S3ReaderConstructor.dcp_optimized() + + metadata_index = MetadataIndex("idx") + read_item = Mock(spec=ReadItem, storage_index=metadata_index) + storage_data = { + metadata_index: _StorageInfo(relative_path=relative_path, offset=0, length=100) + } + + constructor.set_item_ranges_by_file([read_item], storage_data) + assert relative_path in constructor._item_ranges_by_file + + +def test_dcp_optimized_constructor_set_item_ranges_by_file_multiple_items(): + """Test set_item_ranges_by_file with different ReadItems""" + constructor = S3ReaderConstructor.dcp_optimized() + + metadata_indices = [MetadataIndex(f"idx{i}") for i in range(3)] + read_items = [ + Mock(spec=ReadItem, storage_index=metadata_indices[i]) for i in range(3) + ] + storage_data = { + metadata_indices[0]: _StorageInfo( + relative_path="file1.distcp", offset=0, length=100 + ), + metadata_indices[1]: _StorageInfo( + relative_path="file1.distcp", offset=100, length=50 + ), + metadata_indices[2]: _StorageInfo( + relative_path="file2.distcp", offset=0, length=200 + ), + } + + constructor.set_item_ranges_by_file(read_items, storage_data) # type: ignore + + assert "file1.distcp" in constructor._item_ranges_by_file + assert "file2.distcp" in constructor._item_ranges_by_file + assert len(constructor._item_ranges_by_file["file1.distcp"]) == 2 + assert len(constructor._item_ranges_by_file["file2.distcp"]) == 1 + assert constructor._item_ranges_by_file["file1.distcp"][0] == ItemRange(0, 100) + assert constructor._item_ranges_by_file["file1.distcp"][1] == ItemRange(100, 150) + + +def test_dcp_optimized_constructor_set_item_ranges_by_file_multiple_calls(): + """Test constructor handles multiple calls to set_item_ranges_by_file""" + constructor = S3ReaderConstructor.dcp_optimized() + + # First call + metadata_index1 = MetadataIndex("idx1") + read_item1 = Mock(spec=ReadItem, storage_index=metadata_index1) + storage_data1 = { + metadata_index1: _StorageInfo( + relative_path="file1.distcp", offset=0, length=100 + ) + } + constructor.set_item_ranges_by_file([read_item1], storage_data1) + + # Second call should replace previous ranges + metadata_index2 = MetadataIndex("idx2") + read_item2 = Mock(spec=ReadItem, storage_index=metadata_index2) + storage_data2 = { + metadata_index2: _StorageInfo( + relative_path="file2.distcp", offset=0, length=200 + ) + } + constructor.set_item_ranges_by_file([read_item2], storage_data2) + + # Only second call's data should remain + assert "file1.distcp" not in constructor._item_ranges_by_file + assert "file2.distcp" in constructor._item_ranges_by_file + assert len(constructor._item_ranges_by_file["file2.distcp"]) == 1 + assert constructor._item_ranges_by_file["file2.distcp"][0] == ItemRange(0, 200) + + +# * __call__ tests + + +def test_dcp_optimized_constructor_call_with_invalid_ranges(): + """Test dcp_optimized constructor __call__ falls back to SequentialS3Reader when no item ranges for the file""" + constructor = S3ReaderConstructor.dcp_optimized() + assert isinstance(constructor, DCPOptimizedConstructor) + + # Test with no ranges - should fallback + constructor._item_ranges_by_file = {} + s3reader = constructor(TEST_BUCKET, TEST_KEY, MOCK_OBJECT_INFO, MOCK_STREAM) + assert isinstance(s3reader, SequentialS3Reader) + + # Test with ranges for different file - should fallback + constructor._item_ranges_by_file = {"not_test_key.distcp": [ItemRange(0, 100)]} + s3reader = constructor(TEST_BUCKET, TEST_KEY, MOCK_OBJECT_INFO, MOCK_STREAM) + assert isinstance(s3reader, SequentialS3Reader) + + # Test with empty range - should raise error + constructor._item_ranges_by_file = {TEST_KEY: []} + with pytest.raises(ValueError): + s3reader = constructor(TEST_BUCKET, TEST_KEY, MOCK_OBJECT_INFO, MOCK_STREAM) + + +def test_dcp_optimized_constructor_call_relative_path_matching(): + """Test __call__ method with relative path matching""" + constructor = S3ReaderConstructor.dcp_optimized() + constructor._item_ranges_by_file = { + "shard1/epoch_5/__0_0.distcp": [ItemRange(0, 100)] + } + + # Should match when key ends with relative_path + key = "base/shard1/epoch_5/__0_0.distcp" + s3reader = constructor(TEST_BUCKET, key, MOCK_OBJECT_INFO, MOCK_STREAM) + assert isinstance(s3reader, DCPOptimizedS3Reader) + assert s3reader.key == key + + key = "shard1/epoch_5/__0_0.distcp" + # Should also match shorter key that ends with relative_path + s3reader = constructor(TEST_BUCKET, key, MOCK_OBJECT_INFO, MOCK_STREAM) + assert isinstance(s3reader, DCPOptimizedS3Reader) + assert s3reader.key == key + + # No match - should fallback to SequentialS3Reader + key = "different/path/__1_0.distcp" + s3reader = constructor(TEST_BUCKET, key, MOCK_OBJECT_INFO, MOCK_STREAM) + assert isinstance(s3reader, SequentialS3Reader) + assert s3reader.key == key + + # Same filename, different path - should fallback to SequentialS3Reader + key = "same/filename/different/path/__0_0.distcp" + s3reader = constructor(TEST_BUCKET, key, MOCK_OBJECT_INFO, MOCK_STREAM) + assert isinstance(s3reader, SequentialS3Reader) + assert s3reader.key == key + + +# ---------- + + +@pytest.mark.parametrize( + "constructor, expected_type", + [ + (S3ReaderConstructor.sequential(), "sequential"), + (S3ReaderConstructor.range_based(), "range_based"), + (S3ReaderConstructor.dcp_optimized(), "dcp_optimized"), + (None, "sequential"), + (S3ReaderConstructor.default(), "sequential"), + ], +) +def test_s3readerconstructor_get_reader_type_string(constructor, expected_type): """Test reader type string generation""" - assert ( - S3ReaderConstructor.get_reader_type_string(S3ReaderConstructor.sequential()) - == "sequential" - ) - assert ( - S3ReaderConstructor.get_reader_type_string(S3ReaderConstructor.range_based()) - == "range_based" - ) - assert S3ReaderConstructor.get_reader_type_string(None) == "sequential" - assert ( - S3ReaderConstructor.get_reader_type_string(S3ReaderConstructor.default()) - == "sequential" - ) + assert S3ReaderConstructor.get_reader_type_string(constructor) == expected_type diff --git a/s3torchconnector/tst/unit/test_s3reader_dcp_optimized.py b/s3torchconnector/tst/unit/test_s3reader_dcp_optimized.py new file mode 100644 index 00000000..75d623b1 --- /dev/null +++ b/s3torchconnector/tst/unit/test_s3reader_dcp_optimized.py @@ -0,0 +1,763 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# // SPDX-License-Identifier: BSD + +import re +import sys +from io import BytesIO, SEEK_SET, SEEK_CUR, SEEK_END +from typing import List, Tuple, Optional + +import pytest +from hypothesis import given, assume +from hypothesis.strategies import integers, composite + +from s3torchconnector.s3reader.dcp_optimized import ( + DCPOptimizedS3Reader, + ItemRange, + RangeGroup, + _ItemViewBuffer, + DEFAULT_MAX_GAP_SIZE, + FIND_ITEM_ERROR_PREFIX, +) +from .test_s3reader_common import ( + TEST_BUCKET, + TEST_KEY, + MOCK_OBJECT_INFO, + MOCK_STREAM, + create_object_info_getter, + create_stream_getter, + bytestream_and_positions, +) + + +def create_dcp_s3reader( + ranges: Optional[List[ItemRange]] = None, + stream_data: Optional[List[bytes]] = None, + max_gap_size: int = DEFAULT_MAX_GAP_SIZE, + chunk_size: int = 5, +): + """Create DCPOptimizedS3Reader with mock stream data""" + if ranges is None: + ranges = [ItemRange(0, 10)] + if stream_data is None: + stream_data = [b"0123456789"] + + return DCPOptimizedS3Reader( + TEST_BUCKET, + TEST_KEY, + ranges, + create_object_info_getter(stream_data), + create_stream_getter(stream_data, chunk_size=chunk_size), # type: ignore + max_gap_size=max_gap_size, + ) + + +@composite +def dcp_ranges_and_stream(draw): + """Generate sorted and non-overlapping item ranges with corresponding stream data""" + num_ranges = draw(integers(min_value=10, max_value=100)) + ranges = [] + current_pos = 0 + + for _ in range(num_ranges): + # Random gap size and random length + start = current_pos + draw(integers(min_value=0, max_value=2000)) + length = draw(integers(min_value=100, max_value=1000)) + end = start + length + ranges.append(ItemRange(start, end)) + current_pos = end + + stream_data = [b"x" * current_pos] + + return ranges, stream_data + + +class TestItemViewBuffer: + """ItemViewBuffer Tests""" + + def test_append_view_and_size_calculation(self): + """Test append_view and size calculation correctness""" + buffer = _ItemViewBuffer() + assert buffer._size == 0 + assert buffer.tell() == 0 + + buffer.append_view(memoryview(b"Hello")) + assert buffer._size == 5 + assert len(buffer._segments) == 1 + assert buffer._segments[0] == b"Hello" + assert buffer._lengths[0] == 5 + assert buffer._offsets[0] == 0 + + buffer.append_view(memoryview(b" World!")) + assert buffer._size == 12 + assert len(buffer._segments) == 2 + assert buffer._segments[1] == b" World!" + assert buffer._lengths[1] == 7 + assert buffer._offsets[1] == 5 + + # Empty views should be ignored + buffer.append_view(memoryview(b"")) + assert buffer._size == 12 + assert len(buffer._segments) == 2 + + @pytest.mark.parametrize( + "start_pos, offset, whence, expected_pos", + [ + # SEEK_SET + (0, 5, SEEK_SET, 5), + (3, 0, SEEK_SET, 0), + (5, 15, SEEK_SET, 15), # Allow seeking past buffer end + # SEEK_CUR + (3, 2, SEEK_CUR, 5), + (5, -2, SEEK_CUR, 3), + (5, 10, SEEK_CUR, 15), + # SEEK_END + (0, -3, SEEK_END, 7), + (5, 0, SEEK_END, 10), + (5, 2, SEEK_END, 12), + ], + ) + def test_buffer_seek_all_modes(self, start_pos, offset, whence, expected_pos): + buffer = _ItemViewBuffer() + buffer.append_view(memoryview(b"0123456789")) + buffer.seek(start_pos) + + pos = buffer.seek(offset, whence) + assert pos == expected_pos + assert buffer.tell() == expected_pos + + @pytest.mark.parametrize( + "offset, whence, expected_error", + [ + # Negative offsets + (-1, SEEK_SET, AssertionError), + (-1, SEEK_CUR, AssertionError), + (-15, SEEK_END, AssertionError), + # Invalid whence + (5, 3, ValueError), + (5, None, ValueError), + (5, "SEEK_SET", ValueError), + ], + ) + def test_buffer_invalid_seek(self, offset, whence, expected_error): + buffer = _ItemViewBuffer() + buffer.append_view(memoryview(b"0123456789")) + + with pytest.raises(expected_error): + buffer.seek(offset, whence) + + @pytest.mark.parametrize( + "start, size, expected_data, expected_pos", + [ + # Zero reads + (0, 0, b"", 0), + (5, 0, b"", 5), + (10, 0, b"", 10), + # Normal reads + (0, 4, b"0123", 4), + (2, 3, b"234", 5), + (5, 5, b"56789", 10), + (0, 10, b"0123456789", 10), + # Past EOF reads + (8, 5, b"89", 10), + (9, 3, b"9", 10), + (0, 15, b"0123456789", 10), + # EOF and beyond EOF reads + (10, 5, b"", 10), + (15, 5, b"", 15), + ], + ) + def test_buffer_read_cases(self, start, size, expected_data, expected_pos): + """Test read() normal reads and edge cases""" + buffer = _ItemViewBuffer() + for segment in [b"012", b"34", b"5", b"6789"]: + buffer.append_view(memoryview(segment)) + + buffer.seek(start) + data = buffer.read(size) + assert data == expected_data + assert buffer.tell() == expected_pos + + @pytest.mark.parametrize( + "segments", + [ + # Fast path works (first segment >= 4 bytes) + [b"PK\x03\x04abcdef"], + # Fast path does not trigger, but still reads correctly + [b"PK", b"\x03\x04abcdef"], + [b"P", b"K\x03", b"\x04abcdef"], + [b"P", b"K", b"\x03", b"\x04abcdef"], + ], + ) + def test_buffer_read_fast_path_optimization(self, segments): + """Test read(4) at pos=0 fast path""" + buffer = _ItemViewBuffer() + for segment in segments: + buffer.append_view(memoryview(segment)) + + # Test read(4) at pos=0 + data = buffer.read(4) + assert data == b"PK\x03\x04" + assert buffer.tell() == 4 + assert isinstance(data, bytes) + + # Test normal read continues + data = buffer.read(3) + assert data == b"abc" + assert buffer.tell() == 7 + + @pytest.mark.parametrize( + "buf", + [ + bytearray(5), + memoryview(bytearray(5)), + ], + ) + def test_buffer_readinto_valid_types(self, buf): + """Test readinto() valid buffer types""" + buffer = _ItemViewBuffer() + buffer.append_view(memoryview(b"hello")) + + bytes_read = buffer.readinto(buf) + + assert bytes_read == 5 + assert bytes(buf) == b"hello" + assert buffer.tell() == 5 + + @pytest.mark.parametrize( + "buf, expected_error", + [ + # Invalid + ("hello", TypeError), + (12345, TypeError), + ([1, 2, 3, 4, 5], TypeError), + (None, TypeError), + # Readonly memoryviews + (memoryview(b"12345"), AssertionError), + (memoryview(bytearray(5)).toreadonly(), AssertionError), + ], + ) + def test_buffer_readinto_invalid_types(self, buf, expected_error): + """Test readinto() invalid buffer types""" + buffer = _ItemViewBuffer() + buffer.append_view(memoryview(b"hello")) + + with pytest.raises(expected_error): + buffer.readinto(buf) + + @pytest.mark.parametrize( + "start, buf_size, expected_bytes, expected_pos", + [ + # Zero buffer cases + (0, 0, 0, 0), + (5, 0, 0, 5), + (10, 0, 0, 10), + # Normal readinto cases + (0, 5, 5, 5), + (3, 4, 4, 7), + (7, 2, 2, 9), + # Near EOF cases + (8, 5, 2, 10), + (9, 3, 1, 10), + (9, 1, 1, 10), + # EOF and beyond EOF cases + (10, 5, 0, 10), + (15, 5, 0, 15), + ], + ) + def test_buffer_readinto_edge_cases( + self, start, buf_size, expected_bytes, expected_pos + ): + """Test readinto() normal read and edge cases""" + buffer = _ItemViewBuffer() + for segment in [b"012", b"34", b"5", b"6789"]: + buffer.append_view(memoryview(segment)) + + buffer.seek(start) + buf = bytearray(buf_size) + bytes_read = buffer.readinto(buf) + + assert bytes_read == expected_bytes + assert buffer.tell() == expected_pos + + if expected_bytes > 0: + expected_data = b"0123456789"[start : start + expected_bytes] + assert buf[:expected_bytes] == expected_data + + @pytest.mark.parametrize("buf_size", [1, 4, 10, 50]) + @pytest.mark.parametrize( + "buf_type", [bytearray, lambda size: memoryview(bytearray(size))] + ) + @given(bytestream_and_positions()) + def test_buffer_readinto_hypothesis( + self, buf_size, buf_type, stream_and_positions: Tuple[List[bytes], List[int]] + ): + """Test readinto() operations against BytesIO equivalent""" + segments, read_positions = stream_and_positions + + buffer = _ItemViewBuffer() + for segment in segments: + buffer.append_view(memoryview(segment)) + reference_io = BytesIO(b"".join(segments)) + + for pos in read_positions: + buffer.seek(pos) + reference_io.seek(pos) + + buf = buf_type(buf_size) + ref_buf = buf_type(buf_size) + + bytes_read = buffer.readinto(buf) + ref_bytes_read = reference_io.readinto(ref_buf) + + assert bytes_read == ref_bytes_read + assert buffer.tell() == reference_io.tell() + assert bytes(buf[:bytes_read]) == bytes(ref_buf[:bytes_read]) + + def test_buffer_close(self): + """Test close functionality - segments cleared and flag set""" + buffer = _ItemViewBuffer() + segments = [b"hello", b" ", b"world", b"!"] + for segment in segments: + buffer.append_view(memoryview(segment)) + + assert not buffer._closed + assert len(buffer._segments) == len(segments) + + buffer.close() + assert buffer._closed + assert len(buffer._segments) == 0 + + +class TestCreationAndValidation: + """ + DCPOptimizedS3Reader creation and validation tests. + References some tests in test_s3reader_common.py; not tested there since its behaviour is different. + """ + + def test_s3reader_creation(self): + """Test basic reader creation""" + reader = create_dcp_s3reader() + assert reader + assert reader.bucket == TEST_BUCKET + assert reader.key == TEST_KEY + assert not reader.closed + + @pytest.mark.parametrize( + "bucket, key, expected_error", + [ + (None, TEST_KEY, "Bucket should be specified"), + ("", TEST_KEY, "Bucket should be specified"), + (TEST_BUCKET, None, "Key should be specified"), + (TEST_BUCKET, "", "Key should be specified"), + ], + ) + def test_invalid_bucket_key_validation(self, bucket, key, expected_error): + """Test bucket and key validation""" + ranges = [ItemRange(0, 10)] + with pytest.raises(ValueError, match=expected_error): + DCPOptimizedS3Reader(bucket, key, ranges, MOCK_OBJECT_INFO, MOCK_STREAM) + + @pytest.mark.parametrize( + "offset", + [0.4, 0.0, 1.0, "test", 1 + 2j, [1, 2, 3], {}, {2}], + ) + def test_fails_with_non_int_arg(self, offset): + """Test type validation for seek and read arguments""" + reader = create_dcp_s3reader() + + with pytest.raises(TypeError): + reader.seek(offset) + with pytest.raises(TypeError): + reader.read(offset) + + def test_empty_ranges_rejection(self): + """Test empty ranges are rejected""" + with pytest.raises( + ValueError, + match=r"item_ranges must be a non-empty List\[ItemRange\] object", + ): + create_dcp_s3reader(ranges=[]) + + @pytest.mark.parametrize("max_gap_size", [0, 1024, 32 * 1024 * 1024]) + def test_valid_max_gap_size(self, max_gap_size): + """Test valid max_gap_size values are accepted""" + reader = create_dcp_s3reader(max_gap_size=max_gap_size) + assert reader._max_gap_size == max_gap_size + + @pytest.mark.parametrize( + "max_gap_size,expected_error,error_msg", + [ + (-1, ValueError, "max_gap_size must be non-negative"), + ("1", TypeError, "max_gap_size must be int or float, got str"), + ([1], TypeError, "max_gap_size must be int or float, got list"), + (None, TypeError, "max_gap_size must be int or float, got NoneType"), + ], + ) + def test_invalid_max_gap_size_types(self, max_gap_size, expected_error, error_msg): + """Test max_gap_size type validation""" + with pytest.raises(expected_error, match=error_msg): + create_dcp_s3reader(max_gap_size=max_gap_size) + + +class TestValidateAndCoalesceRanges: + """DCPOptimizedS3Reader _validate_and_coalesce_ranges tests for different ItemRanges and max_gap_sizes""" + + @pytest.mark.parametrize( + "ranges,error_msg", + [ + ([ItemRange(-1, 10)], "Invalid range: -1-10"), + ([ItemRange(0, 5), ItemRange(10, 5)], "Invalid range: 10-5"), + ([ItemRange(10, 20), ItemRange(5, 10)], "Unsorted ranges: 10-20 and 5-10"), + ([ItemRange(0, 10), ItemRange(5, 15)], "Overlapping ranges: 0-10 and 5-15"), + ], + ) + def test_validation_errors(self, ranges, error_msg): + """Test validation error cases""" + with pytest.raises(ValueError, match=error_msg): + create_dcp_s3reader(ranges) + + def test_empty_ranges_filtered_out(self): + """Test empty ranges are filtered out (during initialization)""" + ranges = [ItemRange(10, 10), ItemRange(20, 30), ItemRange(100, 100)] + reader = create_dcp_s3reader(ranges) + + assert len(reader._item_ranges) == 1 + assert reader._item_ranges[0].start == 20 + assert reader._item_ranges[0].end == 30 + + def test_all_empty_ranges_error(self): + """Test all empty ranges causes error""" + ranges = [ItemRange(10, 10), ItemRange(20, 20)] + + with pytest.raises(ValueError, match="No non-empty ranges to read"): + create_dcp_s3reader(ranges) + + @pytest.mark.parametrize( + "max_gap_size,ranges,expected_groups", + [ + # Basic Tests + (10, [ItemRange(0, 10)], 1), # Single range + (0, [ItemRange(0, 10), ItemRange(20, 30)], 2), # No coalescing + (10, [ItemRange(0, 10), ItemRange(20, 30)], 1), # Just coalesced + (9, [ItemRange(0, 10), ItemRange(20, 30)], 2), # Just no coalesce + # 3 ranges + (50, [ItemRange(0, 10), ItemRange(20, 30), ItemRange(100, 110)], 2), + (50, [ItemRange(0, 50), ItemRange(50, 100), ItemRange(149, 199)], 1), + # Zero gap + (0, [ItemRange(0, 10), ItemRange(10, 20)], 1), + (0, [ItemRange(0, 10), ItemRange(11, 20)], 2), + # Infinite / large gap size - coalesce all + (float("inf"), [ItemRange(0, 10), ItemRange(1016 * 1024, 1024 * 1024)], 1), + (sys.maxsize, [ItemRange(0, 10), ItemRange(1016 * 1024, 1024 * 1024)], 1), + (2**50, [ItemRange(0, 10), ItemRange(1016 * 1024, 1024 * 1024)], 1), + ], + ) + def test_coalescing_behaviour(self, max_gap_size, ranges, expected_groups): + """Test coalescing with different max_gap_sizes and edge cases""" + stream_data = [b"x" * 200] # Enough data for all ranges + reader = create_dcp_s3reader(ranges, stream_data, max_gap_size) + assert len(reader._range_groups) == expected_groups + + @given(dcp_ranges_and_stream(), integers(min_value=5, max_value=25)) + def test_coalescing_behaviour_hypothesis(self, ranges_and_stream, max_gap_size): + """Check coalescing correctness for different inputs""" + ranges, stream_data = ranges_and_stream + assume(len(ranges) > 1) + + reader = create_dcp_s3reader(ranges, stream_data, max_gap_size=max_gap_size) + groups = reader._range_groups + + # All ranges in all groups are covered (and are in the same order) + covered_ranges = [r for group in groups for r in group.item_ranges] + assert covered_ranges == ranges + + # Groups separated by more than max_gap_size + for i in range(1, len(groups)): + gap = groups[i].start - groups[i - 1].end + assert gap > max_gap_size + + # ItemRanges within groups less than or equal to max_gap_size + for group in groups: + for i in range(1, len(group.item_ranges)): + gap = group.item_ranges[i].start - group.item_ranges[i - 1].end + assert gap <= max_gap_size + + def test_group_start_to_group_mapping(self): + """Test _group_start_to_group correctness, generated after _validate_and_coalesce_ranges method""" + ranges = [ItemRange(0, 10), ItemRange(50, 60), ItemRange(70, 80)] + reader = create_dcp_s3reader(ranges, [b"x" * 100], max_gap_size=15) + + # Should create 2 groups: [0-10] and [50-80] + assert len(reader._range_groups) == 2 + + # Check group starts + expected_group_starts = [0, 50] + actual_group_starts = [group.start for group in reader._range_groups] + assert actual_group_starts == expected_group_starts + + # Check group start mappings + assert set(reader._group_start_to_group.keys()) == {0, 50} + assert reader._group_start_to_group[0].start == 0 + assert reader._group_start_to_group[50].start == 50 + + # Check 70 is not group start (part of 2nd group, not a group start) + assert 70 not in reader._group_start_to_group + + +class TestStreamManagement: + """Tests for _get_stream_for_item and how stream and left data is managed within and between range groups""" + + def test_coalesced_ranges_stream_reuse(self): + """Test stream reuse within coalesced group vs separate streams for different groups""" + # 3 ranges: first 2 coalesce (gap=5 ≤ 10), third is separate (gap=85 > 10) + ranges = [ItemRange(0, 10), ItemRange(15, 25), ItemRange(110, 120)] + test_data = [b"0123456789-----abcdefghij" + b"x" * 85 + b"ABCDEFGHIJ"] + + stream_calls = [] + + def spy_get_stream(start=None, end=None): + stream_calls.append((start, end)) + return create_stream_getter(test_data)(start, end) + + reader = DCPOptimizedS3Reader( + TEST_BUCKET, + TEST_KEY, + ranges, + create_object_info_getter(test_data), + spy_get_stream, # type: ignore + max_gap_size=10, + ) + + # 2 groups: [0-25] and [110-120] + assert len(reader._range_groups) == 2 + assert reader._range_groups[0].start == 0 + assert reader._range_groups[0].end == 25 + assert reader._range_groups[1].start == 110 + assert reader._range_groups[1].end == 120 + + # Read from first group + reader.seek(0) + assert reader.read(10) == b"0123456789" + reader.seek(15) + assert reader.read(10) == b"abcdefghij" + # Only 1 stream call so far for both ItemRanges + assert stream_calls == [(0, 25)] + + # Read from second group + reader.seek(110) + assert reader.read(10) == b"ABCDEFGHIJ" + # 2 stream calls + assert stream_calls == [(0, 25), (110, 120)] + + def test_leftover_handling_with_chunks(self): + """Test leftover data handling across items with chunk boundaries""" + ranges = [ItemRange(0, 7), ItemRange(12, 18), ItemRange(21, 25)] + test_data = [b"0123456789ABCDEFGHIJabcdefghij"] + # Chunk size 5 - each next() iteration will return 5 bytes of data + reader = create_dcp_s3reader(ranges, test_data, max_gap_size=10, chunk_size=5) + + # Should coalesce into single group + assert len(reader._range_groups) == 1 + + # Read first item + reader.seek(0) + assert reader.read(7) == b"0123456" + assert reader._leftover + assert bytes(reader._leftover) == b"789" # bytes 8-10 + + # Read second item (should use leftover data) + reader.seek(12) + assert reader.read(6) == b"CDEFGH" + assert reader._leftover + assert bytes(reader._leftover) == b"IJ" # bytes 19-20 + + def test_get_stream_for_item_missing_stream_error(self): + """Test _get_stream_for_item error when stream is None for non-first item""" + ranges = [ItemRange(0, 10), ItemRange(15, 25)] + reader = create_dcp_s3reader(ranges, [b"x" * 30], max_gap_size=10) + + # Corrupt state: advance to second item without creating stream + reader._current_item = ranges[1] + reader._stream = None # force error (should not happen normally) + + with pytest.raises( + ValueError, match="without starting at the first item of its range-group" + ): + reader._get_stream_for_item(ranges[1]) + + +class TestReaderIO: + """Reader Interface (seek/read/readinto) and Sequential Access Tests""" + + @pytest.mark.parametrize( + "offset, whence, expected_error, error_msg", + [ + ("5", SEEK_SET, TypeError, "integer argument expected, got "), + (0, SEEK_END, ValueError, "whence must be SEEK_CUR or SEEK_SET integers"), + (-1, SEEK_SET, ValueError, "negative seek value -1"), + ], + ) + def test_seek_invalid_types(self, offset, whence, expected_error, error_msg): + """Test seek() parameter validation""" + reader = create_dcp_s3reader() + with pytest.raises(expected_error, match=error_msg): + reader.seek(offset, whence) + + @pytest.mark.parametrize( + "size, expected_error, error_msg", + [ + (None, ValueError, "Size cannot be None; full read not supported"), + (-1, ValueError, "Size cannot be negative; full read not supported"), + ("5", TypeError, "argument should be integer or None, not "), + ], + ) + def test_read_invalid_types(self, size, expected_error, error_msg): + """Test read() parameter validation""" + reader = create_dcp_s3reader() + with pytest.raises(expected_error, match=error_msg): + reader.read(size) + + def test_read_zero_size(self): + """Test read(0) returns empty bytes""" + reader = create_dcp_s3reader() + assert reader.read(0) == b"" + + @pytest.mark.parametrize( + "buf, expected_error", + [ + ("hello", TypeError), + (12345, TypeError), + ([1, 2, 3], TypeError), + (None, TypeError), + (memoryview(b"test"), AssertionError), # _ItemViewBuffer check + ], + ) + def test_readinto_invalid_types(self, buf, expected_error): + """Test readinto() parameter validation""" + reader = create_dcp_s3reader() + with pytest.raises(expected_error): + reader.readinto(buf) + + @pytest.mark.parametrize("buf", [bytearray(5), memoryview(bytearray(5))]) + def test_readinto_valid_types(self, buf): + """Test readinto() accepts valid buffer types""" + reader = create_dcp_s3reader() + + bytes_read = reader.readinto(buf) + assert bytes_read == 5 + assert bytes(buf) == b"01234" + + def test_sequential_access_enforcement(self): + """Test sequential access pattern enforcement""" + ranges = [ItemRange(0, 10), ItemRange(20, 30)] + stream_data = [b"0123456789" + b"x" * 10 + b"abcdefghij"] + reader = create_dcp_s3reader(ranges, stream_data) + + # Forward access should work + reader.seek(0) + assert reader.read(5) == b"01234" + assert reader.read(5) == b"56789" + + # Move to next item + reader.seek(20) + assert reader.read(5) == b"abcde" + + # Backward access to previous item should fail + reader.seek(5) + with pytest.raises(ValueError, match="Position 5 before current range 20-30"): + reader.read(1) + + def test_within_item_seeking(self): + """Test seeking within current item is allowed""" + ranges = [ItemRange(0, 20)] + reader = create_dcp_s3reader(ranges, [b"0123456789abcdefghij"]) + + reader.seek(5) + assert reader.read(5) == b"56789" + + # Seek backward within same item + reader.seek(2) + assert reader.read(3) == b"234" + + @pytest.mark.parametrize( + "pos, setup_reads, error_suffix", + [ + (5, [], "Position 5 before current range 10-20"), + (25, [10], "Position 25 in gap between ranges 10-20 and 30-40"), + (45, [10], "Position 45 beyond next range 30-40"), + (45, [10, 30], "Position 45 beyond last range 30-40"), + (15, [10, 30], "Position 15 before current range 30-40"), + ], + ) + def test_find_item_for_position_errors(self, pos, setup_reads, error_suffix): + """Test seeking to position outside valid ranges for _find_item_for_position""" + ranges = [ItemRange(10, 20), ItemRange(30, 40)] + reader = create_dcp_s3reader(ranges, [b"x" * 50]) + + # Setup reads to advance iterator state + for read_pos in setup_reads: + reader.seek(read_pos) + reader.read(1) + + reader.seek(pos) + expected_pattern = re.escape(FIND_ITEM_ERROR_PREFIX) + re.escape(error_suffix) + with pytest.raises(ValueError, match=expected_pattern): + reader.read(1) + + @pytest.mark.parametrize( + "ranges, read_pattern, expected_data", + [ + # Single range + ([ItemRange(0, 5)], [(0, 5)], [b"01234"]), + # 2 ranges + ( + [ItemRange(0, 5), ItemRange(10, 15)], + [(0, 3), (3, 2), (10, 5)], + [b"012", b"34", b"abcde"], + ), + # Should stop at item boundary + ( + [ItemRange(0, 5), ItemRange(10, 15)], + [(3, 5)], + [b"34"], + ), + ], + ) + def test_read_patterns(self, ranges, read_pattern, expected_data): + """Test various read patterns and item boundary behaviour""" + test_data = [b"0123456789abcdefghij"] + reader = create_dcp_s3reader(ranges, test_data) + + results = [] + for pos, size in read_pattern: + reader.seek(pos) + data = reader.read(size) + results.append(data) + + assert results == expected_data + + @given(dcp_ranges_and_stream()) + def test_sequential_io_hypothesis(self, ranges_and_stream): + """Quick integration test to read all ItemRanges sequentially""" + ranges, stream_data = ranges_and_stream + assume(len(ranges) > 0) + + reader = create_dcp_s3reader(ranges, stream_data) + + for range_item in ranges: + reader.seek(range_item.start) + range_size = range_item.end - range_item.start + data = reader.read(range_size) + + assert len(data) == range_size + assert reader.tell() == range_item.end + + def test_close(self): + """Test close() behaviour""" + reader = create_dcp_s3reader() + reader.close() + + assert reader.closed + assert reader._stream is None + assert reader._leftover is None + assert reader._current_item_buffer is None