Skip to content

[tests] add tests to check for graph breaks and recompilation in pipelines during torch.compile() #11085

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Apr 28, 2025
Merged
Changes from 4 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
de30cba
test for better torch.compile stuff.
sayakpaul Mar 17, 2025
f389a4d
fixes
sayakpaul Mar 17, 2025
6b05db6
Merge branch 'main' into test-better-torch-compile
sayakpaul Mar 18, 2025
e5543dc
Merge branch 'main' into test-better-torch-compile
sayakpaul Mar 20, 2025
6791037
recompilation and graph break.
sayakpaul Mar 21, 2025
abd1f6c
Merge branch 'main' into test-better-torch-compile
sayakpaul Mar 21, 2025
1f797b4
Merge branch 'main' into test-better-torch-compile
sayakpaul Mar 27, 2025
d669340
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 9, 2025
c49a855
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 9, 2025
c060ba0
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 14, 2025
e75a9de
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 14, 2025
c7f153a
clear compilation cache.
sayakpaul Apr 14, 2025
c74c9a8
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 14, 2025
1a934b2
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 14, 2025
e0566e6
change to modeling level test.
sayakpaul Apr 14, 2025
38c1d0d
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 15, 2025
87d957d
allow running compilation tests during nightlies.
sayakpaul Apr 15, 2025
a8184ef
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 15, 2025
fae8b6c
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 18, 2025
1749955
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 21, 2025
a07c63b
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 25, 2025
f71c8f6
Merge branch 'main' into test-better-torch-compile
sayakpaul Apr 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch.nn as nn
from huggingface_hub import ModelCard, delete_repo
from huggingface_hub.utils import is_jinja_available
from torch._dynamo.utils import counters
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer

import diffusers
Expand Down Expand Up @@ -45,13 +46,15 @@
from diffusers.utils.source_code_parsing_utils import ReturnNameVisitor
from diffusers.utils.testing_utils import (
CaptureLogger,
backend_empty_cache,
require_accelerate_version_greater,
require_accelerator,
require_hf_hub_version_greater,
require_torch,
require_torch_gpu,
require_transformers_version_greater,
skip_mps,
slow,
torch_device,
)

Expand Down Expand Up @@ -1113,8 +1116,9 @@ def setUp(self):
def tearDown(self):
# clean up the VRAM after each test in case of CUDA runtime errors
super().tearDown()
torch._dynamo.reset()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

def test_save_load_local(self, expected_max_difference=5e-4):
components = self.get_dummy_components()
Expand Down Expand Up @@ -2153,6 +2157,41 @@ def test_StableDiffusionMixin_component(self):
)
)

@require_torch_gpu
@slow
def test_torch_compile_recompilation(self):
inputs = self.get_dummy_inputs(torch_device)
components = self.get_dummy_components()

pipe = self.pipeline_class(**components).to(torch_device)
if getattr(pipe, "unet", None) is not None:
pipe.unet = torch.compile(pipe.unet, fullgraph=True)
else:
pipe.transformer = torch.compile(pipe.transformer, fullgraph=True)

with torch._dynamo.config.patch(error_on_recompile=True):
_ = pipe(**inputs)

@require_torch_gpu
@slow
def test_torch_compile_graph_breaks(self):
# Inspired by:
# https://github.com/pytorch/pytorch/blob/916e8979d3e0d651a9091732ce3e59da32e72b0e/test/dynamo/test_higher_order_ops.py#L138
counters.clear()

inputs = self.get_dummy_inputs(torch_device)
components = self.get_dummy_components()

pipe = self.pipeline_class(**components).to(torch_device)
if getattr(pipe, "unet", None) is not None:
pipe.unet = torch.compile(pipe.unet, fullgraph=True)
else:
pipe.transformer = torch.compile(pipe.transformer, fullgraph=True)

_ = pipe(**inputs)
num_graph_breaks = len(counters["graph_break"].keys())
self.assertEqual(num_graph_breaks, 0)

@require_hf_hub_version_greater("0.26.5")
@require_transformers_version_greater("4.47.1")
def test_save_load_dduf(self, atol=1e-4, rtol=1e-4):
Expand Down
Loading