From de30cbadf24a6b2ee5a6476783d3e765714c38ae Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 17 Mar 2025 14:03:24 +0530 Subject: [PATCH 1/3] test for better torch.compile stuff. --- tests/pipelines/test_pipelines_common.py | 41 +++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index a98de5c9eaf9..0a9e6791b72c 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -13,6 +13,7 @@ import torch.nn as nn from huggingface_hub import ModelCard, delete_repo from huggingface_hub.utils import is_jinja_available +from torch._dynamo.utils import counters from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer import diffusers @@ -45,6 +46,7 @@ from diffusers.utils.source_code_parsing_utils import ReturnNameVisitor from diffusers.utils.testing_utils import ( CaptureLogger, + backend_empty_cache, require_accelerate_version_greater, require_accelerator, require_hf_hub_version_greater, @@ -52,6 +54,7 @@ require_torch_gpu, require_transformers_version_greater, skip_mps, + slow, torch_device, ) @@ -1113,8 +1116,9 @@ def setUp(self): def tearDown(self): # clean up the VRAM after each test in case of CUDA runtime errors super().tearDown() + torch._dynamo.reset() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_save_load_local(self, expected_max_difference=5e-4): components = self.get_dummy_components() @@ -2153,6 +2157,41 @@ def test_StableDiffusionMixin_component(self): ) ) + @require_torch_gpu + @slow + def test_torch_compile_recompilation(self): + inputs = self.get_dummy_inputs() + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components).to(torch_device) + if getattr(pipe, "unet", None) is None: + pipe.unet = torch.compile(pipe.unet, fullgraph=True) + else: + pipe.transformer = torch.compile(pipe.transformer, fullgraph=True) + + with torch._dynamo.config.patch(error_on_recompile=True): + _ = pipe(**inputs) + + @require_torch_gpu + @slow + def test_torch_compile_graph_breaks(self): + # Inspired by: + # https://github.com/pytorch/pytorch/blob/916e8979d3e0d651a9091732ce3e59da32e72b0e/test/dynamo/test_higher_order_ops.py#L138 + counters.clear() + + inputs = self.get_dummy_inputs() + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components).to(torch_device) + if getattr(pipe, "unet", None) is None: + pipe.unet = torch.compile(pipe.unet, fullgraph=True) + else: + pipe.transformer = torch.compile(pipe.transformer, fullgraph=True) + + _ = pipe(**inputs) + num_graph_breaks = len(counters["graph_break"].keys()) + self.assertEqual(num_graph_breaks, 0) + @require_hf_hub_version_greater("0.26.5") @require_transformers_version_greater("4.47.1") def test_save_load_dduf(self, atol=1e-4, rtol=1e-4): From f389a4d5eb933249b7fc5349d081af3b3adf6ecc Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 17 Mar 2025 14:15:31 +0530 Subject: [PATCH 2/3] fixes --- tests/pipelines/test_pipelines_common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 0a9e6791b72c..cf3eab43aa9b 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2160,11 +2160,11 @@ def test_StableDiffusionMixin_component(self): @require_torch_gpu @slow def test_torch_compile_recompilation(self): - inputs = self.get_dummy_inputs() + inputs = self.get_dummy_inputs(torch_device) components = self.get_dummy_components() pipe = self.pipeline_class(**components).to(torch_device) - if getattr(pipe, "unet", None) is None: + if getattr(pipe, "unet", None) is not None: pipe.unet = torch.compile(pipe.unet, fullgraph=True) else: pipe.transformer = torch.compile(pipe.transformer, fullgraph=True) @@ -2179,11 +2179,11 @@ def test_torch_compile_graph_breaks(self): # https://github.com/pytorch/pytorch/blob/916e8979d3e0d651a9091732ce3e59da32e72b0e/test/dynamo/test_higher_order_ops.py#L138 counters.clear() - inputs = self.get_dummy_inputs() + inputs = self.get_dummy_inputs(torch_device) components = self.get_dummy_components() pipe = self.pipeline_class(**components).to(torch_device) - if getattr(pipe, "unet", None) is None: + if getattr(pipe, "unet", None) is not None: pipe.unet = torch.compile(pipe.unet, fullgraph=True) else: pipe.transformer = torch.compile(pipe.transformer, fullgraph=True) From 6791037c6451409de879b3fa3c73dc2b04d94100 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 21 Mar 2025 08:53:04 +0530 Subject: [PATCH 3/3] recompilation and graph break. --- tests/pipelines/test_pipelines_common.py | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index cf3eab43aa9b..f51048f150d9 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -13,7 +13,6 @@ import torch.nn as nn from huggingface_hub import ModelCard, delete_repo from huggingface_hub.utils import is_jinja_available -from torch._dynamo.utils import counters from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer import diffusers @@ -2159,7 +2158,7 @@ def test_StableDiffusionMixin_component(self): @require_torch_gpu @slow - def test_torch_compile_recompilation(self): + def test_torch_compile_recompilation_and_graph_break(self): inputs = self.get_dummy_inputs(torch_device) components = self.get_dummy_components() @@ -2172,26 +2171,6 @@ def test_torch_compile_recompilation(self): with torch._dynamo.config.patch(error_on_recompile=True): _ = pipe(**inputs) - @require_torch_gpu - @slow - def test_torch_compile_graph_breaks(self): - # Inspired by: - # https://github.com/pytorch/pytorch/blob/916e8979d3e0d651a9091732ce3e59da32e72b0e/test/dynamo/test_higher_order_ops.py#L138 - counters.clear() - - inputs = self.get_dummy_inputs(torch_device) - components = self.get_dummy_components() - - pipe = self.pipeline_class(**components).to(torch_device) - if getattr(pipe, "unet", None) is not None: - pipe.unet = torch.compile(pipe.unet, fullgraph=True) - else: - pipe.transformer = torch.compile(pipe.transformer, fullgraph=True) - - _ = pipe(**inputs) - num_graph_breaks = len(counters["graph_break"].keys()) - self.assertEqual(num_graph_breaks, 0) - @require_hf_hub_version_greater("0.26.5") @require_transformers_version_greater("4.47.1") def test_save_load_dduf(self, atol=1e-4, rtol=1e-4):