File tree Expand file tree Collapse file tree 2 files changed +18
-0
lines changed Expand file tree Collapse file tree 2 files changed +18
-0
lines changed Original file line number Diff line number Diff line change @@ -3,6 +3,8 @@ ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:night
33
44FROM $BASE_IMAGE
55
6+ ARG IS_FOR_V7X="false"
7+
68# Remove existing versions of dependencies
79RUN pip uninstall -y torch torch_xla torchvision
810
@@ -39,4 +41,12 @@ RUN pip install -r requirements_benchmarking.txt
3941COPY . .
4042RUN pip install -e .
4143
44+ # TODO (jacobplatin): remove when v7x is supported in JAX/Libtpu officially
45+ # NOTE: it's important that this is done after installing tpu_inference above,
46+ # so that the v7x-specific dependencies can override any existing ones.
47+ COPY requirements_v7x.txt .
48+ RUN if [ "$IS_FOR_V7X" = "true" ]; then \
49+ pip install -r requirements_v7x.txt; \
50+ fi
51+
4252CMD ["/bin/bash" ]
Original file line number Diff line number Diff line change 1+ # This file contains additional dependencies needed for TPU v7x support.
2+ # It is expected to be used in conjunction with the main requirements.txt file.
3+ --pre
4+ -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
5+ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
6+ jax==0.8.0.dev20251013
7+ jaxlib==0.8.0.dev20251013
8+ libtpu==0.0.25.dev20251012+nightly
You can’t perform that action at this time.
0 commit comments