Skip to content

[Flux] Add batched inference #1227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

CarlosGomes98
Copy link
Contributor

@CarlosGomes98 CarlosGomes98 commented May 27, 2025

This PR adds the ability to perform batched and multi-gpu inference on the Flux model, following from #1205

  • The main modification is to functions in sampling.py, allowing them to take several prompts and infer with them in a batch.
  • Added a infer.py script which performs inference. This leverages the Trainer, which is great for reusability, but not so great because it inherits all the requirements of the trainer (such as a train dataset, which doesnt really make sense here)
  • Added a run_inference.sh script and documented it in README.md
  • Added an example file with prompts prompts.txt

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 27, 2025
@CarlosGomes98
Copy link
Contributor Author

@wwwjn @tianyu-l Here is the split up inference bit :)

@@ -177,9 +202,9 @@ def denoise(
# create positional encodings
POSITION_DIM = 3
latent_pos_enc = create_position_encoding_for_latents(
bsz, latent_height, latent_width, POSITION_DIM
1, latent_height, latent_width, POSITION_DIM
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Why we change the bsz to 1 here and later, as we are taking bsz prompts as input?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So theres 2 parts to this. The reason we can do this, is because for these particular tensors where im using 1 on the first dimension, they are the same for all samples. so we just want to repeat them for all of them. Due to torch broadcasting, whenever this is used in an operation with another tensor, this dimension will be expanded to match whatever is necessary from the other tensor (basically torch will automatically make this whatever the batch size is)

The reason we want to do this is twofold.

  1. It does save some memory to not have to carry around all these repeated tensors, but to just allow torch to do broadcasting during operations instead
  2. If we dont do it, whenever we are doing classifier free guidance, we will have to manually double the size of the tensors. Like this, we dont have to worry about it, as it will just correctly broadcast.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah the reasonning makes sense to me.

It feels to me that, if the result of latent_pos_enc is always identical for all samples in a batch, we probably should just remove the bsz as input arg and not worry about batch at all until its broadcast, instead of hardcoding bsz=1 at multiple places.

position_encoding = position_encoding.repeat(bsz, 1, 1)

@@ -203,9 +231,12 @@ def denoise(
if enable_classifer_free_guidance:
pred_u, pred_c = pred.chunk(2)
pred = pred_u + classifier_free_guidance_scale * (pred_c - pred_u)

pred = pred.repeat(2, 1, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And QQ: Why we need to repeat the first dimension of pred?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My logic is as follows:

Previously, since we were dealing with just 1 input, we didnt have to do this, as pred would end up with a bsz of 1. In the case of classifier_free_guidance, latents would have a bsz of 2. Since pred has a bsz of 1, torch would broadcast this and it would work.

However, now, we can support batch sizes > 1, so pred will end up with a batch size = 1/2 of the bsz of latents, which in general will not be 1. Thus, we cannot benefit from broadcasting anymore, and have to do this repeat manually ourselves.


pil_images = [torch_to_pil(img) for img in all_images]
if config.inference.save_path:
path = Path(config.job.dump_folder, config.inference.save_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reuse the save_image() function in sampling.py here?

Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much @CarlosGomes98 for splitting the diff and make such a clear inference script, this is a really good feature for FLUX!

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If trainer makes some assumptions about the dataset, we can rethink about the requirements of trainer. It's time to make trainer inference/eval friendly as well. cc., @wwwjn

return results


if __name__ == "__main__":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to put the following logic in a separate function (like main()). This will allow easier logic reuse (e.g., for unittests).

Comment on lines +98 to +100
3,
256,
256,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make these magic numbers constant variables for the readability purpose?

config = config_manager.parse_args()
trainer = FluxTrainer(config)
world_size = int(os.environ["WORLD_SIZE"])
global_id = int(os.environ["RANK"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use rank to match the convention?

)
clip_tokenizer = FluxTokenizer(config.encoder.clip_encoder, max_length=77)

if global_id == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not required, the logging information should be controlled by TorchRun configuration, which TorchTitan run scripts default to rank 0 only.

]

# Gather images from all processes
torch.distributed.all_gather(gathered_images, images)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be gather() not all_gather() as you are not using the results on other ranks. Though I don't know if there will be any performance gains, gather() produces less total network traffic.

torch.distributed.all_gather(gathered_images, images)

# re-order the images to match the original ordering of prompts
if global_id == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another good motivation of making the logic in main() -- you can do early return here to remove one indention.

img.save(
path / f"img_{i}.png", exif=exif_data, quality=95, subsampling=0
)
torch.distributed.destroy_process_group()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should something like:

try:
     main()
finally:
     if torch.distributed.is_initialized():
         torch.distributed.destroy_process_group()

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If trainer makes some assumptions about the dataset, we can rethink about the requirements of trainer. It's time to make trainer inference/eval friendly as well. cc., @wwwjn

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice progress! I left some comments on the organization of the code.

# e.g.
# LOG_RANK=0,1 NGPU=4 ./torchtitan/experiments/flux/run_inference.sh

if [ -z "${JOB_FOLDER}" ]; then
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we default to job.dump_folder's default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would argue making it explicitly required makes using the script clearer and less error prone, at the cost of some user friendliness. But I understand that point as well. I'll make the change

exif_data[ExifTags.Base.Model] = "Schnell"
exif_data[ExifTags.Base.ImageDescription] = original_prompts[i]
img.save(
path / f"img_{i}.png", exif=exif_data, quality=95, subsampling=0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If eventually we are saving individual image files, why do we even perform gather / all-gather? We could save different images in the same folder from different ranks, just with unique names i.e. rank_{i} in the .png name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats true. The tricky part is placing all the tensors back in the same order so we can match them up with the prompts. Because of the padding involved its not super straight forward, but I'm sure there's a way to do it while having each rank write its own images. Just might take some more thought

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An easier way to bypass padding is to require the prompts file having length divisible by DP degree, or world size. Users of this script can manually add empty rows if needed.



@record
def inference(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not obvious to me why this is worth a standalone function -- can we just call generate_image in the main script?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only difference here is handling the batching. This could be handled by a method in the trainer, but I previously had it that way and refactored it out after discussions in #1205

I think both are valid, and its a matter of a design decision for torchtitan

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Sounds OK to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functionality-wise, this seems similar to torchtitan/experiments/flux/tests/test_generate_image.py, but with parallelized model.
I think we can make test_generate_image a unit test, if not removing it after this multi-gpu generation lands. @wwwjn

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to group files a bit more logically, can we put run_inference.sh, prompts.txt, infer.py under the flux/inference folder? We can leave sampling.py outside as it's also used by the evaluation in train.py

latents = latents + (t_prev - t_curr) * pred

if enable_classifer_free_guidance:
latents = latents.chunk(2)[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this code, together with the pred = pred.repeat(2, 1, 1) above, looks very obscure.
If they are necessary, please add adequate comments.

@tianyu-l
Copy link
Contributor

tianyu-l commented May 28, 2025

If trainer makes some assumptions about the dataset, we can rethink about the requirements of trainer. It's time to make trainer inference/eval friendly as well.

@fegin
Interesting... I had thought Trainer as its name suggests, in principle should be used for training but not eval / inference? (validation is part of training)

btw for Llama, we have dedicated scripts to do multi-gpu generation / inference, which doesn't reuse the trainer
https://github.com/pytorch/torchtitan/blob/main/scripts/generate/test_generate.py
Actually it couldn't, because Sequence Parallel doesn't work with odd lengths language sequences.

@fegin
Copy link
Contributor

fegin commented May 28, 2025

@tianyu-l Naming is just one minor issue. IMO, it depends on how much code and logic are shared. If components are the most common pieces but the trainer is not, then I agree we shouldn't worry too much about this. However, if there is largely duplicated logic, especially performance critical parts (e.g., GC), we should either refactor trainer to make it more generic or evaluate the common logic in trainer to see if we can put this logic in components.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants