Skip to content

Commit bad624c

Browse files
author
Orbax Authors
committed
Add a presubmit benchmark for CheckpointManager using Pathways.
PiperOrigin-RevId: 825533145
1 parent 6e5d352 commit bad624c

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

checkpoint/orbax/checkpoint/_src/testing/benchmarks/checkpoint_manager_benchmark.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import numpy as np
2323
from orbax.checkpoint import args as args_lib
2424
from orbax.checkpoint import checkpoint_manager
25+
from orbax.checkpoint import multihost
2526
from orbax.checkpoint import utils
2627
from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core
2728
from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib
@@ -55,6 +56,7 @@ def test_fn(
5556
options = context.options
5657
assert isinstance(options, CheckpointManagerBenchmarkOptions)
5758

59+
5860
cm_options = checkpoint_manager.CheckpointManagerOptions(
5961
save_interval_steps=options.save_interval_steps,
6062
max_to_keep=options.max_to_keep,
@@ -66,13 +68,23 @@ def test_fn(
6668
json_data = {'a': 1, 'b': 'test'}
6769
random_key = jax.random.key(0)
6870
np_random_key = np.random.get_state()
71+
pytree_for_restore = self._get_pytree_for_restore(pytree)
6972

70-
composite_args = args_lib.Composite(
71-
pytree=args_lib.StandardSave(pytree),
72-
json_item=args_lib.JsonSave(json_data),
73-
jax_random_key=args_lib.JaxRandomKeySave(random_key),
74-
np_random_key=args_lib.NumpyRandomKeySave(np_random_key),
75-
)
73+
save_kwargs = {
74+
'pytree': args_lib.StandardSave(pytree),
75+
'json_item': args_lib.JsonSave(json_data),
76+
'np_random_key': args_lib.NumpyRandomKeySave(np_random_key),
77+
}
78+
restore_kwargs = {
79+
'pytree': args_lib.StandardRestore(pytree_for_restore),
80+
'json_item': args_lib.JsonRestore(),
81+
'np_random_key': args_lib.NumpyRandomKeyRestore(),
82+
}
83+
if not multihost.is_pathways_backend():
84+
save_kwargs['jax_random_key'] = args_lib.JaxRandomKeySave(random_key)
85+
restore_kwargs['jax_random_key'] = args_lib.JaxRandomKeyRestore()
86+
composite_args = args_lib.Composite(**save_kwargs)
87+
restore_args = args_lib.Composite(**restore_kwargs)
7688

7789
step_saved = -1
7890
for step in range(options.train_steps):
@@ -87,14 +99,6 @@ def test_fn(
8799
if step_saved == -1:
88100
raise AssertionError('No checkpoint was saved.')
89101

90-
pytree_for_restore = self._get_pytree_for_restore(pytree)
91-
92-
restore_args = args_lib.Composite(
93-
pytree=args_lib.StandardRestore(pytree_for_restore),
94-
json_item=args_lib.JsonRestore(),
95-
jax_random_key=args_lib.JaxRandomKeyRestore(),
96-
np_random_key=args_lib.NumpyRandomKeyRestore(),
97-
)
98102
latest_step = mngr.latest_step()
99103
assert latest_step == step_saved, (
100104
f'Expected latest step to be {step_saved}, got {latest_step}'
@@ -110,9 +114,10 @@ def test_fn(
110114
assert (
111115
json_data == restored['json_item']
112116
), f"Expected {json_data}, got {restored['json_item']}"
113-
assert jax.numpy.array_equal(
114-
random_key, restored['jax_random_key']
115-
), f"Expected {random_key}, got {restored['jax_random_key']}"
117+
if not multihost.is_pathways_backend():
118+
assert jax.numpy.array_equal(
119+
random_key, restored['jax_random_key']
120+
), f"Expected {random_key}, got {restored['jax_random_key']}"
116121
jax.tree.map(
117122
np.testing.assert_equal, np_random_key, restored['np_random_key']
118123
)

0 commit comments

Comments
 (0)