Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Do Not Merge] Simple Maxdiffusion SDXL inference integration #299

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions dags/inference/configs/maxdiffusion_inference_gce_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities to construct configs for maxdiffusion inference DAG."""

import json
from typing import Dict
from xlml.apis import gcp_config, metric_config, task, test_config
from dags import test_owner
from dags.multipod.configs import common
from dags.vm_resource import TpuVersion, Project, RuntimeVersion

PROJECT_NAME = Project.CLOUD_ML_AUTO_SOLUTIONS.value
RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value
GCS_SUBFOLDER_PREFIX = test_owner.Team.INFERENCE.value


def get_maxdiffusion_inference_nightly_config(
tpu_version: TpuVersion,
tpu_cores: int,
tpu_zone: str,
time_out_in_min: int,
test_name: str,
test_mode: common.SetupMode,
project_name: str = PROJECT_NAME,
runtime_version: str = RUNTIME_IMAGE,
network: str = "default",
subnetwork: str = "default",
is_tpu_reserved: bool = True,
num_slices: int = 1,
model_configs: Dict = {},
) -> task.TpuQueuedResourceTask:
job_gcp_config = gcp_config.GCPConfig(
project_name=project_name,
zone=tpu_zone,
dataset_name=metric_config.DatasetOption.BENCHMARK_DATASET,
)

set_up_cmds = (
"pip install --upgrade pip",
# Download maxdiffusion
"git clone -b inference_utils https://github.com/google/maxdiffusion.git",
# Create a python virtual environment
"sudo apt-get -y update",
"sudo apt-get -y install python3.10-venv",
"python -m venv .env",
"source .env/bin/activate",
# Setup Maxdiffusion
"cd maxdiffusion",
"pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html",
"pip3 install -r requirements.txt",
"pip3 install ."
)

additional_metadata_dict = {
mailvijayasingh marked this conversation as resolved.
Show resolved Hide resolved
"per_device_batch_size": f"{model_configs['per_device_batch_size']}",
}

run_model_cmds = (
# Start virtual environment
"source .env/bin/activate",
### Benchmark
"cd maxdiffusion",
# Configure flags
""" python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run" """,
# Give server time to start
f"sleep {model_configs['sleep_time']}",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed? I had to have this for my jetengine server

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can delete it.

f"gsutil cp metrics.json {metric_config.SshEnvVars.GCS_OUTPUT.value}",
)

job_test_config = test_config.TpuVmTest(
test_config.Tpu(
version=tpu_version,
cores=tpu_cores,
runtime_version=runtime_version,
reserved=is_tpu_reserved,
network=network,
subnetwork=subnetwork,
),
test_name=test_name,
set_up_cmds=set_up_cmds,
run_model_cmds=run_model_cmds,
time_out_in_min=time_out_in_min,
mailvijayasingh marked this conversation as resolved.
Show resolved Hide resolved
task_owner=test_owner.VIJAYA_S,
num_slices=num_slices,
gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/maxdiffusion",
)

job_metric_config = metric_config.MetricConfig(
json_lines=metric_config.JSONLinesConfig("metrics.json"),
use_runtime_generated_gcs_folder=True,
)

return task.TpuQueuedResourceTask(
task_test_config=job_test_config,
task_gcp_config=job_gcp_config,
task_metric_config=job_metric_config,
)
100 changes: 100 additions & 0 deletions dags/inference/maxdiffusion_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A DAG to run MaxText inference benchmarks with nightly version."""
mailvijayasingh marked this conversation as resolved.
Show resolved Hide resolved

import datetime
from airflow import models
from dags import composer_env, test_owner
from dags.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion
from dags.inference.configs import maxdiffusion_inference_gce_config
from dags.multipod.configs.common import SetupMode, Platform


# Run once a day at 4 am UTC (8 pm PST)
SCHEDULED_TIME = "0 4 * * *" if composer_env.is_prod_env() else None


with models.DAG(
dag_id="maxdiffusion_inference",
schedule=SCHEDULED_TIME,
tags=["inference_team", "maxdiffusion", "nightly", "benchmark"],
start_date=datetime.datetime(2024, 1, 19),
catchup=False,
) as dag:
test_name_prefix = "maxdiffusion-inference"
test_models = {
"SDXL-Base-1": {
"sleep_time": 120,
"tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.V5P, 8)],
"maxdiffusion_logs": "gs://inference-benchmarks/models/SDXL-Base-1/2024-05-14-14-01/",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove since this is unused.

"per_device_batch_sizes": [2],
# "request_rate": 5,
mailvijayasingh marked this conversation as resolved.
Show resolved Hide resolved
},
}

for model, sweep_model_configs in test_models.items():
# tasks_per_model = []
for per_device_batch_size in sweep_model_configs["per_device_batch_sizes"]:
for tpu_version, tpu_cores in sweep_model_configs["tpu_version_cores"]:
model_configs = {}
model_configs["model_name"] = model
model_configs["sleep_time"] = sweep_model_configs["sleep_time"]
model_configs["maxdiffusion_logs"] = sweep_model_configs["maxdiffusion_logs"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: unused

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all parameters, but some are used in next revision

model_configs["per_device_batch_size"] = per_device_batch_size
# model_configs["request_rate"] = sweep_model_configs["request_rate"]
mailvijayasingh marked this conversation as resolved.
Show resolved Hide resolved

if tpu_version == TpuVersion.V5E:
# v5e benchmarks
project_name = Project.TPU_PROD_ENV_AUTOMATED.value
zone = Zone.US_EAST1_C.value
network = V5_NETWORKS
subnetwork = V5E_SUBNETWORKS
runtime_version = RuntimeVersion.V2_ALPHA_TPUV5_LITE.value
elif tpu_version == TpuVersion.V5P:
zone = Zone.US_EAST5_A.value
runtime_version = RuntimeVersion.V2_ALPHA_TPUV5.value
project_name = Project.TPU_PROD_ENV_AUTOMATED.value
network = V5_NETWORKS
subnetwork = V5P_SUBNETWORKS

maxdiffusion_stable_1slice = maxdiffusion_inference_gce_config.get_maxdiffusion_inference_nightly_config(
tpu_version=tpu_version,
tpu_cores=tpu_cores,
tpu_zone=zone,
runtime_version=runtime_version,
project_name=project_name,
time_out_in_min=60,
is_tpu_reserved=True,
test_name=f"{test_name_prefix}-stable-{model}-per_device_batch_size-{per_device_batch_size}",
test_mode=SetupMode.STABLE,
network=network,
subnetwork=subnetwork,
model_configs=model_configs,
).run()
maxdiffusion_nightly_1slice = maxdiffusion_inference_gce_config.get_maxdiffusion_inference_nightly_config(
tpu_version=tpu_version,
tpu_cores=tpu_cores,
tpu_zone=zone,
runtime_version=runtime_version,
project_name=project_name,
time_out_in_min=60,
is_tpu_reserved=True,
test_name=f"{test_name_prefix}-nightly-{model}-per_device_batch_size-{per_device_batch_size}",
test_mode=SetupMode.NIGHTLY,
network=network,
subnetwork=subnetwork,
model_configs=model_configs,
).run()
maxdiffusion_stable_1slice >> maxdiffusion_nightly_1slice
1 change: 1 addition & 0 deletions dags/test_owner.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,4 @@ class Team(enum.Enum):

# Inference
ANDY_Y = "Andy Y."
VIJAYA_S = "Vijaya S."
Loading