Skip to content

Commit f5970e5

Browse files
lukebaumanncopybara-github
authored andcommitted
Support jax.random.PRNGKey serialization in Pathways Orbax handler.
This change allows `CloudPathwaysArrayHandler` to correctly save and restore `jax.random.PRNGKey` objects by extracting and wrapping the key data, and storing metadata about the key implementation using an `ArrayMetadataStore`. This change introduces a dependency on Orbax's internal API. PiperOrigin-RevId: 813796155
1 parent c9fb204 commit f5970e5

File tree

2 files changed

+98
-12
lines changed

2 files changed

+98
-12
lines changed

pathwaysutils/_initialize.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919

2020
import jax
21+
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
2122
from pathwaysutils import profiling
2223
from pathwaysutils import proxy_backend
2324
from pathwaysutils.persistence import orbax_handler
@@ -94,6 +95,7 @@ def initialize() -> None:
9495
if _is_persistence_enabled():
9596
orbax_handler.register_pathways_handlers(
9697
timeout=datetime.timedelta(hours=1),
98+
array_metadata_store=array_metadata_store_lib.Store(),
9799
)
98100

99101
# Turn off JAX compilation cache because Pathways handles its own

pathwaysutils/persistence/orbax_handler.py

Lines changed: 96 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,18 @@
1414
"""TypeHandlers supporting Pathways backend."""
1515

1616
import collections
17-
from collections.abc import Sequence
17+
from collections.abc import Coroutine, Sequence
1818
import concurrent.futures
1919
import datetime
2020
import functools
2121
import logging
22-
import typing
22+
from typing import Any, cast
2323

2424
import jax
2525
from orbax.checkpoint import future
2626
from orbax.checkpoint import type_handlers
27+
from orbax.checkpoint._src.metadata import array_metadata as array_metadata_lib
28+
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
2729
from pathwaysutils.persistence import helper
2830

2931

@@ -33,6 +35,7 @@
3335
SaveArgs = type_handlers.SaveArgs
3436
RestoreArgs = type_handlers.RestoreArgs
3537
ArrayRestoreArgs = type_handlers.ArrayRestoreArgs
38+
ArrayMetadata = array_metadata_lib.ArrayMetadata
3639

3740

3841
def extract_parent_dir_and_name(
@@ -51,26 +54,33 @@ def __init__(
5154
self,
5255
timeout: datetime.timedelta | None = None,
5356
use_ocdbt: bool = False,
57+
array_metadata_store: array_metadata_store_lib.Store | None = None,
5458
):
5559
"""Orbax array handler for Pathways on Cloud with Persistence API.
5660
5761
Args:
5862
timeout: Duration indicating the timeout for reading and writing arrays.
5963
Default is 1 hour.
6064
use_ocdbt: allows using Tensorstore OCDBT driver.
65+
array_metadata_store: An optional store for writing and reading array
66+
metadata. Only required for saving new-style jax random keys.
6167
"""
6268
if timeout is None:
6369
timeout = datetime.timedelta(hours=1)
6470
self.timeout = timeout
6571

6672
if use_ocdbt:
6773
raise ValueError("OCDBT not supported for Pathways.")
68-
super().__init__()
74+
super().__init__(array_metadata_store=array_metadata_store)
6975

7076
async def _background_serialize(
7177
self,
7278
futures_results: Sequence[concurrent.futures.Future[None]],
79+
metadata_coroutine: Coroutine[Any, Any, None] | None = None,
7380
) -> None:
81+
if metadata_coroutine:
82+
await metadata_coroutine
83+
7484
for future_result in futures_results:
7585
future_result.result()
7686

@@ -86,21 +96,61 @@ async def serialize(
8696
values: Sequence[jax.Array],
8797
infos: Sequence[ParamInfo],
8898
args: Sequence[SaveArgs] | None = None,
89-
) -> Sequence[future.Future]:
99+
) -> list[future.Future]:
90100
"""Uses Pathways Persistence API to serialize a jax array."""
91101
type_handlers.check_input_arguments(values, infos, args)
92102

93103
if any([arg.dtype is not None for arg in args]):
94104
raise ValueError("Casting during save not supported for Pathways.")
95105

106+
array_metadatas = []
107+
any_random_key = False
108+
arrays = []
109+
for v, info, arg in zip(values, infos, args):
110+
ext_metadata = None
111+
if jax.dtypes.issubdtype(v.dtype, jax.dtypes.prng_key):
112+
any_random_key = True
113+
v = jax.random.key_data(v)
114+
ext_metadata = {
115+
array_metadata_lib.RANDOM_KEY_IMPL: str(jax.random.key_impl(v))
116+
}
117+
118+
array_metadatas.append(
119+
ArrayMetadata(
120+
param_name=info.name,
121+
shape=v.shape,
122+
dtype=(arg.dtype if arg is not None else v.dtype),
123+
write_shape=getattr(v, "local_shape", v.shape),
124+
chunk_shape=getattr(v, "local_shape", v.shape),
125+
use_ocdbt=False,
126+
use_zarr3=False,
127+
ext_metadata=ext_metadata,
128+
)
129+
)
130+
arrays.append(v)
131+
132+
metadata_coroutine = None
133+
if any_random_key:
134+
if self._array_metadata_store is None:
135+
raise ValueError(
136+
"Array metadata store is not set with a checkpoint that requires"
137+
f" it. Array metadata: {array_metadatas}"
138+
)
139+
140+
metadata_coroutine = self._array_metadata_store.write(
141+
checkpoint_dir=infos[0].parent_dir,
142+
array_metadatas=array_metadatas,
143+
process_index=0,
144+
)
145+
96146
self._wait_for_directory_creation_signals()
97147
locations, names = extract_parent_dir_and_name(infos)
98148
f = functools.partial(helper.write_one_array, timeout=self.timeout)
99-
futures_results = list(map(f, locations, names, values))
149+
futures_results = list(map(f, locations, names, arrays))
100150

101151
return [
102152
future.CommitFutureAwaitingContractedSignals(
103-
self._background_serialize(futures_results),
153+
self._background_serialize(futures_results, metadata_coroutine),
104154
name="cloud_pathways_array_handler",
105155
)
106156
]
@@ -109,7 +159,7 @@ async def deserialize(
109159
self,
110160
infos: Sequence[ParamInfo],
111161
args: Sequence[RestoreArgs] | None = None,
112-
) -> Sequence[jax.Array]:
162+
) -> list[jax.Array]:
113163
"""Uses Pathways Persistence API to deserialize a jax array."""
114164
if args is None:
115165
raise ValueError("Must provide ArrayRestoreArgs to restore as jax.Array.")
@@ -128,7 +178,7 @@ async def deserialize(
128178
"To restore jax.Array, provide ArrayRestoreArgs; found"
129179
f" {type(arg).__name__}"
130180
)
131-
arg = typing.cast(ArrayRestoreArgs, arg)
181+
arg = cast(ArrayRestoreArgs, arg)
132182
if arg.sharding is None and (arg.mesh is None or arg.mesh_axes is None):
133183
raise ValueError(
134184
"Sharding of jax.Array cannot be None. Provide `mesh`"
@@ -143,7 +193,7 @@ async def deserialize(
143193
else:
144194
if not isinstance(arg.sharding, jax.sharding.NamedSharding):
145195
raise ValueError("Pathways only supports jax.sharding.NamedSharding.")
146-
sharding = typing.cast(jax.sharding.NamedSharding, arg.sharding)
196+
sharding = cast(jax.sharding.NamedSharding, arg.sharding)
147197
global_meshes.append(sharding.mesh)
148198
mesh_axes.append(sharding.spec)
149199
shardings.append(sharding)
@@ -163,13 +213,30 @@ async def deserialize(
163213
]
164214
dtypes = [m.dtype if d is None else d for m, d in zip(metadatas, dtypes)]
165215

216+
array_metadatas_cache = {}
217+
if self._array_metadata_store is not None:
218+
if array_metadatas := await self._array_metadata_store.read(
219+
checkpoint_dir=infos[0].parent_dir,
220+
process_index=0,
221+
):
222+
if not isinstance(array_metadatas, list):
223+
raise ValueError(
224+
"Array metadata store returned unexpected result:"
225+
f" {array_metadatas}"
226+
)
227+
228+
array_metadatas_cache = {
229+
array_metadata.param_name: array_metadata
230+
for array_metadata in array_metadatas
231+
}
232+
166233
# Group inputs by global_mesh so that we can perform batched Array
167234
# construction for each global_mesh.
168235
inputs_by_global_mesh = collections.defaultdict(list)
169236
for i, global_mesh in enumerate(global_meshes):
170237
inputs_by_global_mesh[global_mesh].append(i)
171238

172-
results = [None] * len(infos)
239+
results = cast(list[jax.Array], [None] * len(infos))
173240

174241
for global_mesh, idxs in inputs_by_global_mesh.items():
175242
grouped_infos = [infos[idx] for idx in idxs]
@@ -188,13 +255,29 @@ async def deserialize(
188255
)
189256
# each persistence call is awaited serially.
190257
read_future.result()
191-
for idx, arr in zip(idxs, grouped_arrays):
258+
for idx, info, arr in zip(idxs, grouped_infos, grouped_arrays):
259+
if meta := array_metadatas_cache.get(info.name):
260+
assert isinstance(
261+
meta, array_metadata_lib.SerializedArrayMetadata
262+
), f"Expecting SerializedArrayMetadata but got {type(meta)}."
263+
if meta.ext_metadata:
264+
assert isinstance(meta.ext_metadata, dict), (
265+
"Expecting ext_metadata to be a dict but got"
266+
f" {type(meta.ext_metadata)}."
267+
)
268+
269+
if impl := meta.ext_metadata.get(
270+
array_metadata_lib.RANDOM_KEY_IMPL
271+
):
272+
arr = jax.random.wrap_key_data(arr, impl=impl)
192273
results[idx] = arr
193-
return results # pytype: disable=bad-return-type
274+
275+
return results
194276

195277

196278
def register_pathways_handlers(
197279
timeout: datetime.timedelta | None = None,
280+
array_metadata_store: array_metadata_store_lib.Store | None = None,
198281
):
199282
"""Function that must be called before saving or restoring with Pathways."""
200283
logger.debug(
@@ -204,6 +287,7 @@ def register_pathways_handlers(
204287
jax.Array,
205288
CloudPathwaysArrayHandler(
206289
timeout=timeout,
290+
array_metadata_store=array_metadata_store,
207291
),
208292
override=True,
209293
)

0 commit comments

Comments
 (0)