Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][tests] add tests to check for graph breaks, recompilation, cuda syncs in pipelines during torch.compile() #11085

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
17 changes: 17 additions & 0 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
require_torch_gpu,
require_transformers_version_greater,
skip_mps,
slow,
torch_device,
)

Expand Down Expand Up @@ -1117,6 +1118,7 @@ def setUp(self):
def tearDown(self):
# clean up the VRAM after each test in case of CUDA runtime errors
super().tearDown()
torch._dynamo.reset()
gc.collect()
backend_empty_cache(torch_device)

Expand Down Expand Up @@ -2162,6 +2164,21 @@ def test_StableDiffusionMixin_component(self):
)
)

@require_torch_gpu
@slow
def test_torch_compile_recompilation_and_graph_break(self):
inputs = self.get_dummy_inputs(torch_device)
components = self.get_dummy_components()

pipe = self.pipeline_class(**components).to(torch_device)
if getattr(pipe, "unet", None) is not None:
pipe.unet = torch.compile(pipe.unet, fullgraph=True)
else:
pipe.transformer = torch.compile(pipe.transformer, fullgraph=True)

with torch._dynamo.config.patch(error_on_recompile=True):
_ = pipe(**inputs)

@require_hf_hub_version_greater("0.26.5")
@require_transformers_version_greater("4.47.1")
def test_save_load_dduf(self, atol=1e-4, rtol=1e-4):
Expand Down
Loading