Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Do Not Merge] Simple Maxdiffusion SDXL inference integration #299

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions dags/inference/configs/maxdiffusion_inference_gce_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities to construct configs for maxdiffusion inference DAG."""

import datetime
import json
from typing import Dict
from xlml.apis import gcp_config, metric_config, task, test_config
from dags import test_owner
from dags.multipod.configs import common
from dags.vm_resource import TpuVersion, Project, RuntimeVersion

PROJECT_NAME = Project.CLOUD_ML_AUTO_SOLUTIONS.value
RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value
GCS_SUBFOLDER_PREFIX = test_owner.Team.INFERENCE.value


def _modify_save_metrics(metrics_file, model_configs):
metrics = json.loads(metrics_file)
for k, v in model_configs:
metrics["dimensions"][k] = str(v)
with open(metrics_file, "w") as f:
f.write(json.dumps(metrics))


def get_maxdiffusion_inference_nightly_config(
tpu_version: TpuVersion,
tpu_cores: int,
tpu_zone: str,
time_out_in_min: int,
test_name: str,
test_mode: common.SetupMode,
project_name: str = PROJECT_NAME,
runtime_version: str = RUNTIME_IMAGE,
network: str = "default",
subnetwork: str = "default",
is_tpu_reserved: bool = True,
num_slices: int = 1,
model_configs: Dict = {},
) -> task.TpuQueuedResourceTask:
job_gcp_config = gcp_config.GCPConfig(
project_name=project_name,
zone=tpu_zone,
dataset_name=metric_config.DatasetOption.BENCHMARK_DATASET,
)

per_device_bat_size = model_configs["per_device_batch_size"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
per_device_bat_size = model_configs["per_device_batch_size"]
per_device_batch_size = model_configs["per_device_batch_size"]

attention = model_configs["attention"]
model_name = model_configs["model_name"]
set_up_cmds = (
"pip install --upgrade pip",
# Download maxdiffusion
"git clone https://github.com/google/maxdiffusion.git"
# Create a python virtual environment
"sudo apt-get -y update",
"sudo apt-get -y install python3.10-venv",
"python -m venv .env",
"source .env/bin/activate",
# Setup Maxdiffusion
"cd maxdiffusion",
"pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html",
"pip3 install -r requirements.txt",
"pip3 install .",
# dependency for controlnet
"apt-get install ffmpeg libsm6 libxext6 -y" "cd ..",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"apt-get install ffmpeg libsm6 libxext6 -y" "cd ..",
"apt-get install ffmpeg libsm6 libxext6 -y",
"cd ..",

)

if model_name == "SDXL-Base-1.0":
run_model_cmds = (
# Start virtual environment
"source .env/bin/activate",
### Benchmark
"cd maxdiffusion",
# Configure flags
"cd .."
f""" python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run" per_device_batch_size={per_device_bat_size} attention="{attention}" """,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rename "my_run" to something more specific to "sdxl". Here, and below.

"cd ..",
f"gsutil cp metrics.json {metric_config.SshEnvVars.GCS_OUTPUT.value}",
)
if model_name == "SDXL-Lightning":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: change to elif here and below.

run_model_cmds = (
# Start virtual environment
"source .env/bin/activate",
### Benchmark
"cd maxdiffusion",
# Configure flags
"cd .."
f""" python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run" lightning_repo="ByteDance/SDXL-Lightning" lightning_ckpt="sdxl_lightning_4step_unet.safetensors" per_device_batch_size={per_device_bat_size} attention="{attention}" """,
"cd ..",
f"gsutil cp metrics.json {metric_config.SshEnvVars.GCS_OUTPUT.value}",
)
if model_name == "SDXL-ControlNet":
run_model_cmds = (
# Start virtual environment
"source .env/bin/activate",
### Benchmark
"cd maxdiffusion",
# Configure flags
"cd .."
f""" python src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py per_device_batch_size={per_device_bat_size} attention="{attention}" """,
"cd ..",
f"gsutil cp metrics.json {metric_config.SshEnvVars.GCS_OUTPUT.value}",
)

_modify_save_metrics("metrics.json", model_configs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if _modify_save_metrics will actually affect the metrics.json file. We're running the _modify_save_metrics function when we define the DAG, but to actually need the function to run in the TPU VM at run time.

We can either have bash commands to do something like _modify_save_metrics, which may be more difficult. OR, we could have a python script in the maxdiffusion repo (or other repo) to do this logic. OR, we could do this logic directly in the generate_sdxl.py file (with a flag to add the dimensions key.

job_test_config = test_config.TpuVmTest(
test_config.Tpu(
version=tpu_version,
cores=tpu_cores,
runtime_version=runtime_version,
reserved=is_tpu_reserved,
network=network,
subnetwork=subnetwork,
),
test_name=test_name,
set_up_cmds=set_up_cmds,
run_model_cmds=run_model_cmds,
timeout=datetime.timedelta(minutes=time_out_in_min),
task_owner=test_owner.VIJAYA_S,
num_slices=num_slices,
gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/maxdiffusion",
)

job_metric_config = metric_config.MetricConfig(
json_lines=metric_config.JSONLinesConfig("metrics.json"),
use_runtime_generated_gcs_folder=True,
)

return task.TpuQueuedResourceTask(
task_test_config=job_test_config,
task_gcp_config=job_gcp_config,
task_metric_config=job_metric_config,
)
129 changes: 129 additions & 0 deletions dags/inference/maxdiffusion_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A DAG to run Maxdiffusion inference benchmarks"""

import datetime
from airflow import models
from dags import composer_env, test_owner
from dags.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion
from dags.inference.configs import maxdiffusion_inference_gce_config
from dags.multipod.configs.common import SetupMode, Platform


# Run once a day at 4 am UTC (8 pm PST)
SCHEDULED_TIME = "0 4 * * *" if composer_env.is_prod_env() else None


with models.DAG(
dag_id="maxdiffusion_inference",
schedule=SCHEDULED_TIME,
tags=["inference_team", "maxdiffusion", "nightly", "benchmark"],
start_date=datetime.datetime(2024, 1, 19),
catchup=False,
) as dag:
test_name_prefix = "maxdiffusion-inference"
test_models = {
"SDXL-Base-1.0": {
"model_configs": [
(TpuVersion.V5E, 8, [1, 2], ["dot_attention", "flash"]),
(TpuVersion.V5E, 4, [1, 2], ["dot_attention", "flash"]),
(
TpuVersion.V5P,
8,
[2, 10, 20, 40, 80, 320],
"dot_attention",
"flash",
),
]
},
"SDXL-Lightning": {
"model_configs": [
(TpuVersion.V5E, 8, [1, 2], ["dot_attention", "flash"]),
(TpuVersion.V5E, 4, [1, 2], ["dot_attention", "flash"]),
(
TpuVersion.V5P,
8,
[2, 10, 20, 40, 80, 320],
["dot_attention", "flash"],
),
]
},
"SDXL-ContolNet": {
"model_configs": [
(TpuVersion.V5E, 8, [1], ["dot_attention", "flash"]),
(TpuVersion.V5E, 4, [1], ["dot_attention", "flash"]),
(
TpuVersion.V5P,
8,
[2, 10, 20, 40, 80],
["dot_attention", "flash"],
),
]
},
}

for model, sweep_model_configs in test_models.items():
# tasks_per_model = []
for tpu_version, tpu_cores, per_device_batch_sizes, attentions in sweep_model_configs["model_configs"]:
for per_device_batch_size in per_device_batch_sizes:
for attention in attentions:
model_configs = {}
model_configs["model_name"] = model
model_configs["per_device_batch_size"] = per_device_batch_size
model_configs["attention"] = attention

if tpu_version == TpuVersion.V5E:
# v5e benchmarks
project_name = Project.TPU_PROD_ENV_AUTOMATED.value
zone = Zone.US_EAST1_C.value
network = V5_NETWORKS
subnetwork = V5E_SUBNETWORKS
runtime_version = RuntimeVersion.V2_ALPHA_TPUV5_LITE.value
elif tpu_version == TpuVersion.V5P:
zone = Zone.US_EAST5_A.value
runtime_version = RuntimeVersion.V2_ALPHA_TPUV5.value
project_name = Project.TPU_PROD_ENV_AUTOMATED.value
network = V5_NETWORKS
subnetwork = V5P_SUBNETWORKS

maxdiffusion_stable_1slice = maxdiffusion_inference_gce_config.get_maxdiffusion_inference_nightly_config(
tpu_version=tpu_version,
tpu_cores=tpu_cores,
tpu_zone=zone,
runtime_version=runtime_version,
project_name=project_name,
time_out_in_min=60,
is_tpu_reserved=True,
test_name=f"{test_name_prefix}-stable-{model}-per_device_batch_size-{per_device_batch_size}",
test_mode=SetupMode.STABLE,
network=network,
subnetwork=subnetwork,
model_configs=model_configs,
).run()
maxdiffusion_nightly_1slice = maxdiffusion_inference_gce_config.get_maxdiffusion_inference_nightly_config(
tpu_version=tpu_version,
tpu_cores=tpu_cores,
tpu_zone=zone,
runtime_version=runtime_version,
project_name=project_name,
time_out_in_min=60,
is_tpu_reserved=True,
test_name=f"{test_name_prefix}-nightly-{model}-per_device_batch_size-{per_device_batch_size}",
test_mode=SetupMode.NIGHTLY,
network=network,
subnetwork=subnetwork,
model_configs=model_configs,
).run()
maxdiffusion_stable_1slice >> maxdiffusion_nightly_1slice
1 change: 1 addition & 0 deletions dags/test_owner.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,4 @@ class Team(enum.Enum):

# Inference
ANDY_Y = "Andy Y."
VIJAYA_S = "Vijaya S."
12 changes: 7 additions & 5 deletions xlml/apis/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,11 +457,13 @@ def setup_script(self) -> Optional[str]:
# TODO(wcromar): replace configmaps
@property
def test_script(self) -> str:
return '\n'.join([
'set -xue',
self.exports,
' '.join(shlex.quote(s) for s in self.test_command),
])
return '\n'.join(
[
'set -xue',
self.exports,
' '.join(shlex.quote(s) for s in self.test_command),
]
)


@attrs.define
Expand Down
12 changes: 7 additions & 5 deletions xlml/utils/gpu.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You didn't make these changes right? Can we rebase master please?

Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,13 @@ def create_resource_request(
image = get_image_from_family(project=image_project, family=image_family)
disk_type = f"zones/{gcp.zone}/diskTypes/pd-ssd"
disks = [disk_from_image(disk_type, 100, True, image.self_link)]
metadata = create_metadata({
"install-nvidia-driver": "False",
"proxy-mode": "project_editors",
"ssh-keys": f"cloud-ml-auto-solutions:{ssh_keys.public}",
})
metadata = create_metadata(
{
"install-nvidia-driver": "False",
"proxy-mode": "project_editors",
"ssh-keys": f"cloud-ml-auto-solutions:{ssh_keys.public}",
}
)

accelerators = [
compute_v1.AcceleratorConfig(
Expand Down
Loading