Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][tests] add tests to check for graph breaks, recompilation, cuda syncs in pipelines during torch.compile() #11085

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
41 changes: 40 additions & 1 deletion tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch.nn as nn
from huggingface_hub import ModelCard, delete_repo
from huggingface_hub.utils import is_jinja_available
from torch._dynamo.utils import counters
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer

import diffusers
Expand Down Expand Up @@ -45,13 +46,15 @@
from diffusers.utils.source_code_parsing_utils import ReturnNameVisitor
from diffusers.utils.testing_utils import (
CaptureLogger,
backend_empty_cache,
require_accelerate_version_greater,
require_accelerator,
require_hf_hub_version_greater,
require_torch,
require_torch_gpu,
require_transformers_version_greater,
skip_mps,
slow,
torch_device,
)

Expand Down Expand Up @@ -1113,8 +1116,9 @@ def setUp(self):
def tearDown(self):
# clean up the VRAM after each test in case of CUDA runtime errors
super().tearDown()
torch._dynamo.reset()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_save_load_local(self, expected_max_difference=5e-4):
components = self.get_dummy_components()
Expand Down Expand Up @@ -2153,6 +2157,41 @@ def test_StableDiffusionMixin_component(self):
)
)

@require_torch_gpu
@slow
def test_torch_compile_recompilation(self):
inputs = self.get_dummy_inputs(torch_device)
components = self.get_dummy_components()

pipe = self.pipeline_class(**components).to(torch_device)
if getattr(pipe, "unet", None) is not None:
pipe.unet = torch.compile(pipe.unet, fullgraph=True)
else:
pipe.transformer = torch.compile(pipe.transformer, fullgraph=True)

with torch._dynamo.config.patch(error_on_recompile=True):
_ = pipe(**inputs)

@require_torch_gpu
@slow
def test_torch_compile_graph_breaks(self):
# Inspired by:
# https://github.com/pytorch/pytorch/blob/916e8979d3e0d651a9091732ce3e59da32e72b0e/test/dynamo/test_higher_order_ops.py#L138
counters.clear()

inputs = self.get_dummy_inputs(torch_device)
components = self.get_dummy_components()

pipe = self.pipeline_class(**components).to(torch_device)
if getattr(pipe, "unet", None) is not None:
pipe.unet = torch.compile(pipe.unet, fullgraph=True)
else:
pipe.transformer = torch.compile(pipe.transformer, fullgraph=True)

_ = pipe(**inputs)
num_graph_breaks = len(counters["graph_break"].keys())
self.assertEqual(num_graph_breaks, 0)

@require_hf_hub_version_greater("0.26.5")
@require_transformers_version_greater("4.47.1")
def test_save_load_dduf(self, atol=1e-4, rtol=1e-4):
Expand Down
Loading