Skip to content

Commit 7caacdd

Browse files
author
Orbax Authors
committed
Add checkpoint_manager_test.py to multiprocess tests.
PiperOrigin-RevId: 828450426
1 parent faa0b39 commit 7caacdd

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

.github/workflows/build.yml

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ jobs:
242242
working-directory: checkpoint
243243
strategy:
244244
matrix:
245-
python-version: ["3.10", "3.11", "3.12"]
245+
python-version: ["3.10"]
246246
jax-version: ["0.6.0"]
247247
steps:
248248
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -262,14 +262,22 @@ jobs:
262262
else
263263
pip install "jax[k8s,cuda12]>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}"
264264
fi
265+
pip install fsspec
265266
pip install gcsfs
266267
pip install portpicker
267268
- name: Run benchmarks
268269
env:
269270
GCS_BUCKET_PATH: gs://orbax-benchmarks/benchmark-results/${{ github.run_id }}
270271
run: |
271-
cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
272+
cd orbax/checkpoint/_src/testing/benchmarks
273+
for entry in configs/*; do
274+
if [ -f "$entry" ]; then
275+
echo "Running benchmark for $entry"
276+
python run_benchmarks.py --config_file="$entry" --output_directory=$GCS_BUCKET_PATH
277+
fi
278+
done
272279
cd ../../../../..
280+
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
273281
# python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py
274282
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
275283
# The below step just reports the success or failure of tests as a "commit status".
@@ -278,7 +286,8 @@ jobs:
278286
env:
279287
TEST_TMPDIR: gs://orbax-benchmarks/unit-tests/${{ github.run_id }}
280288
run: |
281-
python -c "import jax; jax.distributed.initialize(); print(jax.devices()); import pytest; test_files = [line.strip() for line in open('orbax/checkpoint/_src/testing/multiprocess_tests.txt') if line.strip()]; pytest.main(['-c', '/dev/null'] + test_files)"
289+
python -c "import jax; jax.distributed.initialize(); print(jax.devices()); import pytest; pytest.main(['orbax/checkpoint/checkpoint_manager_test.py'])"
290+
# python -c "import jax; jax.distributed.initialize(); print(jax.devices()); import pytest; test_files = [line.strip() for line in open('orbax/checkpoint/_src/testing/multiprocess_tests.txt') if line.strip()]; pytest.main(['-c', '/dev/null'] + test_files)"
282291
# python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py;"
283292
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
284293
# python -m pytest orbax/checkpoint/checkpoint_manager_test.py

checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/array_handler_benchmark.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ suite_name: "ArrayHandler Benchmark"
33
mesh_config:
44
mesh_axes: ["data", "model"]
55
ici_parallelism: {"data": 2, "model": 2}
6-
dcn_parallelism: {"data": 2, "model": 1}
6+
dcn_parallelism: {"data": 4, "model": 1}
77

88
checkpoint_config:
99
spec:

0 commit comments

Comments
 (0)