Skip to content

Fixes #11068 Add extended support for SD3 ControlNet (Stability AI) #11084

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions examples/controlnet/train_controlnet_sd3.py
Original file line number Diff line number Diff line change
@@ -999,6 +999,10 @@ def main(args):
transformer, num_extra_conditioning_channels=args.num_extra_conditioning_channels
)

if hasattr(controlnet.config, "use_pos_embed") and controlnet.config.use_pos_embed is False:
pos_embed = controlnet._get_pos_embed_from_transformer(transformer)
controlnet.pos_embed = pos_embed.to(controlnet.dtype).to(controlnet.device)

transformer.requires_grad_(False)
vae.requires_grad_(False)
text_encoder_one.requires_grad_(False)
@@ -1036,8 +1040,14 @@ def load_model_hook(models, input_dir):
# load diffusers style into model
load_model = SD3ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
model.register_to_config(**load_model.config)

model.load_state_dict(load_model.state_dict())

if hasattr(controlnet.config, "use_pos_embed") and controlnet.config.use_pos_embed is False:
model.load_state_dict(load_model.state_dict(),strict=False)
pos_embed = controlnet._get_pos_embed_from_transformer(transformer)
controlnet.pos_embed = pos_embed.to(controlnet.dtype).to(controlnet.device)
else:
model.load_state_dict(load_model.state_dict())

del load_model

accelerator.register_save_state_pre_hook(save_model_hook)
@@ -1291,11 +1301,29 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
controlnet_image = vae.encode(controlnet_image).latent_dist.sample()
controlnet_image = controlnet_image * vae.config.scaling_factor

controlnet_config = (
controlnet.config
if isinstance(controlnet, SD3ControlNetModel)
else controlnet.nets[0].config
)
if controlnet_config.force_zeros_for_pooled_projection:
# instantx sd3 controlnet used zero pooled projection
controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds)
else:
controlnet_pooled_projections = pooled_prompt_embeds

if controlnet_config.joint_attention_dim is not None:
controlnet_encoder_hidden_states = prompt_embeds
else:
# SD35 official 8b controlnet does not use encoder_hidden_states
controlnet_encoder_hidden_states = None


control_block_res_samples = controlnet(
hidden_states=noisy_model_input,
timestep=timesteps,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=controlnet_encoder_hidden_states,
pooled_projections=controlnet_pooled_projections,
controlnet_cond=controlnet_image,
return_dict=False,
)[0]