Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stingram/network issue fix #592

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
42 changes: 42 additions & 0 deletions dags/common/model_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Common model perf configs"""

import enum


class MaxTextV5eModelConfigs(enum.Enum):
# Refers to model configs in https://github.com/AI-Hypercomputer/maxtext/blob/main/benchmarks/maxtext_v5e_model_configs.py
DEFAULT_16B = "default_16b_v5e_256"
DEFAULT_32B = "default_32b_v5e_256"
DEFAULT_64B = "default_64b_v5e_256"
DEFAULT_128B = "default_128b_v5e_256"
GPT3_175B = "gpt_3_175b_v5e_256"
LLAMA2_7B = "llama2_7b_v5e_256"
LLAMA2_13B = "llama2_13b_v5e_256"
LLAMA2_70B = "llama2_70b_v5e_256"


class MaxTextTrilliumModelConfigs(enum.Enum):
# Refers to model configs in https://github.com/AI-Hypercomputer/maxtext/blob/main/benchmarks/maxtext_trillium_model_configs.py
GPT3_175B = "gpt_3_175b"
LLAMA2_70B_4096 = "llama2_70b_4096_synthetic"
LLAMA3_1_8B_8192 = "llama3_1_8b_8192"
LLAMA3_1_70B_8192 = "llama3_1_70b_8192"
LLAMA3_1_70B_129024 = "llama3_1_70b_129024"
LLAMA3_1_405B_8192 = "llama3_1_405b_8192_fsdp_dcn"
MIXTRAL_8X7B_DROPLESS = "mixtral_8x7b_dropless"
MIXTRAL_8X7B_DROPPED = "mixtral_8x7b_dropped"
MIXTRAL_8X7B_DROPPED_INT8 = "mixtral_8x7b_dropped_int8"
25 changes: 0 additions & 25 deletions dags/common/quarantined_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,6 @@ class QuarantineTests:
# DAG: maxtext_end_to_end
"chained_tests_gemma-7b_stable": TestInfo(team.LLM_DEVX, "2024-11-12"),
"chained_tests_gemma-7b_nightly": TestInfo(team.LLM_DEVX, "2024-11-12"),
"chained_tests_mixtral-8x7b_stable": TestInfo(
team.SPARSITY_DIFFUSION_DEVX, "2024-11-12"
),
"chained_tests_mixtral-8x7b_nightly": TestInfo(
team.SPARSITY_DIFFUSION_DEVX, "2024-11-12"
),
"maxtext_stable_mixtral-8x22b-v4-128": TestInfo(
team.SPARSITY_DIFFUSION_DEVX, "2024-11-12"
),
"maxtext_nightly_mixtral-8x22b-v4-128": TestInfo(
team.SPARSITY_DIFFUSION_DEVX, "2024-11-12"
),
"chained_tests_llama2-70b_stable": TestInfo(team.LLM_DEVX, "2024-11-12"),
"chained_tests_llama2-70b_nightly": TestInfo(team.LLM_DEVX, "2024-11-12"),
# DAG: jax_stable_stack_gpu_e2e
Expand Down Expand Up @@ -177,19 +165,6 @@ class QuarantineTests:
"mxla-gpt3-6b-nightly-gke-8xv5p-8": TestInfo(
team.PERFORMANCE, "2024-11-12"
),
# DAG: mxla_maxtext_nightly_gke
"mxla-maxtext-nightly-gke-v5p-8": TestInfo(
team.PERFORMANCE, "2024-11-12"
),
"mxla-maxtext-nightly-gke-2xv5p-8": TestInfo(
team.PERFORMANCE, "2024-11-12"
),
"mxla-maxtext-nightly-gke-4xv5p-8": TestInfo(
team.PERFORMANCE, "2024-11-12"
),
"mxla-maxtext-nightly-gke-8xv5p-8": TestInfo(
team.PERFORMANCE, "2024-11-12"
),
# DAG: maxtext_trillium_configs_perf
"maxtext-llama2_70b_4096-stable-3-2xv6e-256": TestInfo(
team.PERFORMANCE, "2024-11-12"
Expand Down
6 changes: 3 additions & 3 deletions dags/common/vm_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,11 @@ class XpkClusters:
zone=Zone.US_CENTRAL2_B.value,
)
TPU_V5P_8_CLUSTER = XpkClusterConfig(
name="v5p-8-bodaborg-us-east5-a",
name="v5p-8-bodaborg-europe-west4-b",
device_version=TpuVersion.V5P,
core_count=8,
project=Project.TPU_PROD_ENV_LARGE_CONT.value,
zone=Zone.US_EAST5_A.value,
project=Project.CLOUD_TPU_MULTIPOD_DEV.value,
zone=Zone.EUROPE_WEST4_B.value,
)
TPU_V5E_256_CLUSTER = XpkClusterConfig(
name="v5e-256-bodaborg-europe-west4",
Expand Down
2 changes: 1 addition & 1 deletion dags/legacy_test/tests/pytorch/nightly/common.libsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ local volumes = import 'templates/volumes.libsonnet';
sudo systemctl disable unattended-upgrades || true
sudo killall --signal SIGKILL unattended-upgrades || true
sudo dpkg --configure -a || true
sudo apt purge unattended-upgrades -y || true
sudo rm /var/lib/dpkg/lock-frontend || true
sudo apt purge unattended-upgrades -y || true
echo "unattended-upgrades stopped."

sudo apt-get -y update
Expand Down
100 changes: 100 additions & 0 deletions dags/legacy_test/tests/pytorch/nightly/llama2-model.libsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,114 @@ local utils = import 'templates/utils.libsonnet';
||| % common.HuggingfacePipVersionConstraints,
},
},
local llama3_train = self.llama3_train,
llama3_train:: common.PyTorchTest + common.Functional + common.PyTorchTpuVmMixin {
modelName: 'llama3-train',
command: [
'python',
'transformers/examples/pytorch/language-modeling/run_clm.py',
'--dataset_name=wikitext',
'--dataset_config_name=wikitext-2-raw-v1',
'--per_device_train_batch_size=4',
'--do_train',
'--output_dir=./tmp/test-clm',
'--overwrite_output_dir',
'--config_name=./llama_3/config.json',
'--cache_dir=./cache',
'--tokenizer_name=./llama_3/tokenizer/',
'--block_size=8192',
'--optim=adafactor',
'--save_strategy=no',
'--logging_strategy=no',
'--fsdp=full_shard',
'--fsdp_config=./llama_3/fsdp_config.json',
'--torch_dtype=bfloat16',
'--dataloader_drop_last=yes',
'--flash_attention',
'--max_steps=10',
],
tpuSettings+: {
tpuVmExports+: |||
export PJRT_DEVICE=TPU
export XLA_USE_SPMD=1
|||,
tpuVmExtraSetup: |||
cat > ~/hf-constraints.txt << 'HF_CONSTRAINTS_EOF'
%s
HF_CONSTRAINTS_EOF

git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git

# install tokenizer model
curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-linux-x86_64.tar.gz
tar -xf google-cloud-cli-linux-x86_64.tar.gz
yes | ./google-cloud-sdk/install.sh
google-cloud-sdk/bin/gsutil cp -r gs://pytorch-airflow/llama_3/ .

cd transformers
sudo pip3 install -e . -c ~/hf-constraints.txt
pip3 install 'torch_xla[pallas]' -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip3 install datasets evaluate scikit-learn accelerate -c ~/hf-constraints.txt
||| % common.HuggingfacePipVersionConstraints,
},
},

local llama3_train_2_slice = self.llama3_train_2_slice,
llama3_train_2_slice:: llama3_train {
modelName: 'llama3-train-2-slice',
command: [
'python',
'transformers/examples/pytorch/language-modeling/run_clm.py',
'--dataset_name=wikitext',
'--dataset_config_name=wikitext-2-raw-v1',
'--per_device_train_batch_size=8',
'--do_train',
'--output_dir=./tmp/test-clm',
'--overwrite_output_dir',
'--config_name=./llama_3/config.json',
'--cache_dir=./cache',
'--tokenizer_name=./llama_3/tokenizer/',
'--block_size=8192',
'--optim=adafactor',
'--save_strategy=no',
'--logging_strategy=no',
'--fsdp=full_shard',
'--fsdp_config=./llama_3/fsdp_config.json',
'--torch_dtype=bfloat16',
'--dataloader_drop_last=yes',
'--flash_attention',
'--max_steps=10',
]
},

local v4_8 = self.v4_8,
v4_8:: {
accelerator: tpus.v4_8,
},

local v5p_8 = self.v5p_8,
v5p_8:: {
tpuSettings+: {
softwareVersion: 'v2-alpha-tpuv5',
},
accelerator: tpus.v5p_8,
},

local trillium_4 = self.trillium_4,
trillium_4:: {
tpuSettings+: {
softwareVersion: 'v2-alpha-tpuv6e',
},
accelerator: tpus.trillium_4,
},

configs: [
llama2 + infer + v4_8 + timeouts.Hours(3),
llama2 + spmd + v4_8 + timeouts.Hours(3),
llama2 + infer + v5p_8 + timeouts.Hours(3),
llama2 + spmd + v5p_8 + timeouts.Hours(3),
llama3_train + v5p_8 + timeouts.Hours(3),
llama3_train + trillium_4 + timeouts.Hours(3),
llama3_train_2_slice + v5p_8 + timeouts.Hours(3),
],
}
18 changes: 8 additions & 10 deletions dags/legacy_test/tests/pytorch/r2.6/common.libsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@ local mixins = import 'templates/mixins.libsonnet';
local utils = import 'templates/utils.libsonnet';
local volumes = import 'templates/volumes.libsonnet';

local rcVersion = 'rc10';

{
local r2_6 = {
frameworkPrefix: 'pt-2-6',
tpuSettings+: {
softwareVersion: 'tpu-ubuntu2204-base',
},
imageTag: 'r2.6.0-%(rc)s_3.10' % {rc: rcVersion},
imageTag: 'r2.6.0_3.10',
},
PyTorchTest:: common.PyTorchTest + r2_6 {
local config = self,
Expand Down Expand Up @@ -97,8 +95,8 @@ local rcVersion = 'rc10';
sudo systemctl disable unattended-upgrades || true
sudo killall --signal SIGKILL unattended-upgrades || true
sudo dpkg --configure -a || true
sudo apt purge unattended-upgrades -y || true
sudo rm /var/lib/dpkg/lock-frontend || true
sudo apt purge unattended-upgrades -y || true
echo "unattended-upgrades stopped."

sudo apt-get -y update
Expand All @@ -109,13 +107,13 @@ local rcVersion = 'rc10';
pip install torch==2.6 --index-url https://download.pytorch.org/whl/test/cpu
# torchvision commit reference: https://github.com/pytorch/pytorch/blob/release/2.6/.github/ci_commit_pins/vision.txt
pip install --user --no-use-pep517 "git+https://github.com/pytorch/vision.git@d23a6e1664d20707c11781299611436e1f0c104f"
pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%(rc)s-cp310-cp310-manylinux_2_28_x86_64.whl
pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0-cp310-cp310-manylinux_2_28_x86_64.whl
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
pip install pillow
git clone --depth=1 https://github.com/pytorch/pytorch.git
cd pytorch
git clone -b v2.6.0-%(rc)s https://github.com/pytorch/xla.git
||| % {rc: rcVersion},
git clone -b v2.6.0 https://github.com/pytorch/xla.git
|||,
},
podTemplate+:: {
spec+: {
Expand Down Expand Up @@ -152,16 +150,16 @@ local rcVersion = 'rc10';
pip uninstall -y torch torchvision
pip install torch==2.6 --index-url https://download.pytorch.org/whl/test/cpu
pip install --user --no-use-pep517 "git+https://github.com/pytorch/vision.git@d23a6e1664d20707c11781299611436e1f0c104f"
pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%(rc)s-cp310-cp310-manylinux_2_28_x86_64.whl
pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0-cp310-cp310-manylinux_2_28_x86_64.whl

mkdir -p pytorch/xla
git clone -b v2.6.0-%(rc)s https://github.com/pytorch/xla.git pytorch/xla
git clone -b v2.6.0 https://github.com/pytorch/xla.git pytorch/xla

%(cmd)s

# Run whatever is in `command` here
"${@:0}"
||| % {cmd: config.tpuSettings.tpuVmExports, rc: rcVersion},
||| % {cmd: config.tpuSettings.tpuVmExports},
],
command: [
'torchrun',
Expand Down
31 changes: 30 additions & 1 deletion dags/legacy_test/tests/pytorch/r2.6/llama2-model.libsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ local utils = import 'templates/utils.libsonnet';
'transformers/examples/pytorch/language-modeling/run_clm.py',
'--dataset_name=wikitext',
'--dataset_config_name=wikitext-2-raw-v1',
'--per_device_train_batch_size=2',
'--per_device_train_batch_size=4',
'--do_train',
'--output_dir=./tmp/test-clm',
'--overwrite_output_dir',
Expand Down Expand Up @@ -182,6 +182,34 @@ local utils = import 'templates/utils.libsonnet';
},
},

local llama3_train_2_slice = self.llama3_train_2_slice,
llama3_train_2_slice:: llama3_train {
modelName: 'llama3-train-2-slice',
command: [
'python',
'transformers/examples/pytorch/language-modeling/run_clm.py',
'--dataset_name=wikitext',
'--dataset_config_name=wikitext-2-raw-v1',
'--per_device_train_batch_size=8',
'--do_train',
'--output_dir=./tmp/test-clm',
'--overwrite_output_dir',
'--config_name=./llama_3/config.json',
'--cache_dir=./cache',
'--tokenizer_name=./llama_3/tokenizer/',
'--block_size=8192',
'--optim=adafactor',
'--save_strategy=no',
'--logging_strategy=no',
'--fsdp=full_shard',
'--fsdp_config=./llama_3/fsdp_config.json',
'--torch_dtype=bfloat16',
'--dataloader_drop_last=yes',
'--flash_attention',
'--max_steps=10',
]
},

local v4_8 = self.v4_8,
v4_8:: {
accelerator: tpus.v4_8,
Expand Down Expand Up @@ -210,5 +238,6 @@ local utils = import 'templates/utils.libsonnet';
llama2 + spmd + v5p_8 + timeouts.Hours(3),
llama3_train + v5p_8 + timeouts.Hours(3),
llama3_train + trillium_4 + timeouts.Hours(3),
llama3_train_2_slice + v5p_8 + timeouts.Hours(3),
],
}
6 changes: 4 additions & 2 deletions dags/map_reproducibility/a3ultra_mixtral_8_7b_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
VALUE_YAML_PATH = (
f"training/{HYPERCOMPUTER}/{MODEL_ID}/nemo-pretraining-gke/values.yaml"
)
CLUSTER = "gke-a3ultra-map"
CLUSTER = "gke-a3u-map-01-31"
CLUSTER_REGION = "europe-west1"
SOFTWARE_ID = "pytorch_nemo"
IMAGE_VERSION = "nemo_workload:24.07"
Expand Down Expand Up @@ -130,7 +130,9 @@ def run_aotc_workload():
accelerator_type,
tmpdir,
)
+ cleanup_cmds()
# DEBUG: to clean-up, get manifest by doing: helm list | grep regression | awk '{print $1}'
# + cleanup_cmds()

),
],
cwd=tmpdir,
Expand Down
6 changes: 4 additions & 2 deletions dags/map_reproducibility/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,12 @@ def helm_apply_cmds(
gcs_cmd = ""
if hypercomputer == "a3ultra":
gcs_cmd = f" --set volumes.gcsMounts[0].bucketName={BUCKET_NAME}"
network_prefix = "gke-a3u-map-01-31"
gcs_cmd += f" --set clusterName={network_prefix}"
else:
gcs_cmd = f" --set workload.gcsBucketForDataCataPath={BUCKET_NAME}"
set_aotc = ""
if aotc is True:
if aotc:
set_aotc = " --set-string workload.aotc=true "
helm_cmds = (
" helm install -f values.yaml "
Expand Down Expand Up @@ -178,10 +180,10 @@ def get_nemo_metrics_cmds(

def cleanup_cmds():
cleanup = (
"helm uninstall $JOB_NAME",
"kubectl get pods "
"--no-headers=true | awk '{print $1}' "
"| grep $JOB_NAME | xargs kubectl delete pods",
"helm uninstall $JOB_NAME",
)
return cleanup

Expand Down
2 changes: 1 addition & 1 deletion dags/multipod/configs/gke_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def get_gke_maxtext_nightly_config(
f" python3 MaxText/train.py MaxText/configs/base.yml run_name={run_name}"
f" base_output_directory={base_output_directory}"
" dataset_path=gs://max-datasets-rogue dataset_type=synthetic"
" per_device_batch_size=12 reuse_example_batch=1 global_parameter_scale=1 metrics_file='metrics.txt'"
" model_name=llama3-8b per_device_batch_size=12 reuse_example_batch=1 metrics_file='metrics.txt'"
" steps=50 enable_checkpointing=false profiler=xplane upload_all_profiler_results=true skip_first_n_steps_for_profiler=10 profiler_steps=10 gcs_metrics=true"
),
)
Expand Down
Loading