1414"""TypeHandlers supporting Pathways backend."""
1515
1616import collections
17- from collections .abc import Sequence
17+ from collections .abc import Coroutine , Sequence
1818import concurrent .futures
1919import datetime
2020import functools
2121import logging
22- import typing
22+ from typing import Any , cast
2323
2424import jax
2525from orbax .checkpoint import future
2626from orbax .checkpoint import type_handlers
27+ from orbax .checkpoint ._src .metadata import array_metadata as array_metadata_lib
28+ from orbax .checkpoint ._src .metadata import array_metadata_store as array_metadata_store_lib
2729from pathwaysutils .persistence import helper
2830
2931
3335SaveArgs = type_handlers .SaveArgs
3436RestoreArgs = type_handlers .RestoreArgs
3537ArrayRestoreArgs = type_handlers .ArrayRestoreArgs
38+ ArrayMetadata = array_metadata_lib .ArrayMetadata
3639
3740
3841def extract_parent_dir_and_name (
@@ -51,26 +54,33 @@ def __init__(
5154 self ,
5255 timeout : datetime .timedelta | None = None ,
5356 use_ocdbt : bool = False ,
57+ array_metadata_store : array_metadata_store_lib .Store | None = None ,
5458 ):
5559 """Orbax array handler for Pathways on Cloud with Persistence API.
5660
5761 Args:
5862 timeout: Duration indicating the timeout for reading and writing arrays.
5963 Default is 1 hour.
6064 use_ocdbt: allows using Tensorstore OCDBT driver.
65+ array_metadata_store: An optional store for writing and reading array
66+ metadata. Only required for saving new-style jax random keys.
6167 """
6268 if timeout is None :
6369 timeout = datetime .timedelta (hours = 1 )
6470 self .timeout = timeout
6571
6672 if use_ocdbt :
6773 raise ValueError ("OCDBT not supported for Pathways." )
68- super ().__init__ ()
74+ super ().__init__ (array_metadata_store = array_metadata_store )
6975
7076 async def _background_serialize (
7177 self ,
7278 futures_results : Sequence [concurrent .futures .Future [None ]],
79+ metadata_coroutine : Coroutine [Any , Any , None ] | None = None ,
7380 ) -> None :
81+ if metadata_coroutine :
82+ await metadata_coroutine
83+
7484 for future_result in futures_results :
7585 future_result .result ()
7686
@@ -86,21 +96,61 @@ async def serialize(
8696 values : Sequence [jax .Array ],
8797 infos : Sequence [ParamInfo ],
8898 args : Sequence [SaveArgs ] | None = None ,
89- ) -> Sequence [future .Future ]:
99+ ) -> list [future .Future ]:
90100 """Uses Pathways Persistence API to serialize a jax array."""
91101 type_handlers .check_input_arguments (values , infos , args )
92102
93103 if any ([arg .dtype is not None for arg in args ]):
94104 raise ValueError ("Casting during save not supported for Pathways." )
95105
106+ array_metadatas = []
107+ any_random_key = False
108+ arrays = []
109+ for v , info , arg in zip (values , infos , args ):
110+ ext_metadata = None
111+ if jax .dtypes .issubdtype (v .dtype , jax .dtypes .prng_key ):
112+ any_random_key = True
113+ v = jax .random .key_data (v )
114+ ext_metadata = {
115+ array_metadata_lib .RANDOM_KEY_IMPL : str (jax .random .key_impl (v ))
116+ }
117+
118+ array_metadatas .append (
119+ ArrayMetadata (
120+ param_name = info .name ,
121+ shape = v .shape ,
122+ dtype = (arg .dtype if arg is not None else v .dtype ),
123+ write_shape = getattr (v , "local_shape" , v .shape ),
124+ chunk_shape = getattr (v , "local_shape" , v .shape ),
125+ use_ocdbt = False ,
126+ use_zarr3 = False ,
127+ ext_metadata = ext_metadata ,
128+ )
129+ )
130+ arrays .append (v )
131+
132+ metadata_coroutine = None
133+ if any_random_key :
134+ if self ._array_metadata_store is None :
135+ raise ValueError (
136+ "Array metadata store is not set with a checkpoint that requires"
137+ f" it. Array metadata: { array_metadatas } "
138+ )
139+
140+ metadata_coroutine = self ._array_metadata_store .write (
141+ checkpoint_dir = infos [0 ].parent_dir ,
142+ array_metadatas = array_metadatas ,
143+ process_index = 0 ,
144+ )
145+
96146 self ._wait_for_directory_creation_signals ()
97147 locations , names = extract_parent_dir_and_name (infos )
98148 f = functools .partial (helper .write_one_array , timeout = self .timeout )
99- futures_results = list (map (f , locations , names , values ))
149+ futures_results = list (map (f , locations , names , arrays ))
100150
101151 return [
102152 future .CommitFutureAwaitingContractedSignals (
103- self ._background_serialize (futures_results ),
153+ self ._background_serialize (futures_results , metadata_coroutine ),
104154 name = "cloud_pathways_array_handler" ,
105155 )
106156 ]
@@ -109,7 +159,7 @@ async def deserialize(
109159 self ,
110160 infos : Sequence [ParamInfo ],
111161 args : Sequence [RestoreArgs ] | None = None ,
112- ) -> Sequence [jax .Array ]:
162+ ) -> list [jax .Array ]:
113163 """Uses Pathways Persistence API to deserialize a jax array."""
114164 if args is None :
115165 raise ValueError ("Must provide ArrayRestoreArgs to restore as jax.Array." )
@@ -128,7 +178,7 @@ async def deserialize(
128178 "To restore jax.Array, provide ArrayRestoreArgs; found"
129179 f" { type (arg ).__name__ } "
130180 )
131- arg = typing . cast (ArrayRestoreArgs , arg )
181+ arg = cast (ArrayRestoreArgs , arg )
132182 if arg .sharding is None and (arg .mesh is None or arg .mesh_axes is None ):
133183 raise ValueError (
134184 "Sharding of jax.Array cannot be None. Provide `mesh`"
@@ -143,7 +193,7 @@ async def deserialize(
143193 else :
144194 if not isinstance (arg .sharding , jax .sharding .NamedSharding ):
145195 raise ValueError ("Pathways only supports jax.sharding.NamedSharding." )
146- sharding = typing . cast (jax .sharding .NamedSharding , arg .sharding )
196+ sharding = cast (jax .sharding .NamedSharding , arg .sharding )
147197 global_meshes .append (sharding .mesh )
148198 mesh_axes .append (sharding .spec )
149199 shardings .append (sharding )
@@ -163,13 +213,30 @@ async def deserialize(
163213 ]
164214 dtypes = [m .dtype if d is None else d for m , d in zip (metadatas , dtypes )]
165215
216+ array_metadatas_cache = {}
217+ if self ._array_metadata_store is not None :
218+ if array_metadatas := await self ._array_metadata_store .read (
219+ checkpoint_dir = infos [0 ].parent_dir ,
220+ process_index = 0 ,
221+ ):
222+ if not isinstance (array_metadatas , list ):
223+ raise ValueError (
224+ "Array metadata store returned unexpected result:"
225+ f" { array_metadatas } "
226+ )
227+
228+ array_metadatas_cache = {
229+ array_metadata .param_name : array_metadata
230+ for array_metadata in array_metadatas
231+ }
232+
166233 # Group inputs by global_mesh so that we can perform batched Array
167234 # construction for each global_mesh.
168235 inputs_by_global_mesh = collections .defaultdict (list )
169236 for i , global_mesh in enumerate (global_meshes ):
170237 inputs_by_global_mesh [global_mesh ].append (i )
171238
172- results = [ None ] * len (infos )
239+ results = cast ( list [ jax . Array ], [ None ] * len (infos ) )
173240
174241 for global_mesh , idxs in inputs_by_global_mesh .items ():
175242 grouped_infos = [infos [idx ] for idx in idxs ]
@@ -188,13 +255,29 @@ async def deserialize(
188255 )
189256 # each persistence call is awaited serially.
190257 read_future .result ()
191- for idx , arr in zip (idxs , grouped_arrays ):
258+ for idx , info , arr in zip (idxs , grouped_infos , grouped_arrays ):
259+ if meta := array_metadatas_cache .get (info .name ):
260+ assert isinstance (
261+ meta , array_metadata_lib .SerializedArrayMetadata
262+ ), f"Expecting SerializedArrayMetadata but got { type (meta )} ."
263+ if meta .ext_metadata :
264+ assert isinstance (meta .ext_metadata , dict ), (
265+ "Expecting ext_metadata to be a dict but got"
266+ f" { type (meta .ext_metadata )} ."
267+ )
268+
269+ if impl := meta .ext_metadata .get (
270+ array_metadata_lib .RANDOM_KEY_IMPL
271+ ):
272+ arr = jax .random .wrap_key_data (arr , impl = impl )
192273 results [idx ] = arr
193- return results # pytype: disable=bad-return-type
274+
275+ return results
194276
195277
196278def register_pathways_handlers (
197279 timeout : datetime .timedelta | None = None ,
280+ array_metadata_store : array_metadata_store_lib .Store | None = None ,
198281):
199282 """Function that must be called before saving or restoring with Pathways."""
200283 logger .debug (
@@ -204,6 +287,7 @@ def register_pathways_handlers(
204287 jax .Array ,
205288 CloudPathwaysArrayHandler (
206289 timeout = timeout ,
290+ array_metadata_store = array_metadata_store ,
207291 ),
208292 override = True ,
209293 )
0 commit comments