@@ -242,7 +242,7 @@ jobs:
242242 working-directory : checkpoint
243243 strategy :
244244 matrix :
245- python-version : ["3.10", "3.11", "3.12" ]
245+ python-version : ["3.10"]
246246 jax-version : ["0.6.0"]
247247 steps :
248248 - uses : actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -262,14 +262,22 @@ jobs:
262262 else
263263 pip install "jax[k8s,cuda12]>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}"
264264 fi
265+ pip install fsspec
265266 pip install gcsfs
266267 pip install portpicker
267268 - name : Run benchmarks
268269 env :
269270 GCS_BUCKET_PATH : gs://orbax-benchmarks/benchmark-results/${{ github.run_id }}
270271 run : |
271- cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
272+ cd orbax/checkpoint/_src/testing/benchmarks
273+ for entry in configs/*; do
274+ if [ -f "$entry" ]; then
275+ echo "Running benchmark for $entry"
276+ python run_benchmarks.py --config_file="$entry" --output_directory=$GCS_BUCKET_PATH
277+ fi
278+ done
272279 cd ../../../../..
280+ # cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
273281 # python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py
274282 # cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
275283 # The below step just reports the success or failure of tests as a "commit status".
@@ -278,7 +286,8 @@ jobs:
278286 env :
279287 TEST_TMPDIR : gs://orbax-benchmarks/unit-tests/${{ github.run_id }}
280288 run : |
281- python -c "import jax; jax.distributed.initialize(); print(jax.devices()); import pytest; test_files = [line.strip() for line in open('orbax/checkpoint/_src/testing/multiprocess_tests.txt') if line.strip()]; pytest.main(['-c', '/dev/null'] + test_files)"
289+ python -c "import jax; jax.distributed.initialize(); print(jax.devices()); import pytest; pytest.main(['orbax/checkpoint/checkpoint_manager_test.py'])"
290+ # python -c "import jax; jax.distributed.initialize(); print(jax.devices()); import pytest; test_files = [line.strip() for line in open('orbax/checkpoint/_src/testing/multiprocess_tests.txt') if line.strip()]; pytest.main(['-c', '/dev/null'] + test_files)"
282291 # python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py;"
283292 # cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
284293 # python -m pytest orbax/checkpoint/checkpoint_manager_test.py
0 commit comments