From 8cc0f15e63f3ce3841d2ac421003fa6317225538 Mon Sep 17 00:00:00 2001 From: A9isha Date: Mon, 20 Oct 2025 17:29:49 +0000 Subject: [PATCH 1/4] test ironwood --- docker_build_dependency_image.sh | 19 ++++----- maxtext_grpo_dependencies.Dockerfile | 40 ++++++++++++++----- .../examples/grpo_llama3_1_70b_demo_pw.py | 7 ++-- 3 files changed, 43 insertions(+), 23 deletions(-) diff --git a/docker_build_dependency_image.sh b/docker_build_dependency_image.sh index bf0ccaee1c..5c16beac3e 100644 --- a/docker_build_dependency_image.sh +++ b/docker_build_dependency_image.sh @@ -125,8 +125,8 @@ if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then echo "Building with benchmark-db" docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_db_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . elif [[ ${INSTALL_GRPO} -eq 1 && ${DEVICE} == "tpu" ]]; then - echo "Installing MaxText stable mode dependencies for GRPO" - docker build --network host --build-arg MODE=stable --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . + echo "Installing MaxText stable mode dependencies for GRPO BASEIMAGE=$BASEIMAGE" + docker build --network host --build-arg MODE=stable --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . else docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . fi @@ -142,17 +142,18 @@ if [[ ${INSTALL_GRPO} -eq 1 ]] ; then exit 1 fi - # # To install tpu_commons from a local path, we copy it into the build context, excluding __pycache__. - # # This assumes vllm, tunix, tpu_commons is a sibling directory to the current one (maxtext). - # rsync -a --exclude='__pycache__' ../tpu_commons . - # # To install vllm from a local path, we copy it into the build context, excluding __pycache__. - # # This assumes vllm is a sibling directory to the current one (maxtext). - # rsync -a --exclude='__pycache__' ../vllm . + # To install tpu_commons from a local path, we copy it into the build context, excluding __pycache__. + # This assumes vllm, tunix, tpu_commons is a sibling directory to the current one (maxtext). + rsync -a --exclude='__pycache__' ../tpu_commons . + # To install vllm from a local path, we copy it into the build context, excluding __pycache__. + # This assumes vllm is a sibling directory to the current one (maxtext). + rsync -a --exclude='__pycache__' ../vllm . # rsync -a --exclude='__pycache__' ../tunix . - # # The cleanup is set to run even if the build fails to remove the copied directory. + # The cleanup is set to run even if the build fails to remove the copied directory. # trap "rm -rf ./tpu_commons ./vllm ./tunix" EXIT INT TERM + trap "rm -rf ./tpu_commons ./vllm " EXIT INT TERM docker build \ --network host \ diff --git a/maxtext_grpo_dependencies.Dockerfile b/maxtext_grpo_dependencies.Dockerfile index 065746f698..a78e771318 100644 --- a/maxtext_grpo_dependencies.Dockerfile +++ b/maxtext_grpo_dependencies.Dockerfile @@ -22,7 +22,7 @@ RUN echo "Installing GRPO dependencies (vLLM, tpu-common, tunix) with MODE=${MOD # Uninstall existing jax to avoid conflicts -RUN pip uninstall -y jax jaxlib libtpu +# RUN pip uninstall -y jax jaxlib libtpu RUN pip install aiohttp==3.12.15 @@ -31,9 +31,8 @@ RUN pip install keyring keyrings.google-artifactregistry-auth RUN pip install numba==0.61.2 -# Install vLLM for Jax and TPUs from the artifact registry -RUN VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \ - --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \ +COPY vllm /vllm +RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir --pre \ --extra-index-url https://pypi.org/simple/ \ --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ @@ -41,16 +40,35 @@ RUN VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \ --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ --find-links https://storage.googleapis.com/libtpu-releases/index.html \ --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ - --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \ - vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu + --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -# Install tpu-commons from the artifact registry -RUN pip install --no-cache-dir --pre \ - --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \ +# Install tpu-commons from local source +COPY tpu_commons /tpu_commons +RUN pip install -e /tpu_commons --no-cache-dir --pre \ --extra-index-url https://pypi.org/simple/ \ --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ - --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ - tpu-commons==0.1.2 + --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html + +# # Install vLLM for Jax and TPUs from the artifact registry +# RUN VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \ +# --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \ +# --extra-index-url https://pypi.org/simple/ \ +# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ +# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ +# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ +# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ +# --find-links https://storage.googleapis.com/libtpu-releases/index.html \ +# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ +# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \ +# vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu + +# # Install tpu-commons from the artifact registry +# RUN pip install --no-cache-dir --pre \ +# --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \ +# --extra-index-url https://pypi.org/simple/ \ +# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ +# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ +# tpu-commons==0.1.2 RUN if [ "$MODE" = "grpo-experimental" ]; then \ pip uninstall -y jax jaxlib libtpu && \ diff --git a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py index def23ec943..7cd5784abc 100644 --- a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py +++ b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py @@ -142,7 +142,7 @@ # ====== Input Checkpoint directory ===== -MODEL_CHECKPOINT_PATH = "/path/to/scanned/model/ckpt_load_dir/" +MODEL_CHECKPOINT_PATH = "gs://maxtext-model-checkpoints/llama3.1_70b_instruct/2025-10-15/pathways/scanned/0/items" # ====== Checkpoint directory ===== LOG_DIR = f"{HOME}/content/tensorboard/grpo/logs_llama3/" @@ -150,12 +150,12 @@ os.makedirs(LOG_DIR) # ===== Profiling ===== -PROFILE_DIR = f"/path/to/profile_dir/{run_id}/profiles_llama3/" +PROFILE_DIR = f"gs://mazumdera-test-bucket-europe-west4/rl-tuning/grpo/anisha-{run_id}/profiles_llama3/" if not epath.Path(PROFILE_DIR).exists(): epath.Path(PROFILE_DIR).mkdir(parents=True) # ====== Checkpoint saving ====== -CKPT_DIR = f"/path/to/ckpt_save_dir/{run_id}/ckpts_llama3/" +CKPT_DIR = f"gs://mazumdera-test-bucket-europe-west4/rl-tuning/grpo/anisha-{run_id}/ckpts_llama3/" if not epath.Path(CKPT_DIR).exists(): epath.Path(CKPT_DIR).mkdir(parents=True) @@ -199,6 +199,7 @@ NUM_BATCHES = 4 # 200 # Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be # increased to a max. of 330 (if batch size is 4). +NUM_TEST_BATCHES = 330 NUM_TEST_BATCHES = 5 # 200 EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`. From 6c579fd09a681fb3db740f0d51fc1009e9602208 Mon Sep 17 00:00:00 2001 From: A9isha Date: Tue, 21 Oct 2025 08:07:49 +0000 Subject: [PATCH 2/4] split vllm dependencies and code --- docker_build_dependency_image.sh | 16 +++--- maxtext_grpo_dependencies.Dockerfile | 79 +++++++++++++--------------- 2 files changed, 44 insertions(+), 51 deletions(-) diff --git a/docker_build_dependency_image.sh b/docker_build_dependency_image.sh index 5c16beac3e..19f492d270 100644 --- a/docker_build_dependency_image.sh +++ b/docker_build_dependency_image.sh @@ -126,7 +126,7 @@ if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_db_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . elif [[ ${INSTALL_GRPO} -eq 1 && ${DEVICE} == "tpu" ]]; then echo "Installing MaxText stable mode dependencies for GRPO BASEIMAGE=$BASEIMAGE" - docker build --network host --build-arg MODE=stable --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . + docker build --network host --build-arg MODE=stable --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . else docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . fi @@ -142,18 +142,14 @@ if [[ ${INSTALL_GRPO} -eq 1 ]] ; then exit 1 fi - # To install tpu_commons from a local path, we copy it into the build context, excluding __pycache__. - # This assumes vllm, tunix, tpu_commons is a sibling directory to the current one (maxtext). + # To install from local paths, we copy vllm and tpu_commons into the build context. + # This assumes vllm and tpu_commons are sibling directories to the current one (maxtext). + echo "Copying local vllm and tpu_commons directories into the build context..." rsync -a --exclude='__pycache__' ../tpu_commons . - # To install vllm from a local path, we copy it into the build context, excluding __pycache__. - # This assumes vllm is a sibling directory to the current one (maxtext). rsync -a --exclude='__pycache__' ../vllm . - # rsync -a --exclude='__pycache__' ../tunix . - - # The cleanup is set to run even if the build fails to remove the copied directory. - # trap "rm -rf ./tpu_commons ./vllm ./tunix" EXIT INT TERM - trap "rm -rf ./tpu_commons ./vllm " EXIT INT TERM + # The cleanup is set to run even if the build fails to remove the copied directories. + trap "echo 'Cleaning up copied directories...' && rm -rf ./tpu_commons ./vllm" EXIT INT TERM docker build \ --network host \ diff --git a/maxtext_grpo_dependencies.Dockerfile b/maxtext_grpo_dependencies.Dockerfile index a78e771318..98c6d2dff5 100644 --- a/maxtext_grpo_dependencies.Dockerfile +++ b/maxtext_grpo_dependencies.Dockerfile @@ -15,62 +15,59 @@ ARG BASEIMAGE FROM ${BASEIMAGE} ARG MODE - ENV MODE=$MODE RUN echo "Installing GRPO dependencies (vLLM, tpu-common, tunix) with MODE=${MODE}" - # Uninstall existing jax to avoid conflicts # RUN pip uninstall -y jax jaxlib libtpu -RUN pip install aiohttp==3.12.15 +# --- STAGE 1: Install Static Dependencies --- +# Install any packages *not* defined in your project dependency files +RUN --mount=type=cache,target=/root/.cache/pip pip install \ + aiohttp==3.12.15\ + keyring \ + keyrings.google-artifactregistry-auth \ + numba==0.61.2 -# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically. -RUN pip install keyring keyrings.google-artifactregistry-auth +# --- STAGE 2: Install Project Dependencies (The Main Cached Layer) --- -RUN pip install numba==0.61.2 +# Copy *only* the dependency definition files. +# This assumes vllm and tpu_commons are in the build context, copied from the parent directory. +COPY vllm/requirements/tpu.txt /tmp/ +COPY vllm/requirements/build.txt /tmp/ +COPY vllm/requirements/common.txt /tmp/ +COPY tpu_commons/requirements.txt /tmp/ -COPY vllm /vllm -RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir --pre \ - --extra-index-url https://pypi.org/simple/ \ - --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ - --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ - --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ - --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ - --find-links https://storage.googleapis.com/libtpu-releases/index.html \ - --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ - --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html +# Run the full dependency installation. +# This entire layer is cached and will *only* be rebuilt if +# these .txt files change. +RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \ + # Set the target device so pip installs the right JAX/libtpu + export VLLM_TARGET_DEVICE="tpu" && \ + pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \ + --extra-index-url https://pypi.org/simple/ \ + --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ + --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ + --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ + --find-links https://storage.googleapis.com/libtpu-releases/index.html \ + --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html' -# Install tpu-commons from local source -COPY tpu_commons /tpu_commons -RUN pip install -e /tpu_commons --no-cache-dir --pre \ - --extra-index-url https://pypi.org/simple/ \ - --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ - --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html +# --- STAGE 3: Install Project Source Code --- -# # Install vLLM for Jax and TPUs from the artifact registry -# RUN VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \ -# --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \ -# --extra-index-url https://pypi.org/simple/ \ -# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ -# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ -# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ -# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ -# --find-links https://storage.googleapis.com/libtpu-releases/index.html \ -# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ -# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \ -# vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu +# Now, copy the full source code. This invalidates cache frequently, +# but the next step is fast. +COPY vllm /vllm/ +COPY tpu_commons /tpu_commons/ -# # Install tpu-commons from the artifact registry -# RUN pip install --no-cache-dir --pre \ -# --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \ -# --extra-index-url https://pypi.org/simple/ \ -# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ -# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ -# tpu-commons==0.1.2 +# Install in editable mode. This is lightning-fast because all +# dependencies were installed and cached in STAGE 2. +RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/ -e /tpu_commons/ RUN if [ "$MODE" = "grpo-experimental" ]; then \ + echo "MODE=grpo-experimental: Re-installing JAX/libtpu"; \ pip uninstall -y jax jaxlib libtpu && \ pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \ pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \ From 3f819cc94497cd0be7d57cca6079640741872819 Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 22 Oct 2025 23:03:05 +0000 Subject: [PATCH 3/4] grpo runs of new hardware --- docker_build_dependency_image.sh | 11 +++--- maxtext_grpo_dependencies.Dockerfile | 36 +++++++++++++++---- .../examples/grpo_llama3_1_70b_demo_pw.py | 6 ++-- .../examples/grpo_llama3_1_8b_demo_pw.py | 13 +++---- 4 files changed, 45 insertions(+), 21 deletions(-) diff --git a/docker_build_dependency_image.sh b/docker_build_dependency_image.sh index 19f492d270..a732ec4542 100644 --- a/docker_build_dependency_image.sh +++ b/docker_build_dependency_image.sh @@ -142,14 +142,15 @@ if [[ ${INSTALL_GRPO} -eq 1 ]] ; then exit 1 fi - # To install from local paths, we copy vllm and tpu_commons into the build context. - # This assumes vllm and tpu_commons are sibling directories to the current one (maxtext). - echo "Copying local vllm and tpu_commons directories into the build context..." - rsync -a --exclude='__pycache__' ../tpu_commons . + # To install from local paths, we copy vllm and tpu-inference into the build context. + # This assumes vllm and tpu-inference are sibling directories to the current one (maxtext). + echo "Copying local vllm and tpu-inference directories into the build context..." + rsync -a --exclude='__pycache__' ../tunix . + rsync -a --exclude='__pycache__' ../tpu-inference . rsync -a --exclude='__pycache__' ../vllm . # The cleanup is set to run even if the build fails to remove the copied directories. - trap "echo 'Cleaning up copied directories...' && rm -rf ./tpu_commons ./vllm" EXIT INT TERM + trap "echo 'Cleaning up copied directories...' && rm -rf ./tpu-inference ./vllm" EXIT INT TERM docker build \ --network host \ diff --git a/maxtext_grpo_dependencies.Dockerfile b/maxtext_grpo_dependencies.Dockerfile index 98c6d2dff5..dd8ecd4527 100644 --- a/maxtext_grpo_dependencies.Dockerfile +++ b/maxtext_grpo_dependencies.Dockerfile @@ -17,7 +17,7 @@ FROM ${BASEIMAGE} ARG MODE ENV MODE=$MODE -RUN echo "Installing GRPO dependencies (vLLM, tpu-common, tunix) with MODE=${MODE}" +RUN echo "Installing GRPO dependencies (vLLM, tpu-inference) with MODE=${MODE}" # Uninstall existing jax to avoid conflicts # RUN pip uninstall -y jax jaxlib libtpu @@ -27,23 +27,27 @@ RUN echo "Installing GRPO dependencies (vLLM, tpu-common, tunix) with MODE=${MOD RUN --mount=type=cache,target=/root/.cache/pip pip install \ aiohttp==3.12.15\ keyring \ - keyrings.google-artifactregistry-auth \ + keyrings.google-artifactregistry-auth + +RUN --mount=type=cache,target=/root/.cache/pip pip install \ numba==0.61.2 +# RUN VLLM_TARGET_DEVICE="tpu" pip install vllm # --- STAGE 2: Install Project Dependencies (The Main Cached Layer) --- # Copy *only* the dependency definition files. -# This assumes vllm and tpu_commons are in the build context, copied from the parent directory. +# This assumes vllm and tpu-inference are in the build context, copied from the parent directory. COPY vllm/requirements/tpu.txt /tmp/ COPY vllm/requirements/build.txt /tmp/ COPY vllm/requirements/common.txt /tmp/ -COPY tpu_commons/requirements.txt /tmp/ +COPY tpu-inference/requirements.txt /tmp/ # Run the full dependency installation. # This entire layer is cached and will *only* be rebuilt if # these .txt files change. -RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \ +RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \ # Set the target device so pip installs the right JAX/libtpu + # Install tpu-inference dependencies export VLLM_TARGET_DEVICE="tpu" && \ pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \ --extra-index-url https://pypi.org/simple/ \ @@ -55,16 +59,34 @@ RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \ --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html' + # Install tpu-inference dependencies +RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \ + pip install -r /tmp/requirements.txt --no-cache-dir --pre \ + --extra-index-url https://pypi.org/simple/ \ + --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ + --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ + --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ + --find-links https://storage.googleapis.com/libtpu-releases/index.html \ + --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html' + # --- STAGE 3: Install Project Source Code --- # Now, copy the full source code. This invalidates cache frequently, # but the next step is fast. COPY vllm /vllm/ -COPY tpu_commons /tpu_commons/ +COPY tpu-inference /tpu-inference/ +COPY tunix /tunix + # Install in editable mode. This is lightning-fast because all # dependencies were installed and cached in STAGE 2. -RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/ -e /tpu_commons/ +RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/ +RUN --mount=type=cache,target=/root/.cache/pip pip install -e /tpu-inference/ + +RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps /tunix/ +# RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /tpu-inference/ RUN if [ "$MODE" = "grpo-experimental" ]; then \ echo "MODE=grpo-experimental: Re-installing JAX/libtpu"; \ diff --git a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py index 7cd5784abc..86aff17428 100644 --- a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py +++ b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py @@ -195,12 +195,12 @@ # ====== Training ====== BATCH_SIZE = 1 # Increase `NUM_BATCHES` and `MAX_STEPS` for better results. -# NUM_BATCHES = 3738 -NUM_BATCHES = 4 # 200 +NUM_BATCHES = 3738 +# NUM_BATCHES = 4 # 200 # Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be # increased to a max. of 330 (if batch size is 4). NUM_TEST_BATCHES = 330 -NUM_TEST_BATCHES = 5 # 200 +# NUM_TEST_BATCHES = 5 # 200 EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`. NUM_EPOCHS = 1 # can potentially train for more epochs diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py b/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py index f994e73566..557e8ce9f7 100644 --- a/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py +++ b/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py @@ -142,7 +142,7 @@ # ====== Input Checkpoint directory ===== -MODEL_CHECKPOINT_PATH = "/path/to/scanned/model/ckpt_load_dir/" +MODEL_CHECKPOINT_PATH = "gs://mazumdera-test-bucket-europe-west4/llama3.1-8b-Instruct/scanned-pathways/0/items" # ====== Checkpoint directory ===== LOG_DIR = f"{HOME}/content/tensorboard/grpo/logs_llama3/" @@ -150,12 +150,12 @@ os.makedirs(LOG_DIR) # ===== Profiling ===== -PROFILE_DIR = f"/path/to/profile_dir/{run_id}/profiles_llama3/" +PROFILE_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/v5p-64/llama3-1-8b/profile_dir/{run_id}/profiles_llama3/" if not epath.Path(PROFILE_DIR).exists(): epath.Path(PROFILE_DIR).mkdir(parents=True) # ====== Checkpoint saving ====== -CKPT_DIR = f"/path/to/ckpt_save_dir/{run_id}/ckpts_llama3/" +CKPT_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/v5p-64/llama3-1-8b/ckpt_save_dir/{run_id}/ckpts_llama3/" if not epath.Path(CKPT_DIR).exists(): epath.Path(CKPT_DIR).mkdir(parents=True) @@ -195,11 +195,12 @@ # ====== Training ====== BATCH_SIZE = 1 # Increase `NUM_BATCHES` and `MAX_STEPS` for better results. -# NUM_BATCHES = 3738 -NUM_BATCHES = 4 # 200 +NUM_BATCHES = 3738 +# NUM_BATCHES = 4 # 200 # Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be # increased to a max. of 330 (if batch size is 4). -NUM_TEST_BATCHES = 5 # 200 +NUM_TEST_BATCHES = 330 +# NUM_TEST_BATCHES = 5 # 200 EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`. NUM_EPOCHS = 1 # can potentially train for more epochs From 7de3241adec96e77ee6def86a8d63955308c4755 Mon Sep 17 00:00:00 2001 From: A9isha Date: Mon, 27 Oct 2025 21:50:50 +0000 Subject: [PATCH 4/4] optionally separate out the dpendencies of vllm --- base_requirements/requirements.txt | 2 +- maxtext_grpo_dependencies.Dockerfile | 179 +++++++++++------- maxtext_grpo_dependencies_split.Dockerfile | 96 ++++++++++ requirements.txt | 2 +- requirements_with_jax_ai_image.txt | 2 +- .../examples/grpo_llama3_1_70b_demo_pw.py | 10 +- .../extra_deps_from_github.txt | 2 +- 7 files changed, 222 insertions(+), 71 deletions(-) create mode 100644 maxtext_grpo_dependencies_split.Dockerfile diff --git a/base_requirements/requirements.txt b/base_requirements/requirements.txt index bc3e683451..1d3c99f270 100644 --- a/base_requirements/requirements.txt +++ b/base_requirements/requirements.txt @@ -39,5 +39,5 @@ tiktoken tokamax transformers qwix -google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip +google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/daedc21c393f23449fb54ddc4f75fca34348ea9c.zip mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip diff --git a/maxtext_grpo_dependencies.Dockerfile b/maxtext_grpo_dependencies.Dockerfile index dd8ecd4527..306a904054 100644 --- a/maxtext_grpo_dependencies.Dockerfile +++ b/maxtext_grpo_dependencies.Dockerfile @@ -18,75 +18,126 @@ ARG MODE ENV MODE=$MODE RUN echo "Installing GRPO dependencies (vLLM, tpu-inference) with MODE=${MODE}" +RUN pip uninstall -y jax jaxlib libtpu + +RUN pip install aiohttp==3.12.15 + +# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically. +RUN pip install keyring keyrings.google-artifactregistry-auth + +RUN pip install numba==0.61.2 -# Uninstall existing jax to avoid conflicts -# RUN pip uninstall -y jax jaxlib libtpu - -# --- STAGE 1: Install Static Dependencies --- -# Install any packages *not* defined in your project dependency files -RUN --mount=type=cache,target=/root/.cache/pip pip install \ - aiohttp==3.12.15\ - keyring \ - keyrings.google-artifactregistry-auth - -RUN --mount=type=cache,target=/root/.cache/pip pip install \ - numba==0.61.2 - -# RUN VLLM_TARGET_DEVICE="tpu" pip install vllm -# --- STAGE 2: Install Project Dependencies (The Main Cached Layer) --- - -# Copy *only* the dependency definition files. -# This assumes vllm and tpu-inference are in the build context, copied from the parent directory. -COPY vllm/requirements/tpu.txt /tmp/ -COPY vllm/requirements/build.txt /tmp/ -COPY vllm/requirements/common.txt /tmp/ -COPY tpu-inference/requirements.txt /tmp/ - -# Run the full dependency installation. -# This entire layer is cached and will *only* be rebuilt if -# these .txt files change. -RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \ - # Set the target device so pip installs the right JAX/libtpu - # Install tpu-inference dependencies - export VLLM_TARGET_DEVICE="tpu" && \ - pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \ - --extra-index-url https://pypi.org/simple/ \ - --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ - --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ - --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ - --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ - --find-links https://storage.googleapis.com/libtpu-releases/index.html \ - --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ - --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html' - - # Install tpu-inference dependencies -RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \ - pip install -r /tmp/requirements.txt --no-cache-dir --pre \ - --extra-index-url https://pypi.org/simple/ \ - --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ - --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ - --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ - --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ - --find-links https://storage.googleapis.com/libtpu-releases/index.html \ - --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ - --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html' - -# --- STAGE 3: Install Project Source Code --- - -# Now, copy the full source code. This invalidates cache frequently, -# but the next step is fast. -COPY vllm /vllm/ -COPY tpu-inference /tpu-inference/ COPY tunix /tunix +RUN pip install -e /tunix --no-cache-dir + + +COPY vllm /vllm +RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir --pre \ + --extra-index-url https://pypi.org/simple/ \ + --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ + --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ + --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ + --find-links https://storage.googleapis.com/libtpu-releases/index.html \ + --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + + +COPY tpu-inference /tpu-inference +RUN pip install -e /tpu-inference --no-cache-dir --pre \ + --extra-index-url https://pypi.org/simple/ \ + --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ + --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html + +# # Install vLLM for Jax and TPUs from the artifact registry +# RUN VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \ +# --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \ +# --extra-index-url https://pypi.org/simple/ \ +# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ +# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ +# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ +# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ +# --find-links https://storage.googleapis.com/libtpu-releases/index.html \ +# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ +# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \ +# vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu + +# # Install tpu-commons from the artifact registry +# RUN pip install --no-cache-dir --pre \ +# --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \ +# --extra-index-url https://pypi.org/simple/ \ +# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ +# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ +# tpu-commons==0.1.2 + +# # Uninstall existing jax to avoid conflicts +# # RUN pip uninstall -y jax jaxlib libtpu + +# # --- STAGE 1: Install Static Dependencies --- +# # Install any packages *not* defined in your project dependency files +# RUN --mount=type=cache,target=/root/.cache/pip pip install \ +# aiohttp==3.12.15\ +# keyring \ +# keyrings.google-artifactregistry-auth + +# RUN --mount=type=cache,target=/root/.cache/pip pip install \ +# numba==0.61.2 + +# # RUN VLLM_TARGET_DEVICE="tpu" pip install vllm +# # --- STAGE 2: Install Project Dependencies (The Main Cached Layer) --- + +# # Copy *only* the dependency definition files. +# # This assumes vllm and tpu-inference are in the build context, copied from the parent directory. +# COPY vllm/requirements/tpu.txt /tmp/ +# COPY vllm/requirements/build.txt /tmp/ +# COPY vllm/requirements/common.txt /tmp/ +# COPY tpu-inference/requirements.txt /tmp/ + +# # Run the full dependency installation. +# # This entire layer is cached and will *only* be rebuilt if +# # these .txt files change. +# RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \ +# # Set the target device so pip installs the right JAX/libtpu +# # Install tpu-inference dependencies +# export VLLM_TARGET_DEVICE="tpu" && \ +# pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \ +# --extra-index-url https://pypi.org/simple/ \ +# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ +# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ +# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ +# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ +# --find-links https://storage.googleapis.com/libtpu-releases/index.html \ +# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ +# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html' + +# # Install tpu-inference dependencies +# RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \ +# pip install -r /tmp/requirements.txt --no-cache-dir --pre \ +# --extra-index-url https://pypi.org/simple/ \ +# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ +# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ +# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ +# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ +# --find-links https://storage.googleapis.com/libtpu-releases/index.html \ +# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ +# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html' + +# # --- STAGE 3: Install Project Source Code --- + +# # Now, copy the full source code. This invalidates cache frequently, +# # but the next step is fast. +# COPY vllm /vllm/ +# COPY tpu-inference /tpu-inference/ +# COPY tunix /tunix -# Install in editable mode. This is lightning-fast because all -# dependencies were installed and cached in STAGE 2. -RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/ -RUN --mount=type=cache,target=/root/.cache/pip pip install -e /tpu-inference/ +# # Install in editable mode. This is lightning-fast because all +# # dependencies were installed and cached in STAGE 2. +# RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/ +# RUN --mount=type=cache,target=/root/.cache/pip pip install -e /tpu-inference/ -RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps /tunix/ -# RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /tpu-inference/ +# RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps /tunix/ +# # RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /tpu-inference/ RUN if [ "$MODE" = "grpo-experimental" ]; then \ echo "MODE=grpo-experimental: Re-installing JAX/libtpu"; \ diff --git a/maxtext_grpo_dependencies_split.Dockerfile b/maxtext_grpo_dependencies_split.Dockerfile new file mode 100644 index 0000000000..dd8ecd4527 --- /dev/null +++ b/maxtext_grpo_dependencies_split.Dockerfile @@ -0,0 +1,96 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ARG BASEIMAGE +FROM ${BASEIMAGE} +ARG MODE +ENV MODE=$MODE + +RUN echo "Installing GRPO dependencies (vLLM, tpu-inference) with MODE=${MODE}" + +# Uninstall existing jax to avoid conflicts +# RUN pip uninstall -y jax jaxlib libtpu + +# --- STAGE 1: Install Static Dependencies --- +# Install any packages *not* defined in your project dependency files +RUN --mount=type=cache,target=/root/.cache/pip pip install \ + aiohttp==3.12.15\ + keyring \ + keyrings.google-artifactregistry-auth + +RUN --mount=type=cache,target=/root/.cache/pip pip install \ + numba==0.61.2 + +# RUN VLLM_TARGET_DEVICE="tpu" pip install vllm +# --- STAGE 2: Install Project Dependencies (The Main Cached Layer) --- + +# Copy *only* the dependency definition files. +# This assumes vllm and tpu-inference are in the build context, copied from the parent directory. +COPY vllm/requirements/tpu.txt /tmp/ +COPY vllm/requirements/build.txt /tmp/ +COPY vllm/requirements/common.txt /tmp/ +COPY tpu-inference/requirements.txt /tmp/ + +# Run the full dependency installation. +# This entire layer is cached and will *only* be rebuilt if +# these .txt files change. +RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \ + # Set the target device so pip installs the right JAX/libtpu + # Install tpu-inference dependencies + export VLLM_TARGET_DEVICE="tpu" && \ + pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \ + --extra-index-url https://pypi.org/simple/ \ + --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ + --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ + --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ + --find-links https://storage.googleapis.com/libtpu-releases/index.html \ + --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html' + + # Install tpu-inference dependencies +RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \ + pip install -r /tmp/requirements.txt --no-cache-dir --pre \ + --extra-index-url https://pypi.org/simple/ \ + --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ + --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ + --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ + --find-links https://storage.googleapis.com/libtpu-releases/index.html \ + --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html' + +# --- STAGE 3: Install Project Source Code --- + +# Now, copy the full source code. This invalidates cache frequently, +# but the next step is fast. +COPY vllm /vllm/ +COPY tpu-inference /tpu-inference/ +COPY tunix /tunix + + +# Install in editable mode. This is lightning-fast because all +# dependencies were installed and cached in STAGE 2. +RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/ +RUN --mount=type=cache,target=/root/.cache/pip pip install -e /tpu-inference/ + +RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps /tunix/ +# RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /tpu-inference/ + +RUN if [ "$MODE" = "grpo-experimental" ]; then \ + echo "MODE=grpo-experimental: Re-installing JAX/libtpu"; \ + pip uninstall -y jax jaxlib libtpu && \ + pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \ + pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \ + fi diff --git a/requirements.txt b/requirements.txt index a5d940a9c3..cf3bf2d0d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -38,5 +38,5 @@ tensorflow-text tensorflow tiktoken transformers -google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip +google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/daedc21c393f23449fb54ddc4f75fca34348ea9c.zip mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip diff --git a/requirements_with_jax_ai_image.txt b/requirements_with_jax_ai_image.txt index 406dc60c8e..4c2bd727db 100644 --- a/requirements_with_jax_ai_image.txt +++ b/requirements_with_jax_ai_image.txt @@ -3,7 +3,7 @@ datasets @ https://github.com/huggingface/datasets/archive/6790e138c00b87a1ddc72184f89e7814cf784360.zip flax>=0.11.0 google-api-python-client -google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip +google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/daedc21c393f23449fb54ddc4f75fca34348ea9c.zip grain[parquet]>=0.2.12 jaxtyping jsonlines diff --git a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py index 86aff17428..f50661f6ac 100644 --- a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py +++ b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py @@ -142,20 +142,24 @@ # ====== Input Checkpoint directory ===== -MODEL_CHECKPOINT_PATH = "gs://maxtext-model-checkpoints/llama3.1_70b_instruct/2025-10-15/pathways/scanned/0/items" +MODEL_CHECKPOINT_PATH = "gs://mazumdera-test-bucket-europe-west4/llama3.1_70b_instruct/2025-10-15/pathways/scanned/0/items" +# MODEL_CHECKPOINT_PATH = "gs://maxtext-model-checkpoints/llama3.1_70b_instruct/2025-10-15/pathways/scanned/0/items" # ====== Checkpoint directory ===== -LOG_DIR = f"{HOME}/content/tensorboard/grpo/logs_llama3/" +LOG_DIR = f"gs://mazumdera-test-bucket-europe-west4/rl-tuning/grpo/anisha-{run_id}/tensorboard/grpo/logs_llama3/" +# LOG_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/rl-tuning/grpo/anisha-{run_id}/tensorboard/grpo/logs_llama3/" if not os.path.exists(LOG_DIR): - os.makedirs(LOG_DIR) + epath.Path(LOG_DIR).mkdir(parents=True) # ===== Profiling ===== PROFILE_DIR = f"gs://mazumdera-test-bucket-europe-west4/rl-tuning/grpo/anisha-{run_id}/profiles_llama3/" +# PROFILE_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/rl-tuning/grpo/anisha-{run_id}/profiles_llama3/" if not epath.Path(PROFILE_DIR).exists(): epath.Path(PROFILE_DIR).mkdir(parents=True) # ====== Checkpoint saving ====== CKPT_DIR = f"gs://mazumdera-test-bucket-europe-west4/rl-tuning/grpo/anisha-{run_id}/ckpts_llama3/" +# CKPT_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/rl-tuning/grpo/anisha-{run_id}/ckpts_llama3/" if not epath.Path(CKPT_DIR).exists(): epath.Path(CKPT_DIR).mkdir(parents=True) diff --git a/src/install_maxtext_extra_deps/extra_deps_from_github.txt b/src/install_maxtext_extra_deps/extra_deps_from_github.txt index 676f2e58e7..9f7bf08afb 100644 --- a/src/install_maxtext_extra_deps/extra_deps_from_github.txt +++ b/src/install_maxtext_extra_deps/extra_deps_from_github.txt @@ -1,2 +1,2 @@ -google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip +google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/daedc21c393f23449fb54ddc4f75fca34348ea9c.zip mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip