Skip to content

Commit d704b3b

Browse files
tuanh123789anhnct8
andauthored
add PAG support sd15 controlnet (#8820)
* add pag support sd15 controlnet * fix quality import * remove unecessary import * remove if state * fix tests * remove useless function * add sd1.5 controlnet pag docs --------- Co-authored-by: anhnct8 <[email protected]>
1 parent 9f963e7 commit d704b3b

File tree

8 files changed

+1605
-0
lines changed

8 files changed

+1605
-0
lines changed

docs/source/en/api/pipelines/pag.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ The abstract from the paper is:
2525
- all
2626
- __call__
2727

28+
## StableDiffusionControlNetPAGPipeline
29+
[[autodoc]] StableDiffusionControlNetPAGPipeline
30+
- all
31+
- __call__
32+
2833
## StableDiffusionXLPAGPipeline
2934
[[autodoc]] StableDiffusionXLPAGPipeline
3035
- all

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@
302302
"StableDiffusionAttendAndExcitePipeline",
303303
"StableDiffusionControlNetImg2ImgPipeline",
304304
"StableDiffusionControlNetInpaintPipeline",
305+
"StableDiffusionControlNetPAGPipeline",
305306
"StableDiffusionControlNetPipeline",
306307
"StableDiffusionControlNetXSPipeline",
307308
"StableDiffusionDepth2ImgPipeline",
@@ -713,6 +714,7 @@
713714
StableDiffusionAttendAndExcitePipeline,
714715
StableDiffusionControlNetImg2ImgPipeline,
715716
StableDiffusionControlNetInpaintPipeline,
717+
StableDiffusionControlNetPAGPipeline,
716718
StableDiffusionControlNetPipeline,
717719
StableDiffusionControlNetXSPipeline,
718720
StableDiffusionDepth2ImgPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@
142142
_import_structure["pag"].extend(
143143
[
144144
"StableDiffusionPAGPipeline",
145+
"StableDiffusionControlNetPAGPipeline",
145146
"StableDiffusionXLPAGPipeline",
146147
"StableDiffusionXLPAGInpaintPipeline",
147148
"StableDiffusionXLControlNetPAGPipeline",
@@ -514,6 +515,7 @@
514515
)
515516
from .musicldm import MusicLDMPipeline
516517
from .pag import (
518+
StableDiffusionControlNetPAGPipeline,
517519
StableDiffusionPAGPipeline,
518520
StableDiffusionXLControlNetPAGPipeline,
519521
StableDiffusionXLPAGImg2ImgPipeline,

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
4848
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
4949
from .pag import (
50+
StableDiffusionControlNetPAGPipeline,
5051
StableDiffusionPAGPipeline,
5152
StableDiffusionXLControlNetPAGPipeline,
5253
StableDiffusionXLPAGImg2ImgPipeline,
@@ -90,6 +91,7 @@
9091
("pixart-alpha", PixArtAlphaPipeline),
9192
("pixart-sigma", PixArtSigmaPipeline),
9293
("stable-diffusion-pag", StableDiffusionPAGPipeline),
94+
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGPipeline),
9395
("stable-diffusion-xl-pag", StableDiffusionXLPAGPipeline),
9496
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGPipeline),
9597
]

src/diffusers/pipelines/pag/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2424
else:
25+
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
2526
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
2627
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
2728
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
@@ -36,6 +37,7 @@
3637
except OptionalDependencyNotAvailable:
3738
from ...utils.dummy_torch_and_transformers_objects import *
3839
else:
40+
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
3941
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
4042
from .pipeline_pag_sd import StableDiffusionPAGPipeline
4143
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline

src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py

Lines changed: 1329 additions & 0 deletions
Large diffs are not rendered by default.

src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,6 +1142,21 @@ def from_pretrained(cls, *args, **kwargs):
11421142
requires_backends(cls, ["torch", "transformers"])
11431143

11441144

1145+
class StableDiffusionControlNetPAGPipeline(metaclass=DummyObject):
1146+
_backends = ["torch", "transformers"]
1147+
1148+
def __init__(self, *args, **kwargs):
1149+
requires_backends(self, ["torch", "transformers"])
1150+
1151+
@classmethod
1152+
def from_config(cls, *args, **kwargs):
1153+
requires_backends(cls, ["torch", "transformers"])
1154+
1155+
@classmethod
1156+
def from_pretrained(cls, *args, **kwargs):
1157+
requires_backends(cls, ["torch", "transformers"])
1158+
1159+
11451160
class StableDiffusionControlNetPipeline(metaclass=DummyObject):
11461161
_backends = ["torch", "transformers"]
11471162

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import inspect
17+
import unittest
18+
19+
import numpy as np
20+
import torch
21+
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
22+
23+
from diffusers import (
24+
AutoencoderKL,
25+
ControlNetModel,
26+
DDIMScheduler,
27+
StableDiffusionControlNetPAGPipeline,
28+
StableDiffusionControlNetPipeline,
29+
UNet2DConditionModel,
30+
)
31+
from diffusers.utils.testing_utils import (
32+
enable_full_determinism,
33+
)
34+
from diffusers.utils.torch_utils import randn_tensor
35+
36+
from ..pipeline_params import (
37+
TEXT_TO_IMAGE_BATCH_PARAMS,
38+
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
39+
TEXT_TO_IMAGE_IMAGE_PARAMS,
40+
TEXT_TO_IMAGE_PARAMS,
41+
)
42+
from ..test_pipelines_common import (
43+
IPAdapterTesterMixin,
44+
PipelineFromPipeTesterMixin,
45+
PipelineLatentTesterMixin,
46+
PipelineTesterMixin,
47+
)
48+
49+
50+
enable_full_determinism()
51+
52+
53+
class StableDiffusionControlNetPAGPipelineFastTests(
54+
PipelineTesterMixin,
55+
IPAdapterTesterMixin,
56+
PipelineLatentTesterMixin,
57+
PipelineFromPipeTesterMixin,
58+
unittest.TestCase,
59+
):
60+
pipeline_class = StableDiffusionControlNetPAGPipeline
61+
params = TEXT_TO_IMAGE_PARAMS.union({"pag_scale", "pag_adaptive_scale"})
62+
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
63+
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
64+
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
65+
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
66+
67+
def get_dummy_components(self, time_cond_proj_dim=None):
68+
# Copied from tests.pipelines.controlnet.test_controlnet_sdxl.StableDiffusionXLControlNetPipelineFastTests.get_dummy_components
69+
torch.manual_seed(0)
70+
unet = UNet2DConditionModel(
71+
block_out_channels=(4, 8),
72+
layers_per_block=2,
73+
sample_size=32,
74+
in_channels=4,
75+
out_channels=4,
76+
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
77+
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
78+
cross_attention_dim=8,
79+
time_cond_proj_dim=time_cond_proj_dim,
80+
norm_num_groups=2,
81+
)
82+
torch.manual_seed(0)
83+
controlnet = ControlNetModel(
84+
block_out_channels=(4, 8),
85+
layers_per_block=2,
86+
in_channels=4,
87+
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
88+
conditioning_embedding_out_channels=(2, 4),
89+
cross_attention_dim=8,
90+
norm_num_groups=2,
91+
)
92+
torch.manual_seed(0)
93+
scheduler = DDIMScheduler(
94+
beta_start=0.00085,
95+
beta_end=0.012,
96+
beta_schedule="scaled_linear",
97+
clip_sample=False,
98+
set_alpha_to_one=False,
99+
)
100+
torch.manual_seed(0)
101+
vae = AutoencoderKL(
102+
block_out_channels=[4, 8],
103+
in_channels=3,
104+
out_channels=3,
105+
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
106+
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
107+
latent_channels=4,
108+
norm_num_groups=2,
109+
)
110+
torch.manual_seed(0)
111+
text_encoder_config = CLIPTextConfig(
112+
bos_token_id=0,
113+
eos_token_id=2,
114+
hidden_size=8,
115+
intermediate_size=16,
116+
layer_norm_eps=1e-05,
117+
num_attention_heads=2,
118+
num_hidden_layers=2,
119+
pad_token_id=1,
120+
vocab_size=1000,
121+
)
122+
text_encoder = CLIPTextModel(text_encoder_config)
123+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
124+
125+
components = {
126+
"unet": unet,
127+
"controlnet": controlnet,
128+
"scheduler": scheduler,
129+
"vae": vae,
130+
"text_encoder": text_encoder,
131+
"tokenizer": tokenizer,
132+
"safety_checker": None,
133+
"feature_extractor": None,
134+
"image_encoder": None,
135+
}
136+
return components
137+
138+
def get_dummy_inputs(self, device, seed=0):
139+
if str(device).startswith("mps"):
140+
generator = torch.manual_seed(seed)
141+
else:
142+
generator = torch.Generator(device=device).manual_seed(seed)
143+
144+
controlnet_embedder_scale_factor = 2
145+
image = randn_tensor(
146+
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
147+
generator=generator,
148+
device=torch.device(device),
149+
)
150+
151+
inputs = {
152+
"prompt": "A painting of a squirrel eating a burger",
153+
"generator": generator,
154+
"num_inference_steps": 2,
155+
"guidance_scale": 6.0,
156+
"pag_scale": 3.0,
157+
"output_type": "np",
158+
"image": image,
159+
}
160+
161+
return inputs
162+
163+
def test_pag_disable_enable(self):
164+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
165+
components = self.get_dummy_components()
166+
167+
# base pipeline (expect same output when pag is disabled)
168+
pipe_sd = StableDiffusionControlNetPipeline(**components)
169+
pipe_sd = pipe_sd.to(device)
170+
pipe_sd.set_progress_bar_config(disable=None)
171+
172+
inputs = self.get_dummy_inputs(device)
173+
del inputs["pag_scale"]
174+
assert (
175+
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
176+
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__calss__.__name__}."
177+
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
178+
179+
# pag disabled with pag_scale=0.0
180+
pipe_pag = self.pipeline_class(**components)
181+
pipe_pag = pipe_pag.to(device)
182+
pipe_pag.set_progress_bar_config(disable=None)
183+
184+
inputs = self.get_dummy_inputs(device)
185+
inputs["pag_scale"] = 0.0
186+
out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
187+
188+
# pag enabled
189+
pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
190+
pipe_pag = pipe_pag.to(device)
191+
pipe_pag.set_progress_bar_config(disable=None)
192+
193+
inputs = self.get_dummy_inputs(device)
194+
out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
195+
196+
assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
197+
assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3
198+
199+
def test_pag_cfg(self):
200+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
201+
components = self.get_dummy_components()
202+
203+
pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
204+
pipe_pag = pipe_pag.to(device)
205+
pipe_pag.set_progress_bar_config(disable=None)
206+
207+
inputs = self.get_dummy_inputs(device)
208+
image = pipe_pag(**inputs).images
209+
image_slice = image[0, -3:, -3:, -1]
210+
211+
assert image.shape == (
212+
1,
213+
64,
214+
64,
215+
3,
216+
), f"the shape of the output image should be (1, 64, 64, 3) but got {image.shape}"
217+
expected_slice = np.array(
218+
[0.45505235, 0.2785938, 0.16334778, 0.79689944, 0.53095645, 0.40135607, 0.7052706, 0.69065094, 0.41548574]
219+
)
220+
221+
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
222+
assert max_diff < 1e-3, f"output is different from expected, {image_slice.flatten()}"
223+
224+
def test_pag_uncond(self):
225+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
226+
components = self.get_dummy_components()
227+
228+
pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
229+
pipe_pag = pipe_pag.to(device)
230+
pipe_pag.set_progress_bar_config(disable=None)
231+
232+
inputs = self.get_dummy_inputs(device)
233+
inputs["guidance_scale"] = 0.0
234+
image = pipe_pag(**inputs).images
235+
image_slice = image[0, -3:, -3:, -1]
236+
237+
assert image.shape == (
238+
1,
239+
64,
240+
64,
241+
3,
242+
), f"the shape of the output image should be (1, 64, 64, 3) but got {image.shape}"
243+
expected_slice = np.array(
244+
[0.45127502, 0.2797252, 0.15970308, 0.7993157, 0.5414344, 0.40160775, 0.7114598, 0.69803864, 0.4217583]
245+
)
246+
247+
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
248+
assert max_diff < 1e-3, f"output is different from expected, {image_slice.flatten()}"

0 commit comments

Comments
 (0)