Skip to content

Commit 07b2ba2

Browse files
author
Orbax Authors
committed
Add checkpoint_manager_test.py to multiprocess tests.
PiperOrigin-RevId: 828450426
1 parent 1f5536e commit 07b2ba2

File tree

8 files changed

+69
-32
lines changed

8 files changed

+69
-32
lines changed

.github/workflows/build.yml

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ jobs:
242242
working-directory: checkpoint
243243
strategy:
244244
matrix:
245-
python-version: ["3.10", "3.11", "3.12"]
245+
python-version: ["3.10"]
246246
jax-version: ["0.6.0"]
247247
steps:
248248
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -262,14 +262,30 @@ jobs:
262262
else
263263
pip install "jax[k8s,cuda12]>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}"
264264
fi
265+
pip install fsspec
265266
pip install gcsfs
266267
pip install portpicker
267268
- name: Run benchmarks
268269
env:
269270
GCS_BUCKET_PATH: gs://orbax-benchmarks/benchmark-results/${{ github.run_id }}
270271
run: |
271-
cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
272+
cd orbax/checkpoint/_src/testing/benchmarks
273+
failed_benchmarks=""
274+
for entry in configs/*; do
275+
if [ -f "$entry" ]; then
276+
echo "Running benchmark for $entry"
277+
if ! python run_benchmarks.py --config_file="$entry" --output_directory=$GCS_BUCKET_PATH; then
278+
echo "Benchmark $entry failed"
279+
failed_benchmarks="$failed_benchmarks $entry"
280+
fi
281+
fi
282+
done
272283
cd ../../../../..
284+
if [ -n "$failed_benchmarks" ]; then
285+
echo "The following benchmarks failed:$failed_benchmarks"
286+
exit 1
287+
fi
288+
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
273289
# python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py
274290
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
275291
# The below step just reports the success or failure of tests as a "commit status".
@@ -278,7 +294,8 @@ jobs:
278294
env:
279295
TEST_TMPDIR: gs://orbax-benchmarks/unit-tests/${{ github.run_id }}
280296
run: |
281-
python -c "import jax; jax.distributed.initialize(); print(jax.devices()); import pytest; test_files = [line.strip() for line in open('orbax/checkpoint/_src/testing/multiprocess_tests.txt') if line.strip()]; pytest.main(['-c', '/dev/null'] + test_files)"
297+
python -c "import jax; jax.distributed.initialize(); print(jax.devices()); import pytest; pytest.main(['orbax/checkpoint/checkpoint_manager_test.py'])"
298+
# python -c "import jax; jax.distributed.initialize(); print(jax.devices()); import pytest; test_files = [line.strip() for line in open('orbax/checkpoint/_src/testing/multiprocess_tests.txt') if line.strip()]; pytest.main(['-c', '/dev/null'] + test_files)"
282299
# python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py;"
283300
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
284301
# python -m pytest orbax/checkpoint/checkpoint_manager_test.py

checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/array_handler_benchmark.yaml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
suite_name: "ArrayHandler Benchmark"
22

3-
mesh_config:
4-
mesh_axes: ["data", "model"]
5-
ici_parallelism: {"data": 2, "model": 2}
6-
dcn_parallelism: {"data": 2, "model": 1}
3+
mesh_configs:
4+
- mesh_axes: ["data", "model"]
5+
ici_parallelism: {"data": 2, "model": 2}
6+
dcn_parallelism: {"data": 2, "model": 1}
7+
- mesh_axes: ["data", "model"]
8+
ici_parallelism: {"data": 1, "model": 1}
9+
dcn_parallelism: {"data": 4, "model": 1}
710

811
checkpoint_config:
912
spec:
@@ -17,5 +20,4 @@ benchmarks:
1720
use_replica_parallel: [True, False]
1821
enable_replica_parallel_separate_folder: [True, False]
1922
use_metadata_store: [True, False]
20-
use_colocated_python: [True, False]
2123

checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/checkpoint_manager_benchmark.yaml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
suite_name: "CheckpointManager Benchmark"
22

3-
mesh_config:
4-
mesh_axes: ["data", "model"]
5-
ici_parallelism: {"data": 2, "model": 2}
3+
mesh_configs:
4+
- mesh_axes: ["data", "model"]
5+
ici_parallelism: {"data": 2, "model": 2}
6+
- mesh_axes: ["data", "model"]
7+
ici_parallelism: {"data": 1, "model": 1}
8+
dcn_parallelism: {"data": 4, "model": 1}
69

710
checkpoint_config:
811
spec:

checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/checkpoint_policy_benchmark.yaml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
suite_name: "CheckpointPolicy Benchmark"
33

44

5-
mesh_config:
6-
mesh_axes: ["data", "model"]
7-
ici_parallelism: {"data": 4, "model": 8}
8-
dcn_parallelism: {"data": 1, "model": 1}
5+
mesh_configs:
6+
- mesh_axes: ["data", "model"]
7+
ici_parallelism: {"data": 4, "model": 8}
8+
dcn_parallelism: {"data": 1, "model": 1}
9+
- mesh_axes: ["data", "model"]
10+
ici_parallelism: {"data": 1, "model": 1}
11+
dcn_parallelism: {"data": 4, "model": 1}
912

1013
checkpoint_config:
1114
spec:

checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/emergency_checkpoint_manager_benchmark.yaml

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
# The name for the entire test suite run.
22
suite_name: "EmergencyCheckpointManager Benchmark"
33

4-
mesh_config:
5-
mesh_axes: ["data", "stage", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"]
6-
# ICI: Within a slice. Assuming 8 devices per slice.
7-
# DCN: Across slices.
8-
ici_parallelism: {"fsdp": 4, "tensor": 1, "data": 2}
9-
dcn_parallelism: {"data": 2} # num_slices on the axis at replica_axis_index
10-
allow_split_physical_axes: true
4+
mesh_configs:
5+
- mesh_axes: ["data", "stage", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"]
6+
# ICI: Within a slice. Assuming 8 devices per slice.
7+
# DCN: Across slices.
8+
ici_parallelism: {"fsdp": 4, "tensor": 1, "data": 2}
9+
dcn_parallelism: {"data": 2} # num_slices on the axis at replica_axis_index
10+
allow_split_physical_axes: true
11+
- mesh_axes: ["data", "model"]
12+
ici_parallelism: {"data": 1, "model": 1}
13+
dcn_parallelism: {"data": 4, "model": 1}
1114

1215
checkpoint_config:
1316
spec:

checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/multihost_dispatchers_benchmark.yaml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
suite_name: "Multihost Dispatchers Benchmark"
22

3-
mesh_config:
4-
mesh_axes: ["data", "model"]
5-
ici_parallelism: {"data": 2, "model": 2}
6-
dcn_parallelism: {"data": 2}
3+
mesh_configs:
4+
- mesh_axes: ["data", "model"]
5+
ici_parallelism: {"data": 2, "model": 2}
6+
dcn_parallelism: {"data": 2}
7+
- mesh_axes: ["data", "model"]
8+
ici_parallelism: {"data": 1, "model": 1}
9+
dcn_parallelism: {"data": 4, "model": 1}
710

811
checkpoint_config:
912
spec:

checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/multislice_broadcast_benchmark.yaml

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
# The name for the entire test suite run.
22
suite_name: "Multislice Broadcast Benchmark"
33

4-
mesh_config:
5-
mesh_axes: ["data", "stage", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"]
6-
# ICI: Within a slice.
7-
ici_parallelism: {"fsdp": 4, "data": 2}
8-
# DCN: Across slices.
9-
dcn_parallelism: {"data": 2}
4+
mesh_configs:
5+
- mesh_axes: ["data", "stage", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"]
6+
# ICI: Within a slice.
7+
ici_parallelism: {"fsdp": 4, "data": 2}
8+
# DCN: Across slices.
9+
dcn_parallelism: {"data": 2}
10+
- mesh_axes: ["data", "model"]
11+
ici_parallelism: {"data": 1, "model": 1}
12+
dcn_parallelism: {"data": 4, "model": 1}
1013

1114
checkpoint_config:
1215
spec:

checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/single_replica_benchmark.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ mesh_configs:
2929
ici_parallelism: {"fsdp": 32, "tensor": 2, "data": 1}
3030
dcn_parallelism: {"data": 2} # num_slices on the axis at replica_axis_index
3131
allow_split_physical_axes: True
32+
- mesh_axes: ["data", "model"]
33+
ici_parallelism: {"data": 1, "model": 1}
34+
dcn_parallelism: {"data": 4, "model": 1}
3235

3336
checkpoint_config:
3437
spec:

0 commit comments

Comments
 (0)