Skip to content

Commit 91d1fb7

Browse files
author
Orbax Authors
committed
Add a GitHub workflow for multiprocess checkpoint benchmarks and tests.
PiperOrigin-RevId: 825391056
1 parent 2e64309 commit 91d1fb7

File tree

5 files changed

+4171
-5
lines changed

5 files changed

+4171
-5
lines changed

.github/workflows/build.yml

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ jobs:
6161
- name: Test with pytest
6262
# TODO(yaning): Move these to an exclude target within pytest.ini.
6363
run: |
64-
python -m pytest --import-mode=importlib --ignore=orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py --ignore=orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/multihost_test.py --ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py
64+
python -m pytest --import-mode=importlib --ignore=orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py --ignore=orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/multihost_test.py --ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py --ignore=orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py --ignore=orbax/checkpoint/checkpoint_manager_test.py
6565
# The below step just reports the success or failure of tests as a "commit status".
6666
# This is needed for copybara integration.
6767
- name: Report success or failure as github status
@@ -231,3 +231,69 @@ jobs:
231231
"description": "'$status'",
232232
"context": "github-actions/build"
233233
}'
234+
multiprocess-checkpoint-benchmarks:
235+
name: "multiprocess-checkpoint-benchmarks (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
236+
runs-on: linux-g2-16-l4-1gpu-x4
237+
# runs-on: linux-x86-ct5lp-4tpu-x4
238+
container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e
239+
defaults:
240+
run:
241+
working-directory: checkpoint
242+
strategy:
243+
matrix:
244+
python-version: ["3.10", "3.11", "3.12"]
245+
jax-version: ["0.6.0"]
246+
steps:
247+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
248+
- name: Set up Python ${{ matrix.python-version }}
249+
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
250+
with:
251+
python-version: ${{ matrix.python-version }}
252+
- name: Install dependencies
253+
run: |
254+
pip install -e .
255+
pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
256+
pip uninstall -y orbax
257+
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
258+
pip install -U jax[k8s,cuda12] jaxlib
259+
elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then
260+
pip install -U --pre jax[k8s,cuda12] jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
261+
else
262+
pip install "jax[k8s,cuda12]>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}"
263+
fi
264+
pip install gcsfs
265+
pip install portpicker
266+
- name: Run benchmarks
267+
env:
268+
GCS_BUCKET_PATH: gs://orbax-benchmarks/benchmark-results/${{ github.run_id }}
269+
run: |
270+
cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
271+
cd ../../../../..
272+
# python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py
273+
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
274+
# The below step just reports the success or failure of tests as a "commit status".
275+
# This is needed for copybara integration.
276+
- name: Run multiprocess tests
277+
env:
278+
TEST_TMPDIR: gs://orbax-benchmarks/unit-tests/${{ github.run_id }}
279+
run: |
280+
python -c "import jax; jax.distributed.initialize(); print(jax.devices()); import pytest; pytest.main(['orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py'])"
281+
# python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py;"
282+
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
283+
# python -m pytest orbax/checkpoint/checkpoint_manager_test.py
284+
- name: Report success or failure as github status
285+
if: always()
286+
shell: bash
287+
run: |
288+
status="${{ job.status }}"
289+
lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
290+
curl -sS --request POST \
291+
--url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
292+
--header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
293+
--header 'content-type: application/json' \
294+
--data '{
295+
"state": "'$lowercase_status'",
296+
"target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
297+
"description": "'$status'",
298+
"context": "github-actions/build"
299+
}'
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright 2025 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for ArrayCheckpointHandler."""
16+
17+
from absl import flags
18+
from absl.testing import parameterized
19+
from etils import epath
20+
import jax
21+
import numpy as np
22+
from orbax.checkpoint import test_utils
23+
from orbax.checkpoint._src.handlers import array_checkpoint_handler
24+
from orbax.checkpoint._src.multihost import multihost
25+
from orbax.checkpoint._src.serialization import type_handlers
26+
from orbax.checkpoint._src.testing import multiprocess_test
27+
28+
29+
SaveArgs = type_handlers.SaveArgs
30+
ArraySaveArgs = array_checkpoint_handler.ArraySaveArgs
31+
ArrayRestoreArgs = array_checkpoint_handler.ArrayRestoreArgs
32+
33+
34+
FLAGS = flags.FLAGS
35+
36+
37+
class ArrayCheckpointHandler(array_checkpoint_handler.ArrayCheckpointHandler):
38+
39+
def save(self, directory, *args, **kwargs):
40+
super().save(directory, *args, **kwargs)
41+
test_utils.sync_global_processes('ArrayCheckpointHandler:save')
42+
if multihost.process_index() == 0:
43+
self.finalize(directory)
44+
test_utils.sync_global_processes('ArrayCheckpointHandler:finalize')
45+
46+
47+
class ArrayCheckpointHandlerTest(
48+
parameterized.TestCase, multiprocess_test.MultiProcessTest
49+
):
50+
51+
def setUp(self):
52+
super().setUp()
53+
self.devices = np.asarray(jax.devices())
54+
self.directory = epath.Path(
55+
self.create_tempdir(name='checkpointing_test').full_path
56+
)
57+
58+
test_utils.sync_global_processes(
59+
'ArrayCheckpointHandlerTest:setup_complete'
60+
)
61+
62+
def tearDown(self):
63+
test_utils.sync_global_processes(
64+
'ArrayCheckpointHandlerTest:tests_complete'
65+
)
66+
super().tearDown()
67+
68+
def validate_save(self):
69+
path = self.directory / array_checkpoint_handler.PYTREE_METADATA_FILE
70+
self.assertTrue(path.exists())
71+
72+
def test_array(self):
73+
checkpoint_handler = ArrayCheckpointHandler()
74+
mesh = jax.sharding.Mesh(self.devices, ('x',))
75+
mesh_axes = jax.sharding.PartitionSpec(
76+
'x',
77+
)
78+
arr = test_utils.create_sharded_array(np.arange(16), mesh, mesh_axes)
79+
save_args = SaveArgs()
80+
checkpoint_handler.save(self.directory, args=ArraySaveArgs(arr, save_args))
81+
self.validate_save()
82+
restored = checkpoint_handler.restore(
83+
self.directory,
84+
args=ArrayRestoreArgs(
85+
restore_args=type_handlers.ArrayRestoreArgs(
86+
restore_type=jax.Array, mesh=mesh, mesh_axes=mesh_axes
87+
)
88+
),
89+
)
90+
test_utils.assert_tree_equal(self, [arr], [restored])
91+
checkpoint_handler.close()
92+
93+
def test_numpy_array(self):
94+
checkpoint_handler = ArrayCheckpointHandler()
95+
arr = np.arange(16)
96+
save_args = SaveArgs()
97+
checkpoint_handler.save(self.directory, args=ArraySaveArgs(arr, save_args))
98+
self.validate_save()
99+
restored = checkpoint_handler.restore(
100+
self.directory,
101+
args=ArrayRestoreArgs(
102+
restore_args=type_handlers.RestoreArgs(restore_type=np.ndarray)
103+
),
104+
)
105+
test_utils.assert_tree_equal(self, [arr], [restored])
106+
checkpoint_handler.close()
107+
108+
def test_scalar(self):
109+
checkpoint_handler = ArrayCheckpointHandler()
110+
save_args = SaveArgs()
111+
checkpoint_handler.save(self.directory, args=ArraySaveArgs(5, save_args))
112+
self.validate_save()
113+
restored = checkpoint_handler.restore(
114+
self.directory,
115+
args=ArrayRestoreArgs(
116+
restore_args=type_handlers.RestoreArgs(restore_type=int)
117+
),
118+
)
119+
self.assertEqual(5, restored)
120+
checkpoint_handler.close()
121+
122+
def test_invalid_type(self):
123+
checkpoint_handler = ArrayCheckpointHandler()
124+
with self.assertRaises(TypeError):
125+
checkpoint_handler.save(self.directory, args=ArraySaveArgs('hi'))
126+
checkpoint_handler.close()
127+
128+
def test_different_name(self):
129+
checkpoint_name = 'my_array'
130+
checkpoint_handler = ArrayCheckpointHandler(checkpoint_name=checkpoint_name)
131+
arr = np.arange(16)
132+
save_args = SaveArgs()
133+
checkpoint_handler.save(self.directory, args=ArraySaveArgs(arr, save_args))
134+
self.validate_save()
135+
restored = checkpoint_handler.restore(
136+
self.directory,
137+
args=ArrayRestoreArgs(
138+
restore_args=type_handlers.RestoreArgs(restore_type=np.ndarray)
139+
),
140+
)
141+
test_utils.assert_tree_equal(self, [arr], [restored])
142+
checkpoint_handler.close()
143+
144+
def test_restore_type(self):
145+
pytree = 5
146+
checkpoint_handler = ArrayCheckpointHandler()
147+
148+
checkpoint_handler.save(self.directory, args=ArraySaveArgs(pytree))
149+
restored = checkpoint_handler.restore(
150+
self.directory,
151+
args=ArrayRestoreArgs(
152+
restore_args=type_handlers.RestoreArgs(restore_type=np.ndarray)
153+
),
154+
)
155+
self.assertIsInstance(restored, np.ndarray)
156+
157+
158+
if __name__ == '__main__':
159+
multiprocess_test.main()

checkpoint/orbax/checkpoint/_src/multihost/multihost.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,15 @@ def coordination_timeout() -> int:
5858

5959

6060

61+
def is_pathways_backend() -> bool:
62+
# Pathways is single-host.
63+
return (
64+
hasattr(jax.devices()[0].client, 'pathways')
65+
or jax.devices()[0].client.runtime_type == 'pathways'
66+
or jax.devices()[0].client.runtime_type == 'proxy/pathways'
67+
)
68+
69+
6170
def is_runtime_to_distributed_ids_initialized() -> bool:
6271
return _RUNTIME_TO_DISTRIBUTED_ID is not None
6372

checkpoint/orbax/checkpoint/_src/testing/multiprocess_test.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,10 +304,6 @@ def setUp(self):
304304
jax.process_count() == 1
305305
), "Expected 1 process for Pathways backend."
306306
else:
307-
assert jax.process_count() == NUM_PROCESSES.value, (
308-
jax.process_count(),
309-
NUM_PROCESSES.value,
310-
)
311307
# Make sure all processes are at the same test case.
312308
client = multihost.get_jax_distributed_client()
313309
# Note that the name of this barrier is long and complicated, to prevent

0 commit comments

Comments
 (0)