Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #11068 Add extended support for SD3 ControlNet (Stability AI) #11084

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

andjoer
Copy link
Contributor

@andjoer andjoer commented Mar 17, 2025

Summary

This PR adds support for ControlNet models from Stability AI (SD3 variants), specifically:

  • Conditional application of pos_embed for ControlNets with use_pos_embed = False.
  • Modified checkpoint loading to resume properly from saved state.
  • Adjusted training loop to support models that require force_zeros_for_pooled_projection or omit encoder_hidden_states.

Tested on: stabilityai/stable-diffusion-3.5-large-controlnet-blur, stabilityai/stable-diffusion-3.5-large

Motivation

These changes are required to enable training and inference with SD3 ControlNet variants contributed by Stability AI.

Points to Discuss

  • The script was structured in a way that the optimizer is initialized according to the default configuration of a controlnet. This requires to specify the initial pretrained controlnet model that was fine-tuned again when resuming the training from a checkpoint. I left this part of the script unchanged but it could be changed that if training is resumed from a checkpoint the model is initialized according to this configuration and not according to the default configuration
  • As mentioned by the authors of the SD3 controlnet pipeline the issues are basically due to a problem with the conversion script that made the controlnet models compatible with HF. Now the base model gets loaded and the embedding weights are initialized from the transformer. When a checkpoint is loaded these are present in it, but not in the reinitialized model (since the config states that the embeddings are not used). I solved this by setting strict=False when loading the checkpoint.

…, checkpoint resume, and training loop adjustments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant