diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index ffe460d72de8..5758b3348051 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -999,6 +999,10 @@ def main(args): transformer, num_extra_conditioning_channels=args.num_extra_conditioning_channels ) + if hasattr(controlnet.config, "use_pos_embed") and controlnet.config.use_pos_embed is False: + pos_embed = controlnet._get_pos_embed_from_transformer(transformer) + controlnet.pos_embed = pos_embed.to(controlnet.dtype).to(controlnet.device) + transformer.requires_grad_(False) vae.requires_grad_(False) text_encoder_one.requires_grad_(False) @@ -1036,8 +1040,14 @@ def load_model_hook(models, input_dir): # load diffusers style into model load_model = SD3ControlNetModel.from_pretrained(input_dir, subfolder="controlnet") model.register_to_config(**load_model.config) - - model.load_state_dict(load_model.state_dict()) + + if hasattr(controlnet.config, "use_pos_embed") and controlnet.config.use_pos_embed is False: + model.load_state_dict(load_model.state_dict(),strict=False) + pos_embed = controlnet._get_pos_embed_from_transformer(transformer) + controlnet.pos_embed = pos_embed.to(controlnet.dtype).to(controlnet.device) + else: + model.load_state_dict(load_model.state_dict()) + del load_model accelerator.register_save_state_pre_hook(save_model_hook) @@ -1291,11 +1301,29 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): controlnet_image = vae.encode(controlnet_image).latent_dist.sample() controlnet_image = controlnet_image * vae.config.scaling_factor + controlnet_config = ( + controlnet.config + if isinstance(controlnet, SD3ControlNetModel) + else controlnet.nets[0].config + ) + if controlnet_config.force_zeros_for_pooled_projection: + # instantx sd3 controlnet used zero pooled projection + controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds) + else: + controlnet_pooled_projections = pooled_prompt_embeds + + if controlnet_config.joint_attention_dim is not None: + controlnet_encoder_hidden_states = prompt_embeds + else: + # SD35 official 8b controlnet does not use encoder_hidden_states + controlnet_encoder_hidden_states = None + + control_block_res_samples = controlnet( hidden_states=noisy_model_input, timestep=timesteps, - encoder_hidden_states=prompt_embeds, - pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=controlnet_encoder_hidden_states, + pooled_projections=controlnet_pooled_projections, controlnet_cond=controlnet_image, return_dict=False, )[0]