Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
a78932d
feat: new ListOfRangesS3Reader for DCP partial reading
jet-tong Sep 16, 2025
5bb2c7c
feat(dcp): use constructor pattern for ListOfRanges optimization
jet-tong Oct 7, 2025
af01848
fix: resolve mypy errors and minor logic and name changes
jet-tong Oct 9, 2025
5fc014c
perf(s3reader): optimize for sequential DCP workloads
jet-tong Oct 14, 2025
da73d3e
refactor: rename reader as dcp optimized and add max_gap_size support
jet-tong Oct 16, 2025
39fa629
refactor: rename list_of_ranges.py as dcp_optimized.py
jet-tong Oct 16, 2025
a515b4a
refactor: use one default max gap size across classes
jet-tong Oct 16, 2025
68165e6
test(dcp): update dcp e2e tests with DCPOptimizedS3Reader
jet-tong Oct 16, 2025
531b22b
feat: make DCPOptimizedS3Reader seekable for PyTorch backwards compat…
jet-tong Oct 17, 2025
5e1322b
refactor: simplify DCPOptimizedS3Reader to use single active stream
jet-tong Oct 17, 2025
cc5bfc9
refactor(dcp): refactor core functions and add docstrings and error m…
jet-tong Oct 20, 2025
463a1f1
fix(dcp): use filename only for file range key
jet-tong Oct 20, 2025
c985b26
refactor: rename RangeRequest as ItemRange
jet-tong Oct 20, 2025
5720907
fix(dcp): use os to extract basename instead of split
jet-tong Oct 20, 2025
ede4d8d
fix(dcp): improve error handling and validation
jet-tong Oct 21, 2025
b83284b
refactor: address github comments
jet-tong Oct 21, 2025
5ff5fdc
refactor(dcp): improve naming and typing for DCP optimized reader con…
jet-tong Oct 21, 2025
a9ed023
fix(s3reader): use Python 3.9 compatible types and remove redundant c…
jet-tong Oct 21, 2025
bd6f274
refactor: cleanup imports, styling and comments
jet-tong Oct 21, 2025
f9a90b8
fix: move DCP imports under TYPE_CHECKING
jet-tong Oct 22, 2025
f5e810a
refactor(s3reader): switch from index to iterator based state management
jet-tong Oct 24, 2025
3c25552
per(dcp): minor group lookup optimization with dict
jet-tong Oct 24, 2025
7226079
refactor: improve item buffer method and support closing
jet-tong Oct 24, 2025
3a05cb5
fix(s3reader): resolve stream pos desync when skipping coalescing bytes
jet-tong Oct 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD

import dataclasses
import io
import logging
import os
import urllib.parse
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Generator, Union, Optional
from typing import List
from typing import Generator, Union, Optional, List

import torch
from s3torchconnectorclient._mountpoint_s3_client import S3Exception
from tenacity import (
retry,
Expand All @@ -24,11 +26,18 @@
FileSystemWriter,
FileSystemBase,
)
import torch
from torch.distributed.checkpoint.planner import SavePlan, LoadPlan


from s3torchconnector._s3client import S3Client
from s3torchconnector._s3dataset_common import parse_s3_uri
from ..s3reader import S3ReaderConstructor, S3ReaderConstructorProtocol
from ..s3reader import (
S3ReaderConstructor,
DCPOptimizedConstructor,
S3ReaderConstructorProtocol,
DCPS3ReaderConstructorProtocol,
ItemRange,
)
from .. import S3ClientConfig
from .s3_prefix_strategy import S3PrefixStrategyBase, DefaultPrefixStrategy
from .._user_agent import UserAgent
Expand Down Expand Up @@ -252,11 +261,6 @@ def _escape_path(string):
return "/".join(parts)


from torch.distributed.checkpoint.planner import SavePlan, LoadPlan
import dataclasses
from dataclasses import dataclass


@dataclass
class StorageMetadata:
"""Metadata for S3 storage prefix."""
Expand Down Expand Up @@ -324,7 +328,9 @@ def __init__(
region: str,
path: Union[str, os.PathLike],
s3client_config: Optional[S3ClientConfig] = None,
reader_constructor: Optional[S3ReaderConstructorProtocol] = None,
reader_constructor: Optional[
Union[S3ReaderConstructorProtocol, DCPS3ReaderConstructorProtocol]
] = None,
) -> None:
"""
Initialize an S3 reader for distributed checkpointing.
Expand All @@ -337,7 +343,12 @@ def __init__(
e.g. S3ReaderConstructor.sequential() or S3ReaderConstructor.range_based()
"""
super().__init__(path)
self.fs = S3FileSystem(region, s3client_config=s3client_config, reader_constructor=reader_constructor) # type: ignore
self._reader_constructor = reader_constructor or S3ReaderConstructor.default()
self.fs: S3FileSystem = S3FileSystem( # type: ignore[assignment]
region,
s3client_config=s3client_config,
reader_constructor=self._reader_constructor,
)
self.path = self.fs.init_path(path)
self.sync_files = False

Expand All @@ -357,6 +368,13 @@ def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
"""
# Sort items in plan based on their offset in checkpoints shards
plan.items.sort(key=lambda item: self.storage_data[item.storage_index].offset)

# Inject ranges if using DCP optimized reader constructor
if isinstance(self._reader_constructor, DCPS3ReaderConstructorProtocol):
self._reader_constructor.set_item_ranges_by_file(
plan.items, self.storage_data
)

return plan


Expand Down
10 changes: 8 additions & 2 deletions s3torchconnector/src/s3torchconnector/s3reader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@
# // SPDX-License-Identifier: BSD

from .s3reader import S3Reader
from .constructor import S3ReaderConstructor
from .constructor import S3ReaderConstructor, DCPOptimizedConstructor
from .sequential import SequentialS3Reader
from .ranged import RangedS3Reader
from .protocol import GetStreamCallable, S3ReaderConstructorProtocol
from .dcp_optimized import DCPOptimizedS3Reader, ItemRange, RangeGroup
from .protocol import (
GetStreamCallable,
S3ReaderConstructorProtocol,
DCPS3ReaderConstructorProtocol,
)

__all__ = [
"S3Reader",
"S3ReaderConstructor",
"SequentialS3Reader",
"RangedS3Reader",
"DCPOptimizedS3Reader",
]
70 changes: 64 additions & 6 deletions s3torchconnector/src/s3torchconnector/s3reader/constructor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,61 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD

import os
from functools import partial
from typing import Optional

from .protocol import S3ReaderConstructorProtocol
from typing import TYPE_CHECKING, Optional, List, Dict, Any
from collections import defaultdict

from .s3reader import S3Reader
from .protocol import (
S3ReaderConstructorProtocol,
DCPS3ReaderConstructorProtocol,
)
from .sequential import SequentialS3Reader
from .ranged import RangedS3Reader
from .dcp_optimized import DCPOptimizedS3Reader, ItemRange, DEFAULT_MAX_GAP_SIZE

if TYPE_CHECKING:
from torch.distributed.checkpoint.planner import ReadItem
from torch.distributed.checkpoint.metadata import MetadataIndex
from torch.distributed.checkpoint.filesystem import _StorageInfo


class DCPOptimizedConstructor:
def __init__(self, max_gap_size: int = DEFAULT_MAX_GAP_SIZE) -> None:
self._item_ranges_by_file: Dict[str, List[ItemRange]] = {}
self._max_gap_size = max_gap_size

def set_item_ranges_by_file(
self,
plan_items: "List[ReadItem]",
storage_data: "Dict[MetadataIndex, _StorageInfo]",
) -> None:
# TODO: Check if we want to return DCPOptimizedConstructor for immutability here instead
if not plan_items:
return
self._item_ranges_by_file = defaultdict(list)
for read_item in plan_items:
item_md = storage_data[read_item.storage_index]
# TODO: write test to check using filename works with S3PrefixStrategy
filename = os.path.basename(item_md.relative_path)
self._item_ranges_by_file[filename].append(
ItemRange(item_md.offset, item_md.offset + item_md.length)
)

def __call__(self, bucket: str, key: str, get_object_info, get_stream) -> S3Reader:
filename = os.path.basename(key)
if filename in self._item_ranges_by_file:
return DCPOptimizedS3Reader(
bucket,
key,
item_ranges=self._item_ranges_by_file[filename],
get_object_info=get_object_info,
get_stream=get_stream,
max_gap_size=self._max_gap_size,
)
# Fallback if file_ranges unavailable (e.g. when reading .metadata)
return SequentialS3Reader(bucket, key, get_object_info, get_stream)


class S3ReaderConstructor:
Expand Down Expand Up @@ -78,6 +127,14 @@ def range_based(buffer_size: Optional[int] = None) -> S3ReaderConstructorProtoco
"""
return partial(RangedS3Reader, buffer_size=buffer_size)

@staticmethod
def dcp_optimized(
max_gap_size: int = DEFAULT_MAX_GAP_SIZE,
) -> DCPS3ReaderConstructorProtocol:
"""Creates a DCPOptimizedConstructor that uses DCPOptimizedS3Reader when ranges are available"""
# TODO update docstring with guide and requirements to use this reader for DCP
return DCPOptimizedConstructor(max_gap_size=max_gap_size)

@staticmethod
def default() -> S3ReaderConstructorProtocol:
"""Creates default reader constructor (sequential)
Expand All @@ -97,10 +154,11 @@ def get_reader_type_string(
S3ReaderConstructor.default()
)

if not isinstance(constructor, partial):
if isinstance(constructor, DCPOptimizedConstructor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here - this feels pretty janky to me. What's this used for? Just debugging or to actually do something based on it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User agent - agree this still feels janky.

return "dcp_optimized"
elif not isinstance(constructor, partial):
return "unknown"

if constructor.func == RangedS3Reader:
elif constructor.func == RangedS3Reader:
return "range_based"
elif constructor.func == SequentialS3Reader:
return "sequential"
Expand Down
Loading