Skip to content

Commit ca0b396

Browse files
yhtangterrykongDwarKapex
authored
Use pip-compile to help with consistent Python dependency resolution (#371)
# Summary - All Python packages, except for a few build dependencies, are now installed using **pip-tools**. - The JAX and upstream T5X/PAX containers are now built in a two-stage procedure: 1. The **'meal kit'** stage: source packages are downloaded, wheels built if necessary (for TE, tensorflow-text, lingvo, etc.), but **no** package is installed. Instead, manifest files are created in the `/opt/pip-tools.d` folder to instruct which packages shall be installed by pip-tools. The stage is named due to its similarity in how ingredients in a meal kit are prepared while deferring the final cooking step. 2. The **'final'** (cooking🔥) stage: this is when pip-tools collectively compile the manifests from the various container layers and then sync-install everything to exactly match the resolved versions. - Note that downstream containers will **build on top of the meal kit image of its base container**, thus ensuring all packages and dependencies are installed exactly once to avoid conflicts and image bloating. - The meal kit and final images are published as - mealkit: `ghcr.io/nvidia/image:mealkit` and `ghcr.io/nvidia/image:mealkit-YYYY-MM-DD` - final: `ghcr.io/nvidia/image:latest` and `ghcr.io/nvidia/image:nightly-YYYY-MM-DD` # Additional changes to the workflows - `/opt/jax-source` is renamed to `/opt/jax`. The `-source` suffix is only added to packages that needs compilation, e.g. XLA and TE. - The CI workflow is now matricized against CPU arch. - The reusable `_build_*.yaml` workflows are simplified to build only one image for a single architecture at a time. The logic for creating multi-arch images is relocated into the `_publish_container.yaml` workflows and involved during the nightly runs only. - TE is now built as a wheel and shipped in the JAX core meal kit image. - TE unit tests will be performed using the upstream-pax image due to the dependency on praxis. - Build workflows now produce sitreps following the paradigm of #229. - Removed the various one-off workflows for pinned CUDA/JAX versions. - Refactored the PAX arm64 Dockerfile in preparation for #338 # What remains to be done - [ ] Update the Rosetta container build + test process to use the upstream T5X/PAX mealkit (ghcr.io/nvidia/upstream-t5x:mealkit, ghcr.io/nvidia/upstream-pax:mealkit) containers # Reviewing tips This PR requires a multitude of reviewers due to its size and scope. I'd truly appreciate code owners to review any changes related to their previous contributions. An incomplete list of reviewer-scope is: - @terrykong, @ashors1, @sharathts, @maanug-nv: Rosetta, TE, T5X and PAX MGMN tests - @nouiz: JAX, TE and T5X build - @joker-eph: PAX arm64 build - @nluehr: Base image, NCCL, PAX - @DwarKapex: base/JAX/XLA build, workflow logic Closes #223 Closes #230 Closes #231 Closes #232 Closes #233 Closes #271 Fixes #328 Fixes #337 Co-authored-by: Terry Kong <[email protected]> --------- Co-authored-by: Terry Kong <[email protected]> Co-authored-by: Vladislav Kozlov <[email protected]>
1 parent 2aa961a commit ca0b396

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1697
-2321
lines changed

.github/container/Dockerfile.base

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
ARG BASE_IMAGE=nvidia/cuda:12.2.0-devel-ubuntu22.04
2+
ARG GIT_USER_NAME="JAX Toolbox"
3+
4+
25
FROM ${BASE_IMAGE}
6+
ARG GIT_USER_EMAIL
7+
ARG GIT_USER_NAME
38

49
###############################################################################
510
## Install Python and essential tools
@@ -17,13 +22,28 @@ RUN apt-get update && \
1722
git \
1823
lld \
1924
vim \
25+
bat \
26+
curl \
27+
git \
28+
gnupg \
29+
rsync \
2030
python-is-python3 \
2131
python3-pip \
32+
liblzma-dev \
2233
wget \
2334
&& \
2435
apt-get clean && \
2536
rm -rf /var/lib/apt/lists/*
26-
RUN pip install --upgrade --no-cache-dir pip
37+
RUN <<"EOF" bash -ex
38+
git config --global user.name "${GIT_USER_NAME}"
39+
git config --global user.email "${GIT_USER_EMAIL}"
40+
EOF
41+
RUN pip install --upgrade --no-cache-dir pip pip-tools && rm -rf ~/.cache/*
42+
RUN mkdir -p /opt/pip-tools.d
43+
ADD --chmod=777 \
44+
get-source.sh \
45+
pip-finalize.sh \
46+
/usr/local/bin/
2747

2848
###############################################################################
2949
## Install cuDNN

.github/container/Dockerfile.jax

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,36 @@
11
ARG BASE_IMAGE=ghcr.io/nvidia/jax-toolbox:base
22
ARG REPO_JAX="https://github.com/google/jax.git"
33
ARG REPO_XLA="https://github.com/openxla/xla.git"
4+
ARG REPO_FLAX="https://github.com/google/flax.git"
5+
ARG REPO_TE="https://github.com/NVIDIA/TransformerEngine.git"
46
ARG REF_JAX=main
57
ARG REF_XLA=main
6-
ARG SRC_PATH_JAX=/opt/jax-source
8+
ARG REF_FLAX=main
9+
ARG REF_TE=main
10+
ARG SRC_PATH_JAX=/opt/jax
711
ARG SRC_PATH_XLA=/opt/xla-source
12+
ARG SRC_PATH_FLAX=/opt/flax
13+
ARG SRC_PATH_TE=/opt/transformer-engine-source
14+
ARG GIT_USER_NAME="JAX Toolbox"
15+
16+
817
ARG BAZEL_CACHE=/tmp
918
ARG BUILD_DATE
1019

1120
###############################################################################
1221
## Build JAX
1322
###############################################################################
1423

15-
FROM ${BASE_IMAGE} as jax-builder
24+
FROM ${BASE_IMAGE} as builder
1625
ARG REPO_JAX
1726
ARG REPO_XLA
1827
ARG REF_JAX
1928
ARG REF_XLA
2029
ARG SRC_PATH_JAX
2130
ARG SRC_PATH_XLA
2231
ARG BAZEL_CACHE
32+
ARG GIT_USER_NAME
33+
ARG GIT_USER_EMAIL
2334

2435
RUN git clone "${REPO_JAX}" "${SRC_PATH_JAX}" && cd "${SRC_PATH_JAX}" && git checkout ${REF_JAX}
2536
RUN --mount=type=ssh \
@@ -30,8 +41,8 @@ RUN --mount=type=ssh \
3041
RUN <<EOF bash -ex
3142
cd ${SRC_PATH_XLA}
3243

33-
git config user.name "JAX Toolbox"
34-
git config user.email "[email protected]"
44+
git config --global user.name "${GIT_USER_NAME}"
45+
git config --global user.email "${GIT_USER_EMAIL}"
3546
git remote add -f ashors1 https://github.com/ashors1/xla
3647
git cherry-pick --allow-empty $(git merge-base ashors/main ashors1/revert-84222)..ashors1/revert-84222
3748
git remote remove ashors1
@@ -47,15 +58,12 @@ RUN build-jax.sh \
4758
--xla-arm64-patch /opt/xla-arm64-neon.patch \
4859
--clean
4960

50-
RUN cp -r ${SRC_PATH_JAX} ${SRC_PATH_JAX}-no-git && rm -rf ${SRC_PATH_JAX}-no-git/.git
51-
RUN cp -r ${SRC_PATH_XLA} ${SRC_PATH_XLA}-no-git && rm -rf ${SRC_PATH_XLA}-no-git/.git
52-
5361
###############################################################################
54-
## Build 'runtime' flavor without the git metadata
62+
## Pack jaxlib wheel and various source dirs into a pre-installation image
5563
###############################################################################
5664

5765
ARG BASE_IMAGE
58-
FROM ${BASE_IMAGE} as runtime-image
66+
FROM ${BASE_IMAGE} as mealkit
5967
ARG SRC_PATH_JAX
6068
ARG SRC_PATH_XLA
6169
ARG BUILD_DATE
@@ -67,29 +75,43 @@ ENV NCCL_IB_SL=1
6775
ENV NCCL_NVLS_ENABLE=0
6876
ENV CUDA_MODULE_LOADING=EAGER
6977

70-
COPY --from=jax-builder ${SRC_PATH_JAX}-no-git ${SRC_PATH_JAX}
71-
COPY --from=jax-builder ${SRC_PATH_XLA}-no-git ${SRC_PATH_XLA}
78+
COPY --from=builder ${SRC_PATH_JAX} ${SRC_PATH_JAX}
79+
COPY --from=builder ${SRC_PATH_XLA} ${SRC_PATH_XLA}
80+
ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/
7281

73-
RUN pip --disable-pip-version-check install ${SRC_PATH_JAX}/dist/*.whl && \
74-
pip --disable-pip-version-check install -e ${SRC_PATH_JAX} && \
75-
rm -rf ~/.cache/pip/
82+
RUN mkdir -p /opt/pip-tools.d
83+
RUN <<"EOF" bash -ex
84+
echo "-e file://${SRC_PATH_JAX}" >> /opt/pip-tools.d/manifest.jax
85+
echo "jaxlib @ file://$(ls ${SRC_PATH_JAX}/dist/*.whl)" >> /opt/pip-tools.d/manifest.jax
86+
EOF
7687

77-
# Install software stack in JAX ecosystem
78-
# Made this optional since tensorstore cannot build on Ubuntu 20.04 + ARM
79-
RUN { pip install flax || true; } && rm -rf ~/.cache/*
88+
## Flax
89+
ARG REPO_FLAX
90+
ARG REF_FLAX
91+
ARG SRC_PATH_FLAX
92+
RUN get-source.sh -f ${REPO_FLAX} -r ${REF_FLAX} -d ${SRC_PATH_FLAX} -m /opt/pip-tools.d/manifest.flax
93+
94+
## Transformer engine: check out source and build wheel
95+
ARG REPO_TE
96+
ARG REF_TE
97+
ARG SRC_PATH_TE
98+
ENV NVTE_FRAMEWORK=jax
99+
ENV SRC_PATH_TE=${SRC_PATH_TE}
100+
RUN <<"EOF" bash -ex
101+
set -o pipefail
102+
pip install ninja && rm -rf ~/.cache/pip
103+
get-source.sh -f ${REPO_TE} -r ${REF_TE} -d ${SRC_PATH_TE}
104+
pushd ${SRC_PATH_TE}
105+
python setup.py bdist_wheel && rm -rf build
106+
echo "transformer-engine @ file://$(ls ${SRC_PATH_TE}/dist/*.whl)" >> /opt/pip-tools.d/manifest.te
107+
EOF
80108

81109
# TODO: properly configure entrypoint
82-
# COPY entrypoint.d/ /opt/nvidia/entrypoint.d/
83110

84111
###############################################################################
85-
## Build 'devel' image with build scripts and git metadata
112+
## Install primary packages and transitive dependencies
86113
###############################################################################
87114

88-
FROM runtime-image as devel-image
89-
ARG SRC_PATH_JAX
90-
ARG SRC_PATH_XLA
91-
92-
ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/
115+
FROM mealkit as final
93116

94-
COPY --from=jax-builder ${SRC_PATH_JAX}/.git ${SRC_PATH_JAX}/.git
95-
COPY --from=jax-builder ${SRC_PATH_XLA}/.git ${SRC_PATH_XLA}/.git
117+
RUN pip-finalize.sh
Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,54 @@
11
# syntax=docker/dockerfile:1-labs
2-
###############################################################################
3-
## Pax
4-
###############################################################################
52

63
ARG BASE_IMAGE=ghcr.io/nvidia/jax:latest
7-
FROM ${BASE_IMAGE}
8-
9-
ADD install-pax.sh /usr/local/bin
10-
ADD install-flax.sh /usr/local/bin
11-
ADD install-te.sh /usr/local/bin
12-
13-
ENV NVTE_FRAMEWORK=jax
144
ARG REPO_PAXML=https://github.com/google/paxml.git
155
ARG REPO_PRAXIS=https://github.com/google/praxis.git
166
ARG REF_PAXML=main
177
ARG REF_PRAXIS=main
18-
ARG REPO_TE=https://github.com/NVIDIA/TransformerEngine.git
19-
# TODO: This is a temporary pinning of TE as the API in TE no longer matches the TE patch
20-
# This should be reverted to main ASAP
21-
ARG REF_TE=7976bd003fcf084dd068069b92a9a79b1743316a
8+
ARG SRC_PATH_PAXML=/opt/paxml
9+
ARG SRC_PATH_PRAXIS=/opt/praxis
10+
11+
###############################################################################
12+
## Download source and add auxiliary scripts
13+
###############################################################################
14+
15+
FROM ${BASE_IMAGE} as mealkit
16+
ARG REPO_PAXML
17+
ARG REPO_PRAXIS
18+
ARG REF_PAXML
19+
ARG REF_PRAXIS
20+
ARG SRC_PATH_PAXML
21+
ARG SRC_PATH_PRAXIS
22+
23+
# update TE manifest file to install the [test] extras
24+
RUN sed -i "s/transformer-engine @/transformer-engine[test] @/g" /opt/pip-tools.d/manifest.te
25+
2226
RUN <<"EOF" bash -ex
23-
install-pax.sh --defer --from_paxml ${REPO_PAXML} --from_praxis ${REPO_PRAXIS} --ref_paxml ${REF_PAXML} --ref_praxis ${REF_PRAXIS}
24-
install-flax.sh --defer
25-
install-te.sh --defer --from ${REPO_TE} --ref ${REF_TE}
26-
27-
if [[ -f /opt/requirements-defer.txt ]]; then
28-
# SKIP_HEAD_INSTALLS avoids having to install jax from Github source so that
29-
# we do not overwrite the jax that was already installed.
30-
SKIP_HEAD_INSTALLS=true pip install -r /opt/requirements-defer.txt
31-
fi
32-
if [[ -f /opt/cleanup.sh ]]; then
33-
bash -ex /opt/cleanup.sh
34-
fi
27+
get-source.sh -f ${REPO_PAXML} -r ${REF_PAXML} -d ${SRC_PATH_PAXML}
28+
get-source.sh -f ${REPO_PRAXIS} -r ${REF_PRAXIS} -d ${SRC_PATH_PRAXIS}
29+
echo "-e file://${SRC_PATH_PAXML}[gpu]" >> /opt/pip-tools.d/manifest.pax
30+
echo "-e file://${SRC_PATH_PRAXIS}" >> /opt/pip-tools.d/manifest.pax
31+
32+
for src in ${SRC_PATH_PAXML} ${SRC_PATH_PRAXIS}; do
33+
pushd ${src}
34+
sed -i "s| @ git+https://github.com/google/flax||g" requirements.in
35+
sed -i "s| @ git+https://github.com/google/jax||g" requirements.in
36+
if git diff --quiet; then
37+
echo "URL specs no longer present in select dependencies for ${src}"
38+
exit 1
39+
else
40+
git commit -a -m "remove URL specs from select dependencies for ${src}"
41+
fi
42+
popd
43+
done
3544
EOF
3645

3746
ADD test-pax.sh /usr/local/bin
47+
48+
###############################################################################
49+
## Install accumulated packages from the base image and the previous stage
50+
###############################################################################
51+
52+
FROM mealkit as final
53+
54+
RUN pip-finalize.sh

0 commit comments

Comments
 (0)