Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Do Not Merge] Simple Maxdiffusion SDXL inference integration #299

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 54 additions & 18 deletions dags/inference/configs/maxdiffusion_inference_gce_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Utilities to construct configs for maxdiffusion inference DAG."""

import datetime
import json
from typing import Dict
from xlml.apis import gcp_config, metric_config, task, test_config
Expand All @@ -26,6 +27,14 @@
GCS_SUBFOLDER_PREFIX = test_owner.Team.INFERENCE.value


def _modify_save_metrics(metrics_file, model_configs):
metrics = json.loads(metrics_file)
for k, v in model_configs:
metrics["dimensions"][k] = str(v)
with open(metrics_file, "w") as f:
f.write(json.dumps(metrics))


def get_maxdiffusion_inference_nightly_config(
tpu_version: TpuVersion,
tpu_cores: int,
Expand All @@ -47,10 +56,13 @@ def get_maxdiffusion_inference_nightly_config(
dataset_name=metric_config.DatasetOption.BENCHMARK_DATASET,
)

per_device_bat_size = model_configs["per_device_batch_size"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
per_device_bat_size = model_configs["per_device_batch_size"]
per_device_batch_size = model_configs["per_device_batch_size"]

attention = model_configs["attention"]
model_name = model_configs["model_name"]
set_up_cmds = (
"pip install --upgrade pip",
# Download maxdiffusion
"git clone -b inference_utils https://github.com/google/maxdiffusion.git",
"git clone https://github.com/google/maxdiffusion.git"
# Create a python virtual environment
"sudo apt-get -y update",
"sudo apt-get -y install python3.10-venv",
Expand All @@ -60,25 +72,49 @@ def get_maxdiffusion_inference_nightly_config(
"cd maxdiffusion",
"pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html",
"pip3 install -r requirements.txt",
"pip3 install ."
"pip3 install .",
# dependency for controlnet
"apt-get install ffmpeg libsm6 libxext6 -y" "cd ..",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"apt-get install ffmpeg libsm6 libxext6 -y" "cd ..",
"apt-get install ffmpeg libsm6 libxext6 -y",
"cd ..",

)

additional_metadata_dict = {
"per_device_batch_size": f"{model_configs['per_device_batch_size']}",
}

run_model_cmds = (
# Start virtual environment
"source .env/bin/activate",
### Benchmark
"cd maxdiffusion",
# Configure flags
""" python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run" """,
# Give server time to start
f"sleep {model_configs['sleep_time']}",
f"gsutil cp metrics.json {metric_config.SshEnvVars.GCS_OUTPUT.value}",
)
if model_name == "SDXL-Base-1.0":
run_model_cmds = (
# Start virtual environment
"source .env/bin/activate",
### Benchmark
"cd maxdiffusion",
# Configure flags
"cd .."
f""" python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run" per_device_batch_size={per_device_bat_size} attention="{attention}" """,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rename "my_run" to something more specific to "sdxl". Here, and below.

"cd ..",
f"gsutil cp metrics.json {metric_config.SshEnvVars.GCS_OUTPUT.value}",
)
if model_name == "SDXL-Lightning":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: change to elif here and below.

run_model_cmds = (
# Start virtual environment
"source .env/bin/activate",
### Benchmark
"cd maxdiffusion",
# Configure flags
"cd .."
f""" python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run" lightning_repo="ByteDance/SDXL-Lightning" lightning_ckpt="sdxl_lightning_4step_unet.safetensors" per_device_batch_size={per_device_bat_size} attention="{attention}" """,
"cd ..",
f"gsutil cp metrics.json {metric_config.SshEnvVars.GCS_OUTPUT.value}",
)
if model_name == "SDXL-ControlNet":
run_model_cmds = (
# Start virtual environment
"source .env/bin/activate",
### Benchmark
"cd maxdiffusion",
# Configure flags
"cd .."
f""" python src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py per_device_batch_size={per_device_bat_size} attention="{attention}" """,
"cd ..",
f"gsutil cp metrics.json {metric_config.SshEnvVars.GCS_OUTPUT.value}",
)

_modify_save_metrics("metrics.json", model_configs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if _modify_save_metrics will actually affect the metrics.json file. We're running the _modify_save_metrics function when we define the DAG, but to actually need the function to run in the TPU VM at run time.

We can either have bash commands to do something like _modify_save_metrics, which may be more difficult. OR, we could have a python script in the maxdiffusion repo (or other repo) to do this logic. OR, we could do this logic directly in the generate_sdxl.py file (with a flag to add the dimensions key.

job_test_config = test_config.TpuVmTest(
test_config.Tpu(
version=tpu_version,
Expand All @@ -91,7 +127,7 @@ def get_maxdiffusion_inference_nightly_config(
test_name=test_name,
set_up_cmds=set_up_cmds,
run_model_cmds=run_model_cmds,
time_out_in_min=time_out_in_min,
timeout=datetime.timedelta(minutes=time_out_in_min),
task_owner=test_owner.VIJAYA_S,
num_slices=num_slices,
gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/maxdiffusion",
Expand Down
53 changes: 41 additions & 12 deletions dags/inference/maxdiffusion_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""A DAG to run MaxText inference benchmarks with nightly version."""
"""A DAG to run Maxdiffusion inference benchmarks"""

import datetime
from airflow import models
Expand All @@ -35,25 +35,54 @@
) as dag:
test_name_prefix = "maxdiffusion-inference"
test_models = {
"SDXL-Base-1": {
"sleep_time": 120,
"tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.V5P, 8)],
"maxdiffusion_logs": "gs://inference-benchmarks/models/SDXL-Base-1/2024-05-14-14-01/",
"per_device_batch_sizes": [2],
# "request_rate": 5,
"SDXL-Base-1.0": {
"model_configs": [
(TpuVersion.V5E, 8, [1, 2], ["dot_attention", "flash"]),
(TpuVersion.V5E, 4, [1, 2], ["dot_attention", "flash"]),
(
TpuVersion.V5P,
8,
[2, 10, 20, 40, 80, 320],
"dot_attention",
"flash",
),
]
},
"SDXL-Lightning": {
"model_configs": [
(TpuVersion.V5E, 8, [1, 2], ["dot_attention", "flash"]),
(TpuVersion.V5E, 4, [1, 2], ["dot_attention", "flash"]),
(
TpuVersion.V5P,
8,
[2, 10, 20, 40, 80, 320],
["dot_attention", "flash"],
),
]
},
"SDXL-ContolNet": {
"model_configs": [
(TpuVersion.V5E, 8, [1], ["dot_attention", "flash"]),
(TpuVersion.V5E, 4, [1], ["dot_attention", "flash"]),
(
TpuVersion.V5P,
8,
[2, 10, 20, 40, 80],
["dot_attention", "flash"],
),
]
},
}

for model, sweep_model_configs in test_models.items():
# tasks_per_model = []
for per_device_batch_size in sweep_model_configs["per_device_batch_sizes"]:
for tpu_version, tpu_cores in sweep_model_configs["tpu_version_cores"]:
for tpu_version, tpu_cores, per_device_batch_sizes, attentions in sweep_model_configs["model_configs"]:
for per_device_batch_size in per_device_batch_sizes:
for attention in attentions:
model_configs = {}
model_configs["model_name"] = model
model_configs["sleep_time"] = sweep_model_configs["sleep_time"]
model_configs["maxdiffusion_logs"] = sweep_model_configs["maxdiffusion_logs"]
model_configs["per_device_batch_size"] = per_device_batch_size
# model_configs["request_rate"] = sweep_model_configs["request_rate"]
model_configs["attention"] = attention

if tpu_version == TpuVersion.V5E:
# v5e benchmarks
Expand Down
12 changes: 7 additions & 5 deletions xlml/apis/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,11 +457,13 @@ def setup_script(self) -> Optional[str]:
# TODO(wcromar): replace configmaps
@property
def test_script(self) -> str:
return '\n'.join([
'set -xue',
self.exports,
' '.join(shlex.quote(s) for s in self.test_command),
])
return '\n'.join(
[
'set -xue',
self.exports,
' '.join(shlex.quote(s) for s in self.test_command),
]
)


@attrs.define
Expand Down
12 changes: 7 additions & 5 deletions xlml/utils/gpu.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You didn't make these changes right? Can we rebase master please?

Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,13 @@ def create_resource_request(
image = get_image_from_family(project=image_project, family=image_family)
disk_type = f"zones/{gcp.zone}/diskTypes/pd-ssd"
disks = [disk_from_image(disk_type, 100, True, image.self_link)]
metadata = create_metadata({
"install-nvidia-driver": "False",
"proxy-mode": "project_editors",
"ssh-keys": f"cloud-ml-auto-solutions:{ssh_keys.public}",
})
metadata = create_metadata(
{
"install-nvidia-driver": "False",
"proxy-mode": "project_editors",
"ssh-keys": f"cloud-ml-auto-solutions:{ssh_keys.public}",
}
)

accelerators = [
compute_v1.AcceleratorConfig(
Expand Down
Loading