Skip to content

Commit 4bf572f

Browse files
abhishek002002Orbax Authors
authored andcommitted
Add PyTorchLayout for loading PyTorch checkpoints.
PiperOrigin-RevId: 821011541
1 parent c6b4b10 commit 4bf572f

File tree

3 files changed

+466
-0
lines changed

3 files changed

+466
-0
lines changed
Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
# Copyright 2025 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Defines `PyTorchLayout` for loading PyTorch checkpoint files."""
16+
17+
import asyncio
18+
import dataclasses
19+
import io
20+
import os
21+
import pickle
22+
from typing import Any, Awaitable
23+
import zipfile
24+
25+
from absl import logging
26+
import jax
27+
import numpy as np
28+
from orbax.checkpoint._src.path import async_path
29+
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
30+
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
31+
from orbax.checkpoint.experimental.v1._src.path import types
32+
33+
34+
CheckpointLayout = checkpoint_layout.CheckpointLayout
35+
InvalidLayoutError = checkpoint_layout.InvalidLayoutError
36+
Path = types.Path
37+
38+
39+
_PICKLE_FILENAME = "data.pkl"
40+
_STORAGE_PREFIX = "data"
41+
42+
# Maps torch.dtype to an equivalent numpy dtype.
43+
_TORCH_TO_NP_DTYPE = {
44+
"torch.float16": np.float16,
45+
"torch.float32": np.float32,
46+
"torch.float64": np.float64,
47+
# JAX's numpy supports bfloat16, but we use a string to avoid a direct
48+
# dependency on a specific numpy implementation having np.bfloat16.
49+
"torch.bfloat16": "bfloat16",
50+
"torch.uint8": np.uint8,
51+
"torch.int8": np.int8,
52+
"torch.int16": np.int16,
53+
"torch.int32": np.int32,
54+
"torch.int64": np.int64,
55+
"torch.bool": np.bool_,
56+
"torch.complex64": np.complex64,
57+
"torch.complex128": np.complex128,
58+
# Map quantized types to their numpy equivalents. Note that this loses
59+
# quantization information (scale and zero-point).
60+
"torch.qint8": np.int8,
61+
"torch.quint8": np.uint8,
62+
"torch.qint32": np.int32,
63+
}
64+
65+
66+
def _parse_storage_pid(pid: Any) -> tuple[Any, str]:
67+
"""Parses a PyTorch storage persistent ID.
68+
69+
Args:
70+
pid: The persistent id.
71+
72+
Returns:
73+
A tuple of (storage_type, key).
74+
75+
Raises:
76+
pickle.UnpicklingError: If the pid is not a valid storage pid.
77+
"""
78+
# pid is typically a tuple like:
79+
# ('storage', torch.LongStorage, '0', 'cpu', 8)
80+
if not isinstance(pid, tuple) or pid[0] != "storage":
81+
raise pickle.UnpicklingError(f"Unsupported persistent id object: {pid}")
82+
storage_type, key = pid[1], pid[2]
83+
return storage_type, key
84+
85+
86+
class CustomTorchUnpickler(pickle.Unpickler):
87+
"""An unpickler that can handle PyTorch's 'storage' persistent IDs.
88+
89+
by looking up data in an externally provided dictionary of bytes.
90+
"""
91+
92+
def __init__(
93+
self,
94+
file: io.BytesIO,
95+
storage_data: dict[str, bytes],
96+
):
97+
super().__init__(file)
98+
self._storage_data = storage_data
99+
100+
def persistent_load(self, pid: Any) -> Any:
101+
"""Handles persistent load calls encountered during unpickling."""
102+
storage_type, key = _parse_storage_pid(pid)
103+
if key not in self._storage_data:
104+
raise pickle.UnpicklingError(
105+
f"Storage key '{key}' not found in checkpoint archive."
106+
)
107+
108+
storage_bytes = self._storage_data[key]
109+
return storage_type.from_buffer(storage_bytes, "little")
110+
111+
112+
@dataclasses.dataclass
113+
class _StorageMetadata:
114+
"""A placeholder for torch.Storage metadata, containing only the dtype."""
115+
116+
dtype: str
117+
118+
def __init__(self, dtype: str):
119+
self.dtype = dtype
120+
121+
122+
def _rebuild_tensor_as_sds(
123+
storage: Any,
124+
storage_offset: int,
125+
size: tuple[int, ...],
126+
stride: tuple[int, ...],
127+
requires_grad: bool = False,
128+
backward_hooks: Any = (),
129+
) -> jax.ShapeDtypeStruct:
130+
"""Pickle reduction function to rebuild a tensor as a ShapeDtypeStruct."""
131+
del storage_offset, stride, requires_grad, backward_hooks # Unused.
132+
if not isinstance(storage, _StorageMetadata):
133+
# This error indicates that the unpickler's persistent_load did not return
134+
# the expected placeholder. This can happen with unsupported PyTorch
135+
# versions or corrupted files.
136+
raise pickle.UnpicklingError(
137+
"Expected to find _StorageMetadata, but got"
138+
f" {type(storage).__name__}. This may indicate an unsupported PyTorch"
139+
" version."
140+
)
141+
if storage.dtype not in _TORCH_TO_NP_DTYPE:
142+
raise pickle.UnpicklingError(
143+
f"Unsupported torch dtype for conversion to numpy: {storage.dtype}"
144+
)
145+
numpy_dtype = np.dtype(_TORCH_TO_NP_DTYPE[storage.dtype])
146+
return jax.ShapeDtypeStruct(shape=tuple(size), dtype=numpy_dtype)
147+
148+
149+
class MetadataUnpickler(pickle.Unpickler):
150+
"""An unpickler that reconstructs tensors as ShapeDtypeStructs."""
151+
152+
def find_class(self, module: str, name: str) -> Any:
153+
"""Overrides class lookup to intercept tensor creation."""
154+
if (module == "torch._utils" and name == "_rebuild_tensor_v2") or (
155+
module == "torch" and name == "_rebuild_tensor"
156+
):
157+
return _rebuild_tensor_as_sds
158+
return super().find_class(module, name)
159+
160+
def persistent_load(self, pid: Any) -> Any:
161+
"""Handles persistent load calls for torch.Storage."""
162+
storage_type, _ = _parse_storage_pid(pid)
163+
# For metadata, we only need the dtype from the storage type.
164+
return _StorageMetadata(dtype=str(storage_type.dtype))
165+
166+
167+
def _unpickle_metadata_sync(pickle_bytes: bytes) -> Any:
168+
"""Unpickles metadata using MetadataUnpickler."""
169+
data_stream = io.BytesIO(pickle_bytes)
170+
unpickler = MetadataUnpickler(data_stream)
171+
return unpickler.load()
172+
173+
174+
def _read_zip_contents_sync(path: Path) -> tuple[bytes, dict[str, bytes]]:
175+
"""Sync helper for `_read_zip_contents`."""
176+
pickle_bytes = None
177+
storage_data = {}
178+
with zipfile.ZipFile(path, "r") as zf:
179+
for name in zf.namelist():
180+
if name.endswith(_PICKLE_FILENAME):
181+
pickle_bytes = zf.read(name)
182+
elif os.path.basename(os.path.dirname(name)) == _STORAGE_PREFIX:
183+
storage_id = os.path.basename(name)
184+
# Accommodate different key formats. Some PyTorch versions may use
185+
# storage keys with underscores.
186+
if storage_id.isdigit() or "_" in storage_id:
187+
storage_data[storage_id] = zf.read(name)
188+
if pickle_bytes is None:
189+
raise FileNotFoundError(f"{_PICKLE_FILENAME} not found in {path}")
190+
return pickle_bytes, storage_data
191+
192+
193+
async def _read_zip_contents(path: Path) -> tuple[bytes, dict[str, bytes]]:
194+
"""Reads pickle data and all storage files from a PyTorch zip archive."""
195+
return await asyncio.to_thread(_read_zip_contents_sync, path)
196+
197+
198+
def _structure_to_numpy(pytorch_data: Any) -> Any:
199+
"""Converts torch.Tensors in pytorch_data to NumPy arrays."""
200+
201+
def _to_numpy(leaf: Any) -> Any:
202+
if hasattr(leaf, "numpy"):
203+
return leaf.numpy()
204+
return leaf
205+
206+
return jax.tree.map(_to_numpy, pytorch_data)
207+
208+
209+
def _load_pytorch_on_device(
210+
pytorch_data: Any,
211+
abstract_pytree: Any,
212+
) -> Any:
213+
"""Loads tensors from pytorch_data into on-device JAX arrays based on abstract_pytree."""
214+
215+
def _load_leaf(leaf: Any, abstract_leaf: Any) -> jax.Array:
216+
if not hasattr(leaf, "numpy"):
217+
raise ValueError(
218+
"Item in PyTorch checkpoint is not a tensor-like object with a"
219+
" 'numpy' method or is missing from the checkpoint."
220+
)
221+
222+
sharding = abstract_leaf.sharding
223+
target_shape = abstract_leaf.shape
224+
target_dtype = abstract_leaf.dtype
225+
226+
device_indices_map = sharding.addressable_devices_indices_map(target_shape)
227+
device_arrays = []
228+
for device in device_indices_map:
229+
idx = device_indices_map[device]
230+
shard_tensor = leaf[idx]
231+
shard_np = shard_tensor.numpy()
232+
if shard_np.dtype != target_dtype:
233+
shard_np = shard_np.astype(target_dtype)
234+
device_arrays.append(jax.device_put(shard_np, device))
235+
236+
return jax.make_array_from_single_device_arrays(
237+
target_shape, sharding, device_arrays
238+
)
239+
240+
return jax.tree.map(_load_leaf, pytorch_data, abstract_pytree)
241+
242+
243+
def _unpickle_structure_sync(
244+
pickle_bytes: bytes, storage_data: dict[str, bytes]
245+
) -> Any:
246+
"""Unpickles the structure using CustomTorchUnpickler."""
247+
data_stream = io.BytesIO(pickle_bytes)
248+
unpickler = CustomTorchUnpickler(data_stream, storage_data)
249+
return unpickler.load()
250+
251+
252+
async def _load_pytorch(
253+
path: Path, abstract_pytree: dict[str, Any] | None = None
254+
) -> dict[str, Any]:
255+
"""Loads pytorch checkpoint as numpy arrays or sharded jax arrays."""
256+
pickle_bytes, storage_data = await _read_zip_contents(path)
257+
258+
pytorch_data = await asyncio.to_thread(
259+
_unpickle_structure_sync, pickle_bytes, storage_data
260+
)
261+
262+
if abstract_pytree is None:
263+
# Return NumPy arrays.
264+
restored_pytree = _structure_to_numpy(pytorch_data)
265+
else:
266+
# Return on-device JAX arrays.
267+
restored_pytree = _load_pytorch_on_device(pytorch_data, abstract_pytree)
268+
269+
return {checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY: restored_pytree}
270+
271+
272+
class PyTorchLayout(CheckpointLayout):
273+
"""Layout for loading PyTorch checkpoints (.pt, .pth).
274+
275+
Uses zipfile and a custom unpickler to handle torch.Tensors
276+
without calling torch.load().
277+
"""
278+
279+
def __init__(self, path: Path):
280+
self._path = path
281+
282+
@property
283+
def path(self) -> Path:
284+
"""Returns the path of the PyTorch checkpoint file."""
285+
return self._path
286+
287+
def _check_zip_structure(self):
288+
"""Sync helper to check zip file contents."""
289+
try:
290+
with zipfile.ZipFile(self._path, "r") as zf:
291+
if not any(name.endswith(_PICKLE_FILENAME) for name in zf.namelist()):
292+
raise InvalidLayoutError(
293+
f"'{self._path}' is not a valid PyTorch zip archive"
294+
" (missing data.pkl)."
295+
)
296+
except zipfile.BadZipFile as e:
297+
raise InvalidLayoutError(
298+
f"'{self._path}' is not a valid ZIP file."
299+
) from e
300+
301+
async def validate(self) -> None:
302+
"""Checks if the path is a file and a valid PyTorch ZIP archive."""
303+
if not await async_path.is_file(self._path):
304+
raise InvalidLayoutError(f"Path is not a file: {self._path}")
305+
if self._path.suffix not in [".pt", ".pth"]:
306+
logging.warning(
307+
"File %s lacks .pt or .pth suffix but attempting to "
308+
"load as PyTorch checkpoint.",
309+
self._path,
310+
)
311+
try:
312+
await asyncio.to_thread(self._check_zip_structure)
313+
except InvalidLayoutError as e:
314+
raise e
315+
except OSError as e:
316+
raise InvalidLayoutError(
317+
f"Failed to validate {self._path} as PyTorch checkpoint: {e}"
318+
) from e
319+
320+
async def validate_pytree(self, checkpointable_name: str | None) -> None:
321+
"""No-op, as PyTorchLayout treats the entire file as the 'pytree' item."""
322+
return
323+
324+
async def metadata(self) -> metadata_types.CheckpointMetadata[dict[str, Any]]:
325+
"""Extracts ShapeDtypeStruct metadata without loading tensor data."""
326+
pickle_bytes, _ = await _read_zip_contents(self._path)
327+
metadata_tree = await asyncio.to_thread(
328+
_unpickle_metadata_sync, pickle_bytes
329+
)
330+
stat_result = await asyncio.to_thread(os.stat, self._path)
331+
commit_timestamp_nsecs = stat_result.st_mtime_ns
332+
333+
return metadata_types.CheckpointMetadata[dict[str, Any]](
334+
metadata={checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY: metadata_tree},
335+
commit_timestamp_nsecs=commit_timestamp_nsecs,
336+
)
337+
338+
async def load(
339+
self,
340+
abstract_checkpointables: dict[str, Any] | None = None,
341+
) -> Awaitable[dict[str, Any]]:
342+
"""Loads a PyTorch checkpoint file.
343+
344+
If abstract_checkpointables are provided, it attempts to load tensors as
345+
sharded jax.Arrays onto devices. Otherwise, it loads tensors as host
346+
NumPy arrays.
347+
348+
Args:
349+
abstract_checkpointables: An optional PyTree of abstract arrays specifying
350+
sharding information.
351+
352+
Returns:
353+
An awaitable of a dictionary containing the loaded PyTree.
354+
"""
355+
abstract_pytree = None
356+
if abstract_checkpointables:
357+
abstract_pytree = abstract_checkpointables.get(
358+
checkpoint_layout.PYTREE_CHECKPOINTABLE_KEY
359+
)
360+
return _load_pytorch(self._path, abstract_pytree)

0 commit comments

Comments
 (0)