-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathmaxtext_configs_hybridsim.py
120 lines (108 loc) · 4.73 KB
/
maxtext_configs_hybridsim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A DAG to run AOT compilation and HybridSim tests for MaxText model configs on TPU v4, v5e.
"""
import datetime
from airflow import models
from airflow.utils.task_group import TaskGroup
from dags import composer_env
from dags.common.quarantined_tests import QuarantineTests
from dags.common import test_owner
from dags.common.vm_resource import TpuVersion, Zone, DockerImage, XpkClusters, Project
from dags.multipod.configs import gke_config
from xlml.utils import name_format
from dags.multipod.configs import gke_config
from xlml.apis import metric_config
# Run once a day at 1 pm UTC (5 am PST / 6 am PDT)
SCHEDULED_TIME = "0 13 * * *" if composer_env.is_prod_env() else None
def hybridsim_compile_and_run(test_group_id):
with TaskGroup(group_id=test_group_id, prefix_group_id=False) as group:
gcs_subfolder = f"{test_owner.Team.MULTIPOD.value}/maxtext"
shared_gcs_location = name_format.generate_gcs_folder_location.override(
task_id=f"{test_group_id}_generate_gcs_folder_location"
)(
f"{gcs_subfolder}/maxtext_configs_hybridsim/v{tpu.value}",
test_group_id,
)
# Run AOT workload: generate HLO, upload to GCS
aot_cmd = (
'export XLA_FLAGS="--xla_dump_to=/tmp/xla_dump/ --xla_dump_large_constants"',
f"bash MaxText/configs/v{v5e_alt if tpu.value == TpuVersion.V5E.value else tpu.value}/{model_size}.sh EXECUTABLE=train_compile.py M_COMPILE_TOPOLOGY=v{v5e_alt if tpu.value == TpuVersion.V5E.value else tpu.value}-{num_cores} M_COMPILE_TOPOLOGY_NUM_SLICES={n}",
"gsutil -m cp -r /tmp/xla_dump/ ${GCS_OUTPUT}",
)
maxtext_aot = gke_config.get_gke_config(
time_out_in_min=240,
test_name=f"maxtext-{model_size}-{n}xv{tpu.value}-{num_cores}-aot",
run_model_cmds=aot_cmd,
docker_image=DockerImage.MAXTEXT_TPU_JAX_NIGHTLY.value,
test_owner=test_owner.RAYMOND_Z,
).run(gcs_location=shared_gcs_location)
# Run HybridSim workload: read HLO from GCS, generate estimated step time
cluster = clusters[tpu]
chip_config = "default" if tpu == TpuVersion.V5E else "megacore"
hybridsim_cmd = (
"gsutil cp gs://cloud-hybridsim-prod/run_hybridsim.sh .",
f"bash run_hybridsim.sh GCS_XLA_DUMP_PATH=${{GCS_OUTPUT}}xla_dump GCS_OUTPUT_PATH=${{GCS_OUTPUT}}estimated_cost_ns.jsonl CHIP_CONFIG={chip_config}",
)
job_metric_config = metric_config.MetricConfig(
json_lines=metric_config.JSONLinesConfig(
file_location="estimated_cost_ns.jsonl",
),
use_runtime_generated_gcs_folder=True,
)
maxtext_hybridsim = gke_config.get_gke_config(
cluster=cluster,
time_out_in_min=240,
test_name=f"maxtext-{model_size}-{n}xv{tpu.value}-{num_cores}-hybridsim",
run_model_cmds=hybridsim_cmd,
docker_image=DockerImage.CLOUD_HYBRIDSIM_NIGHTLY.value,
test_owner=test_owner.RAYMOND_Z,
user_specified_job_metric_config=job_metric_config,
).run(gcs_location=shared_gcs_location)
shared_gcs_location >> maxtext_aot >> maxtext_hybridsim
with models.DAG(
dag_id="maxtext_configs_hybridsim",
schedule=SCHEDULED_TIME,
tags=["multipod_team", "maxtext", "nightly", "mlscale_onduty"],
start_date=datetime.datetime(2024, 2, 19),
catchup=False,
concurrency=10,
) as dag:
# Test setup values
model_configs = {
# accelerator: [(model_size, num_cores), ...],
TpuVersion.V4: [("22b", 128), ("52b", 384)],
TpuVersion.V5E: [("16b", 256), ("32b", 256), ("64b", 256), ("128b", 256)],
}
num_slices = [1, 2, 4, 8]
clusters = {
TpuVersion.V4: XpkClusters.TPU_V4_8_MAXTEXT_CLUSTER,
TpuVersion.V5E: XpkClusters.TPU_V5E_256_CLUSTER,
}
v5e_alt = "5e"
quarantine_task_group = TaskGroup(
group_id="Quarantine", dag=dag, prefix_group_id=False
)
for tpu, models in model_configs.items():
for model_size, num_cores in models:
for n in num_slices:
test_group_id = (
f"{model_size}-{n}xv{tpu.value}-{num_cores}-aot-hybridsim"
)
if QuarantineTests.is_quarantined(test_group_id):
with quarantine_task_group:
hybridsim_compile_and_run(test_group_id)
else:
hybridsim_compile_and_run(test_group_id)