Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flux LoRA] fix issues in flux lora scripts #11111

Merged
merged 33 commits into from
Apr 8, 2025

Conversation

linoytsaban
Copy link
Collaborator

@linoytsaban linoytsaban commented Mar 18, 2025

fix remaining pending issues from #10313, #9476 in Flux LoRA training scripts

  • verify optimizer is updating properly (transformer only ☑️, test encoder w/ clip ☑️, pivotal w/ clip, pivotal w/ clip & t5, ti) wip
  • accelerate error when running on multiple gpus
  • replace scheduler
  • log_validation with mixed precision
  • save intermediate embeddings when checkpointing enabled

code snippets and output examples:

  • for running log_validation with mixed precision-
import os
os.environ['MODEL_NAME'] = "black-forest-labs/FLUX.1-dev"
os.environ['DATASET_NAME'] = "dog"
os.environ['OUTPUT_DIR'] = "flux-test-1"

!accelerate launch train_dreambooth_lora_flux.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$DATASET_NAME \
  --output_dir=$OUTPUT_DIR \
  --mixed_precision="bf16" \
  --instance_prompt="a photo of sks dog" \
  --resolution=1024 \
  --train_batch_size=1 \
  --guidance_scale=1 \
  --gradient_accumulation_steps=1 \
  --optimizer="prodigy" \
  --learning_rate=1. \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=500 \
  --checkpointing_steps=250 \
  --validation_prompt="a photo of sks dog in a bucket"\
  --validation_epochs=25 \
  --seed="0" \
  --push_to_hub

validation output at step 380:
Screenshot 2025-03-18 at 23 22 53

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@linoytsaban linoytsaban changed the title [Flux LoRA] fix issues in advanced script [Flux LoRA] fix issues in flux lora scripts Mar 18, 2025
@linoytsaban linoytsaban requested a review from sayakpaul March 18, 2025 21:40
@linoytsaban linoytsaban added bug Something isn't working training labels Mar 18, 2025
@luchaoqi
Copy link
Contributor

luchaoqi commented Mar 19, 2025

Hi @linoytsaban , thanks for this prompt fix!

I believe the accelerator would produce error with line here with textual inversion specifically following the blog here:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/playpen-nas-ssd/luchao/projects/diffusers/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced_linoy.py", line 2576, in <module>
[rank0]:     main(args)
[rank0]:   File "/playpen-nas-ssd/luchao/projects/diffusers/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced_linoy.py", line 2273, in main
[rank0]:     prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
[rank0]:                                                     ^^^^^^^^^^^^^^
[rank0]:   File "/playpen-nas-ssd/luchao/projects/diffusers/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced_linoy.py", line 1446, in encode_prompt
[rank0]:     dtype = text_encoders[0].dtype
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/playpen-nas-ssd/luchao/software/miniconda3/envs/diffuser/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1928, in __getattr__
[rank0]:     raise AttributeError(
[rank0]: AttributeError: 'DistributedDataParallel' object has no attribute 'dtype'. Did you mean: 'type'?

Also Is it possible to verify if textual inversion works in sks dog case on your end as well? e.g. pure CLIP textual inversion as mentioned here

  --train_text_encoder_ti \
  --train_text_encoder_ti_frac=1 \
  --train_transformer_frac=0

@linoytsaban
Copy link
Collaborator Author

hey @luchaoqi! yes I'm currently testing multiple configurations - will definitely test with pivotal tuning with clip and pure textual inversion with clip.

re: error with accelerator when running with multiple processes - adding it to the todo list :)

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial comments.

…n, fix accelerator.accumulate call in advanced script
@linoytsaban
Copy link
Collaborator Author

@sayakpaul I noticed in the scripts some times we use: accelerator.unwrap_model, and sometimes we use unwrap_model

    def unwrap_model(model):
        model = accelerator.unwrap_model(model)
        model = model._orig_mod if is_compiled_module(model) else model
        return model

do you recall why it's not consistently one way or the other?

@sayakpaul
Copy link
Member

Using the unwrap_model() function works. The ones that doesn't should be updated to have something similar. We added unwrap_model() to have more consistency for cases with torch.compile()

@linoytsaban
Copy link
Collaborator Author

Hey @luchaoqi! could you please check if the accelerator now works fine with distributed training? I think it should be resolved now

@luchaoqi
Copy link
Contributor

Hi @linoytsaban, yes distributed training works as expected.

Pure textual inversion pops up new problems:

03/19/2025 10:20:03 - INFO - __main__ - Running validation...
 Generating 4 images with prompt: a photo of <s0><s1> person at 50 years old.
Traceback (most recent call last):
  File "/playpen-nas-ssd/luchao/projects/diffusers/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced_linoy.py", line 2408, in <module>
    main(args)
  File "/playpen-nas-ssd/luchao/projects/diffusers/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced_linoy.py", line 2055, in main
    text_encoder_one.train()
    ^^^^^^^^^^^^^^^^
UnboundLocalError: cannot access local variable 'text_encoder_one' where it is not associated with a value
[rank0]: Traceback (most recent call last):
[rank0]:   File "/playpen-nas-ssd/luchao/projects/diffusers/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced_linoy.py", line 2408, in <module>
[rank0]:     main(args)
[rank0]:   File "/playpen-nas-ssd/luchao/projects/diffusers/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced_linoy.py", line 2055, in main
[rank0]:     text_encoder_one.train()
[rank0]:     ^^^^^^^^^^^^^^^^
[rank0]: UnboundLocalError: cannot access local variable 'text_encoder_one' where it is not associated with a value

@luchaoqi
Copy link
Contributor

Hi @linoytsaban, just wanted to follow up on the textual inversion part—do you anticipate it being fixed soon, or will it need a bit more time?

@linoytsaban
Copy link
Collaborator Author

@luchaoqi yes should be done soon!

@linoytsaban linoytsaban marked this pull request as ready for review April 2, 2025 05:28
@linoytsaban
Copy link
Collaborator Author

@luchaoqi if you want to give it a try the current version should be fixed

@luchaoqi
Copy link
Contributor

luchaoqi commented Apr 4, 2025

@linoytsaban thanks! Would definitely try it out asap once I get some time.
Feel free to merge it if other reviewers agree, cheers!

@linoytsaban linoytsaban requested a review from sayakpaul April 8, 2025 08:31
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Left some comments. Let me know if they make sense.

@@ -228,10 +228,21 @@ def log_validation(

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
autocast_ctx = nullcontext()
autocast_ctx = torch.autocast(accelerator.device.type)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is only needed for the intermediate validation. Do we need to check for that?

Copy link
Collaborator Author

@linoytsaban linoytsaban Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think you're right, tested it now and seems to work as expected, changed it now

@linoytsaban linoytsaban requested a review from sayakpaul April 8, 2025 12:26
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a ton for handling this!

@sayakpaul
Copy link
Member

@bot /style

Copy link
Contributor

github-actions bot commented Apr 8, 2025

Style fixes have been applied. View the workflow run here.

@linoytsaban
Copy link
Collaborator Author

Failing test is unrelated

@linoytsaban linoytsaban merged commit 71f34fc into huggingface:main Apr 8, 2025
8 of 9 checks passed
@linoytsaban linoytsaban deleted the flux_lora_advanced branch April 9, 2025 06:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants