Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions dags/common/vm_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class Zone(enum.Enum):
# reserved a3+ cluster in supercomputer-testing
AUSTRALIA_SOUTHEAST1_C = "australia-southeast1-c"
# reserved H200 capacity in cloud-tpu-inference-test
# & reserved a3u cluster in supercomputer-testing
EUROPE_WEST1_B = "europe-west1-b"
# reserved TRILLIUM capacity
EUROPE_WEST4_A = "europe-west4-a"
Expand Down Expand Up @@ -179,6 +180,7 @@ class GpuVersion(enum.Enum):
H200 = "nvidia-h200-80gb"
XPK_H100 = "h100-80gb-8"
XPK_H100_MEGA = "h100-mega-80gb-8"
XPK_H200 = "h200-141gb-8"
V100 = "nvidia-tesla-v100"


Expand Down Expand Up @@ -280,6 +282,13 @@ class XpkClusters:
project=Project.SUPERCOMPUTER_TESTING.value,
zone=Zone.AUSTRALIA_SOUTHEAST1_C.value,
)
GPU_A3ULTRA_CLUSTER = XpkClusterConfig(
name="gke-a3u-map-01-31",
device_version=GpuVersion.XPK_H200,
core_count=8,
project=Project.SUPERCOMPUTER_TESTING.value,
zone=Zone.EUROPE_WEST1_B.value,
)
CPU_M1_MEGAMEM_96_CLUSTER = XpkClusterConfig(
name="m1-megamem-96-shared",
device_version=CpuVersion.M1_MEGAMEM,
Expand All @@ -295,6 +304,11 @@ class XpkClusters:
zone=Zone.US_CENTRAL1_B.value,
)

class XpkVersions(enum.Enum):
"""Supported XPK versions."""

V0_4_1 = 'v0.4.1'
V0_6_0 = 'v0.6.0'

class DockerImage(enum.Enum):
"""Common docker images."""
Expand Down
4 changes: 3 additions & 1 deletion dags/multipod/configs/gke_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from xlml.apis import gcp_config, metric_config, task, test_config
from xlml.apis.xpk_cluster_config import XpkClusterConfig
from dags import gcs_bucket
from dags.common.vm_resource import TpuVersion, Project, XpkClusters, GpuVersion, CpuVersion
from dags.common.vm_resource import TpuVersion, Project, XpkClusters, GpuVersion, CpuVersion, XpkVersions
from typing import Iterable
import datetime

Expand Down Expand Up @@ -150,6 +150,7 @@ def get_maxtext_end_to_end_gpu_gke_test_config(
test_owner: str,
docker_image: str,
num_slices: int = 1,
xpk_version: str = XpkVersions.V0_4_1.value,
) -> task.GpuCreateResourceTask:
job_gcp_config = gcp_config.GCPConfig(
project_name=cluster.project,
Expand Down Expand Up @@ -178,6 +179,7 @@ def get_maxtext_end_to_end_gpu_gke_test_config(
return task.XpkTask(
task_test_config=job_test_config,
task_gcp_config=job_gcp_config,
xpk_version=xpk_version,
)


Expand Down
24 changes: 22 additions & 2 deletions dags/sparsity_diffusion_devx/maxtext_moe_gpu_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from airflow.utils.task_group import TaskGroup
from dags import composer_env
from dags.common import test_owner
from dags.common.vm_resource import XpkClusters, DockerImage
from dags.common.vm_resource import XpkClusters, DockerImage, XpkVersions
from dags.multipod.configs import gke_config


Expand Down Expand Up @@ -69,7 +69,27 @@ def run_maxtext_tests():
docker_image=DockerImage.MAXTEXT_GPU_JAX_STABLE_STACK.value,
test_owner=test_owner.MICHELLE_Y,
).run()
pinned_a3plus_gpu >> stable_a3plus_gpu
pinned_a3ultra_gpu = gke_config.get_maxtext_end_to_end_gpu_gke_test_config(
time_out_in_min=90,
test_name=f"{test_name_prefix}-pinned-{model}",
run_model_cmds=(test_script,),
num_slices=nnodes,
cluster=XpkClusters.GPU_A3ULTRA_CLUSTER,
docker_image=DockerImage.MAXTEXT_GPU_JAX_PINNED.value,
test_owner=test_owner.MICHELLE_Y,
xpk_version=XpkVersions.V0_6_0.value,
).run()
stable_a3ultra_gpu = gke_config.get_maxtext_end_to_end_gpu_gke_test_config(
time_out_in_min=90,
test_name=f"{test_name_prefix}-stable-{model}",
run_model_cmds=(test_script,),
num_slices=nnodes,
cluster=XpkClusters.GPU_A3ULTRA_CLUSTER,
docker_image=DockerImage.MAXTEXT_GPU_JAX_STABLE_STACK.value,
test_owner=test_owner.MICHELLE_Y,
xpk_version=XpkVersions.V0_6_0.value,
).run()
pinned_a3plus_gpu >> stable_a3plus_gpu >> pinned_a3ultra_gpu >> stable_a3ultra_gpu


with models.DAG(
Expand Down
4 changes: 4 additions & 0 deletions xlml/apis/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from airflow.utils.task_group import TaskGroup
from xlml.apis import gcp_config, metric_config, test_config
from xlml.utils import gpu, metric, name_format, ssh, tpu, xpk, gke
from dags.common.vm_resource import XpkVersions


class BaseTask(abc.ABC):
Expand Down Expand Up @@ -164,6 +165,7 @@ class XpkTask(BaseTask):
workload_provision_timeout: datetime.timedelta = datetime.timedelta(
minutes=300
)
xpk_version: str = XpkVersions.V0_4_1.value

def run(
self,
Expand Down Expand Up @@ -283,6 +285,7 @@ def run_model(
project_id=self.task_gcp_config.project_name,
zone=self.task_gcp_config.zone,
cluster_name=self.task_test_config.cluster_name,
xpk_version=self.xpk_version,
)

(
Expand Down Expand Up @@ -318,6 +321,7 @@ def launch_workload(
num_slices=self.task_test_config.num_slices,
use_vertex_tensorboard=use_vertex_tensorboard,
use_pathways=use_pathways,
xpk_version=self.xpk_version,
)
wait_for_workload_start = xpk.wait_for_workload_start.override(
timeout=self.workload_provision_timeout.total_seconds()
Expand Down
17 changes: 11 additions & 6 deletions xlml/utils/xpk.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from kubernetes import client as k8s_client
from xlml.apis import metric_config
from xlml.utils import gke
from dags.common.vm_resource import GpuVersion
from dags.common.vm_resource import GpuVersion, XpkVersions

# Duration = past 7 days
LOGGING_URL_FORMAT = (
Expand All @@ -41,10 +41,10 @@
)


def get_xpk_setup_cmd(tmpdir):
def get_xpk_setup_cmd(tmpdir, version=XpkVersions.V0_4_1.value):
return [
"set -xue",
f"git clone --branch v0.4.1 https://github.com/AI-Hypercomputer/xpk {tmpdir}/xpk",
f"git clone --branch {version} https://github.com/AI-Hypercomputer/xpk {tmpdir}/xpk",
"pip install ruamel.yaml docker",
]

Expand Down Expand Up @@ -76,6 +76,7 @@ def run_workload(
num_slices: int = 1,
use_vertex_tensorboard: bool = False,
use_pathways: bool = False,
xpk_version: str = XpkVersions.V0_4_1.value,
):
"""Run workload through xpk tool."""

Expand All @@ -100,7 +101,7 @@ def run_workload(
f" --env {metric_config.SshEnvVars.GCS_OUTPUT.name}={gcs_path}"
" --restart-on-user-code-failure"
)
cmds = get_xpk_setup_cmd(tmpdir)
cmds = get_xpk_setup_cmd(tmpdir, xpk_version)
if accelerator_type == GpuVersion.XPK_H100_MEGA.value:
workload_create_cmd += " --scheduler=gke.io/topology-aware-auto"
if use_vertex_tensorboard:
Expand Down Expand Up @@ -260,7 +261,11 @@ def wait_for_workload_completion(

@task(trigger_rule="all_done")
def clean_up_workload(
workload_id: str, project_id: str, zone: str, cluster_name: str
workload_id: str,
project_id: str,
zone: str,
cluster_name: str,
xpk_version: str = XpkVersions.V0_4_1.value,
) -> bool:
"""Delete workload."""
with tempfile.TemporaryDirectory() as tmpdir:
Expand All @@ -270,7 +275,7 @@ def clean_up_workload(
f" --project={project_id} --zone={zone}"
)

cmds = get_xpk_setup_cmd(tmpdir)
cmds = get_xpk_setup_cmd(tmpdir, xpk_version)
cmds.append(workload_delete_cmd)
hook = SubprocessHook()
result = hook.run_command(
Expand Down
Loading