Skip to content

Commit 5b0c3d2

Browse files
jrplatinJacob Platin
andauthored
[Docker] Add V7X requirements and update Docker to accept option to build using it (#916)
Signed-off-by: Jacob Platin <[email protected]> Co-authored-by: Jacob Platin <[email protected]>
1 parent 163cd94 commit 5b0c3d2

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

docker/Dockerfile

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:night
33

44
FROM $BASE_IMAGE
55

6+
ARG IS_FOR_V7X="false"
7+
68
# Remove existing versions of dependencies
79
RUN pip uninstall -y torch torch_xla torchvision
810

@@ -39,4 +41,12 @@ RUN pip install -r requirements_benchmarking.txt
3941
COPY . .
4042
RUN pip install -e .
4143

44+
# TODO (jacobplatin): remove when v7x is supported in JAX/Libtpu officially
45+
# NOTE: it's important that this is done after installing tpu_inference above,
46+
# so that the v7x-specific dependencies can override any existing ones.
47+
COPY requirements_v7x.txt .
48+
RUN if [ "$IS_FOR_V7X" = "true" ]; then \
49+
pip install -r requirements_v7x.txt; \
50+
fi
51+
4252
CMD ["/bin/bash"]

requirements_v7x.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# This file contains additional dependencies needed for TPU v7x support.
2+
# It is expected to be used in conjunction with the main requirements.txt file.
3+
--pre
4+
-i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
5+
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
6+
jax==0.8.0.dev20251013
7+
jaxlib==0.8.0.dev20251013
8+
libtpu==0.0.25.dev20251012+nightly

0 commit comments

Comments
 (0)