2222import numpy as np
2323from orbax .checkpoint import args as args_lib
2424from orbax .checkpoint import checkpoint_manager
25+ from orbax .checkpoint import multihost
2526from orbax .checkpoint import utils
2627from orbax .checkpoint ._src .testing .benchmarks .core import core as benchmarks_core
2728from orbax .checkpoint ._src .testing .benchmarks .core import metric as metric_lib
@@ -55,6 +56,7 @@ def test_fn(
5556 options = context .options
5657 assert isinstance (options , CheckpointManagerBenchmarkOptions )
5758
59+
5860 cm_options = checkpoint_manager .CheckpointManagerOptions (
5961 save_interval_steps = options .save_interval_steps ,
6062 max_to_keep = options .max_to_keep ,
@@ -66,13 +68,23 @@ def test_fn(
6668 json_data = {'a' : 1 , 'b' : 'test' }
6769 random_key = jax .random .key (0 )
6870 np_random_key = np .random .get_state ()
71+ pytree_for_restore = self ._get_pytree_for_restore (pytree )
6972
70- composite_args = args_lib .Composite (
71- pytree = args_lib .StandardSave (pytree ),
72- json_item = args_lib .JsonSave (json_data ),
73- jax_random_key = args_lib .JaxRandomKeySave (random_key ),
74- np_random_key = args_lib .NumpyRandomKeySave (np_random_key ),
75- )
73+ save_kwargs = {
74+ 'pytree' : args_lib .StandardSave (pytree ),
75+ 'json_item' : args_lib .JsonSave (json_data ),
76+ 'np_random_key' : args_lib .NumpyRandomKeySave (np_random_key ),
77+ }
78+ restore_kwargs = {
79+ 'pytree' : args_lib .StandardRestore (pytree_for_restore ),
80+ 'json_item' : args_lib .JsonRestore (),
81+ 'np_random_key' : args_lib .NumpyRandomKeyRestore (),
82+ }
83+ if not multihost .is_pathways_backend ():
84+ save_kwargs ['jax_random_key' ] = args_lib .JaxRandomKeySave (random_key )
85+ restore_kwargs ['jax_random_key' ] = args_lib .JaxRandomKeyRestore ()
86+ composite_args = args_lib .Composite (** save_kwargs )
87+ restore_args = args_lib .Composite (** restore_kwargs )
7688
7789 step_saved = - 1
7890 for step in range (options .train_steps ):
@@ -87,14 +99,6 @@ def test_fn(
8799 if step_saved == - 1 :
88100 raise AssertionError ('No checkpoint was saved.' )
89101
90- pytree_for_restore = self ._get_pytree_for_restore (pytree )
91-
92- restore_args = args_lib .Composite (
93- pytree = args_lib .StandardRestore (pytree_for_restore ),
94- json_item = args_lib .JsonRestore (),
95- jax_random_key = args_lib .JaxRandomKeyRestore (),
96- np_random_key = args_lib .NumpyRandomKeyRestore (),
97- )
98102 latest_step = mngr .latest_step ()
99103 assert latest_step == step_saved , (
100104 f'Expected latest step to be { step_saved } , got { latest_step } '
@@ -110,9 +114,10 @@ def test_fn(
110114 assert (
111115 json_data == restored ['json_item' ]
112116 ), f"Expected { json_data } , got { restored ['json_item' ]} "
113- assert jax .numpy .array_equal (
114- random_key , restored ['jax_random_key' ]
115- ), f"Expected { random_key } , got { restored ['jax_random_key' ]} "
117+ if not multihost .is_pathways_backend ():
118+ assert jax .numpy .array_equal (
119+ random_key , restored ['jax_random_key' ]
120+ ), f"Expected { random_key } , got { restored ['jax_random_key' ]} "
116121 jax .tree .map (
117122 np .testing .assert_equal , np_random_key , restored ['np_random_key' ]
118123 )
0 commit comments