Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions docker_build_dependency_image.sh
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then
echo "Building with benchmark-db"
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_db_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
elif [[ ${INSTALL_GRPO} -eq 1 && ${DEVICE} == "tpu" ]]; then
echo "Installing MaxText stable mode dependencies for GRPO"
docker build --network host --build-arg MODE=stable --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
echo "Installing MaxText stable mode dependencies for GRPO BASEIMAGE=$BASEIMAGE"
docker build --network host --build-arg MODE=stable --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
else
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
fi
Expand All @@ -142,17 +142,15 @@ if [[ ${INSTALL_GRPO} -eq 1 ]] ; then
exit 1
fi

# # To install tpu_commons from a local path, we copy it into the build context, excluding __pycache__.
# # This assumes vllm, tunix, tpu_commons is a sibling directory to the current one (maxtext).
# rsync -a --exclude='__pycache__' ../tpu_commons .
# # To install vllm from a local path, we copy it into the build context, excluding __pycache__.
# # This assumes vllm is a sibling directory to the current one (maxtext).
# rsync -a --exclude='__pycache__' ../vllm .
# To install from local paths, we copy vllm and tpu-inference into the build context.
# This assumes vllm and tpu-inference are sibling directories to the current one (maxtext).
echo "Copying local vllm and tpu-inference directories into the build context..."
rsync -a --exclude='__pycache__' ../tunix .
rsync -a --exclude='__pycache__' ../tpu-inference .
rsync -a --exclude='__pycache__' ../vllm .

# rsync -a --exclude='__pycache__' ../tunix .

# # The cleanup is set to run even if the build fails to remove the copied directory.
# trap "rm -rf ./tpu_commons ./vllm ./tunix" EXIT INT TERM
# The cleanup is set to run even if the build fails to remove the copied directories.
trap "echo 'Cleaning up copied directories...' && rm -rf ./tpu-inference ./vllm" EXIT INT TERM

docker build \
--network host \
Expand Down
101 changes: 69 additions & 32 deletions maxtext_grpo_dependencies.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,44 +15,81 @@
ARG BASEIMAGE
FROM ${BASEIMAGE}
ARG MODE

ENV MODE=$MODE

RUN echo "Installing GRPO dependencies (vLLM, tpu-common, tunix) with MODE=${MODE}"

RUN echo "Installing GRPO dependencies (vLLM, tpu-inference) with MODE=${MODE}"

# Uninstall existing jax to avoid conflicts
RUN pip uninstall -y jax jaxlib libtpu

RUN pip install aiohttp==3.12.15

# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically.
RUN pip install keyring keyrings.google-artifactregistry-auth

RUN pip install numba==0.61.2

# Install vLLM for Jax and TPUs from the artifact registry
RUN VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \
--index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
--extra-index-url https://pypi.org/simple/ \
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu

# Install tpu-commons from the artifact registry
RUN pip install --no-cache-dir --pre \
--index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
--extra-index-url https://pypi.org/simple/ \
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
tpu-commons==0.1.2
# RUN pip uninstall -y jax jaxlib libtpu

# --- STAGE 1: Install Static Dependencies ---
# Install any packages *not* defined in your project dependency files
RUN --mount=type=cache,target=/root/.cache/pip pip install \
aiohttp==3.12.15\
keyring \
keyrings.google-artifactregistry-auth

RUN --mount=type=cache,target=/root/.cache/pip pip install \
numba==0.61.2

# RUN VLLM_TARGET_DEVICE="tpu" pip install vllm
# --- STAGE 2: Install Project Dependencies (The Main Cached Layer) ---

# Copy *only* the dependency definition files.
# This assumes vllm and tpu-inference are in the build context, copied from the parent directory.
COPY vllm/requirements/tpu.txt /tmp/
COPY vllm/requirements/build.txt /tmp/
COPY vllm/requirements/common.txt /tmp/
COPY tpu-inference/requirements.txt /tmp/

# Run the full dependency installation.
# This entire layer is cached and will *only* be rebuilt if
# these .txt files change.
RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
# Set the target device so pip installs the right JAX/libtpu
# Install tpu-inference dependencies
export VLLM_TARGET_DEVICE="tpu" && \
pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \
--extra-index-url https://pypi.org/simple/ \
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

# Install tpu-inference dependencies
RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \
pip install -r /tmp/requirements.txt --no-cache-dir --pre \
--extra-index-url https://pypi.org/simple/ \
--extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
--find-links https://storage.googleapis.com/libtpu-wheels/index.html \
--find-links https://storage.googleapis.com/libtpu-releases/index.html \
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

# --- STAGE 3: Install Project Source Code ---

# Now, copy the full source code. This invalidates cache frequently,
# but the next step is fast.
COPY vllm /vllm/
COPY tpu-inference /tpu-inference/
COPY tunix /tunix


# Install in editable mode. This is lightning-fast because all
# dependencies were installed and cached in STAGE 2.
RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/
RUN --mount=type=cache,target=/root/.cache/pip pip install -e /tpu-inference/

RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps /tunix/
# RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /tpu-inference/

RUN if [ "$MODE" = "grpo-experimental" ]; then \
echo "MODE=grpo-experimental: Re-installing JAX/libtpu"; \
pip uninstall -y jax jaxlib libtpu && \
pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \
pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \
Expand Down
13 changes: 7 additions & 6 deletions src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,20 +142,20 @@


# ====== Input Checkpoint directory =====
MODEL_CHECKPOINT_PATH = "/path/to/scanned/model/ckpt_load_dir/"
MODEL_CHECKPOINT_PATH = "gs://maxtext-model-checkpoints/llama3.1_70b_instruct/2025-10-15/pathways/scanned/0/items"

# ====== Checkpoint directory =====
LOG_DIR = f"{HOME}/content/tensorboard/grpo/logs_llama3/"
if not os.path.exists(LOG_DIR):
os.makedirs(LOG_DIR)

# ===== Profiling =====
PROFILE_DIR = f"/path/to/profile_dir/{run_id}/profiles_llama3/"
PROFILE_DIR = f"gs://mazumdera-test-bucket-europe-west4/rl-tuning/grpo/anisha-{run_id}/profiles_llama3/"
if not epath.Path(PROFILE_DIR).exists():
epath.Path(PROFILE_DIR).mkdir(parents=True)

# ====== Checkpoint saving ======
CKPT_DIR = f"/path/to/ckpt_save_dir/{run_id}/ckpts_llama3/"
CKPT_DIR = f"gs://mazumdera-test-bucket-europe-west4/rl-tuning/grpo/anisha-{run_id}/ckpts_llama3/"

if not epath.Path(CKPT_DIR).exists():
epath.Path(CKPT_DIR).mkdir(parents=True)
Expand Down Expand Up @@ -195,11 +195,12 @@
# ====== Training ======
BATCH_SIZE = 1
# Increase `NUM_BATCHES` and `MAX_STEPS` for better results.
# NUM_BATCHES = 3738
NUM_BATCHES = 4 # 200
NUM_BATCHES = 3738
# NUM_BATCHES = 4 # 200
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
# increased to a max. of 330 (if batch size is 4).
NUM_TEST_BATCHES = 5 # 200
NUM_TEST_BATCHES = 330
# NUM_TEST_BATCHES = 5 # 200

EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
NUM_EPOCHS = 1 # can potentially train for more epochs
Expand Down
13 changes: 7 additions & 6 deletions src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,20 +142,20 @@


# ====== Input Checkpoint directory =====
MODEL_CHECKPOINT_PATH = "/path/to/scanned/model/ckpt_load_dir/"
MODEL_CHECKPOINT_PATH = "gs://mazumdera-test-bucket-europe-west4/llama3.1-8b-Instruct/scanned-pathways/0/items"

# ====== Checkpoint directory =====
LOG_DIR = f"{HOME}/content/tensorboard/grpo/logs_llama3/"
if not os.path.exists(LOG_DIR):
os.makedirs(LOG_DIR)

# ===== Profiling =====
PROFILE_DIR = f"/path/to/profile_dir/{run_id}/profiles_llama3/"
PROFILE_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/v5p-64/llama3-1-8b/profile_dir/{run_id}/profiles_llama3/"
if not epath.Path(PROFILE_DIR).exists():
epath.Path(PROFILE_DIR).mkdir(parents=True)

# ====== Checkpoint saving ======
CKPT_DIR = f"/path/to/ckpt_save_dir/{run_id}/ckpts_llama3/"
CKPT_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/v5p-64/llama3-1-8b/ckpt_save_dir/{run_id}/ckpts_llama3/"

if not epath.Path(CKPT_DIR).exists():
epath.Path(CKPT_DIR).mkdir(parents=True)
Expand Down Expand Up @@ -195,11 +195,12 @@
# ====== Training ======
BATCH_SIZE = 1
# Increase `NUM_BATCHES` and `MAX_STEPS` for better results.
# NUM_BATCHES = 3738
NUM_BATCHES = 4 # 200
NUM_BATCHES = 3738
# NUM_BATCHES = 4 # 200
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
# increased to a max. of 330 (if batch size is 4).
NUM_TEST_BATCHES = 5 # 200
NUM_TEST_BATCHES = 330
# NUM_TEST_BATCHES = 5 # 200

EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
NUM_EPOCHS = 1 # can potentially train for more epochs
Expand Down
Loading