-
Notifications
You must be signed in to change notification settings - Fork 379
[Flux] Add batched inference #1227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Flux] Add batched inference #1227
Conversation
@@ -177,9 +202,9 @@ def denoise( | |||
# create positional encodings | |||
POSITION_DIM = 3 | |||
latent_pos_enc = create_position_encoding_for_latents( | |||
bsz, latent_height, latent_width, POSITION_DIM | |||
1, latent_height, latent_width, POSITION_DIM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: Why we change the bsz
to 1 here and later, as we are taking bsz
prompts as input?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So theres 2 parts to this. The reason we can do this, is because for these particular tensors where im using 1 on the first dimension, they are the same for all samples. so we just want to repeat them for all of them. Due to torch broadcasting, whenever this is used in an operation with another tensor, this dimension will be expanded to match whatever is necessary from the other tensor (basically torch will automatically make this whatever the batch size is)
The reason we want to do this is twofold.
- It does save some memory to not have to carry around all these repeated tensors, but to just allow torch to do broadcasting during operations instead
- If we dont do it, whenever we are doing classifier free guidance, we will have to manually double the size of the tensors. Like this, we dont have to worry about it, as it will just correctly broadcast.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah the reasonning makes sense to me.
It feels to me that, if the result of latent_pos_enc
is always identical for all samples in a batch, we probably should just remove the bsz
as input arg and not worry about batch at all until its broadcast, instead of hardcoding bsz=1
at multiple places.
position_encoding = position_encoding.repeat(bsz, 1, 1) |
@@ -203,9 +231,12 @@ def denoise( | |||
if enable_classifer_free_guidance: | |||
pred_u, pred_c = pred.chunk(2) | |||
pred = pred_u + classifier_free_guidance_scale * (pred_c - pred_u) | |||
|
|||
pred = pred.repeat(2, 1, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And QQ: Why we need to repeat the first dimension of pred
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My logic is as follows:
Previously, since we were dealing with just 1 input, we didnt have to do this, as pred would end up with a bsz of 1. In the case of classifier_free_guidance, latents
would have a bsz of 2. Since pred has a bsz of 1, torch would broadcast this and it would work.
However, now, we can support batch sizes > 1, so pred will end up with a batch size = 1/2 of the bsz of latents, which in general will not be 1. Thus, we cannot benefit from broadcasting anymore, and have to do this repeat manually ourselves.
|
||
pil_images = [torch_to_pil(img) for img in all_images] | ||
if config.inference.save_path: | ||
path = Path(config.job.dump_folder, config.inference.save_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we reuse the save_image()
function in sampling.py
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much @CarlosGomes98 for splitting the diff and make such a clear inference script, this is a really good feature for FLUX!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If trainer
makes some assumptions about the dataset, we can rethink about the requirements of trainer
. It's time to make trainer
inference/eval
friendly as well. cc., @wwwjn
return results | ||
|
||
|
||
if __name__ == "__main__": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to put the following logic in a separate function (like main()
). This will allow easier logic reuse (e.g., for unittests).
3, | ||
256, | ||
256, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make these magic numbers constant variables for the readability purpose?
config = config_manager.parse_args() | ||
trainer = FluxTrainer(config) | ||
world_size = int(os.environ["WORLD_SIZE"]) | ||
global_id = int(os.environ["RANK"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use rank
to match the convention?
) | ||
clip_tokenizer = FluxTokenizer(config.encoder.clip_encoder, max_length=77) | ||
|
||
if global_id == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not required, the logging information should be controlled by TorchRun configuration, which TorchTitan run scripts default to rank 0 only.
] | ||
|
||
# Gather images from all processes | ||
torch.distributed.all_gather(gathered_images, images) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be gather()
not all_gather()
as you are not using the results on other ranks. Though I don't know if there will be any performance gains, gather()
produces less total network traffic.
torch.distributed.all_gather(gathered_images, images) | ||
|
||
# re-order the images to match the original ordering of prompts | ||
if global_id == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another good motivation of making the logic in main()
-- you can do early return here to remove one indention.
img.save( | ||
path / f"img_{i}.png", exif=exif_data, quality=95, subsampling=0 | ||
) | ||
torch.distributed.destroy_process_group() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should something like:
try:
main()
finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If trainer
makes some assumptions about the dataset, we can rethink about the requirements of trainer
. It's time to make trainer
inference/eval
friendly as well. cc., @wwwjn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice progress! I left some comments on the organization of the code.
# e.g. | ||
# LOG_RANK=0,1 NGPU=4 ./torchtitan/experiments/flux/run_inference.sh | ||
|
||
if [ -z "${JOB_FOLDER}" ]; then |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we default to job.dump_folder
's default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would argue making it explicitly required makes using the script clearer and less error prone, at the cost of some user friendliness. But I understand that point as well. I'll make the change
exif_data[ExifTags.Base.Model] = "Schnell" | ||
exif_data[ExifTags.Base.ImageDescription] = original_prompts[i] | ||
img.save( | ||
path / f"img_{i}.png", exif=exif_data, quality=95, subsampling=0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If eventually we are saving individual image files, why do we even perform gather / all-gather? We could save different images in the same folder from different ranks, just with unique names i.e. rank_{i}
in the .png
name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thats true. The tricky part is placing all the tensors back in the same order so we can match them up with the prompts. Because of the padding involved its not super straight forward, but I'm sure there's a way to do it while having each rank write its own images. Just might take some more thought
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An easier way to bypass padding is to require the prompts file having length divisible by DP degree, or world size. Users of this script can manually add empty rows if needed.
|
||
|
||
@record | ||
def inference( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not obvious to me why this is worth a standalone function -- can we just call generate_image
in the main script?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only difference here is handling the batching. This could be handled by a method in the trainer, but I previously had it that way and refactored it out after discussions in #1205
I think both are valid, and its a matter of a design decision for torchtitan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Sounds OK to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Functionality-wise, this seems similar to torchtitan/experiments/flux/tests/test_generate_image.py
, but with parallelized model.
I think we can make test_generate_image
a unit test, if not removing it after this multi-gpu generation lands. @wwwjn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to group files a bit more logically, can we put run_inference.sh
, prompts.txt
, infer.py
under the flux/inference
folder? We can leave sampling.py
outside as it's also used by the evaluation in train.py
latents = latents + (t_prev - t_curr) * pred | ||
|
||
if enable_classifer_free_guidance: | ||
latents = latents.chunk(2)[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this code, together with the pred = pred.repeat(2, 1, 1)
above, looks very obscure.
If they are necessary, please add adequate comments.
@fegin btw for Llama, we have dedicated scripts to do multi-gpu generation / inference, which doesn't reuse the trainer |
@tianyu-l Naming is just one minor issue. IMO, it depends on how much code and logic are shared. If components are the most common pieces but the trainer is not, then I agree we shouldn't worry too much about this. However, if there is largely duplicated logic, especially performance critical parts (e.g., GC), we should either refactor trainer to make it more generic or evaluate the common logic in trainer to see if we can put this logic in |
This PR adds the ability to perform batched and multi-gpu inference on the Flux model, following from #1205
sampling.py
, allowing them to take several prompts and infer with them in a batch.infer.py
script which performs inference. This leverages the Trainer, which is great for reusability, but not so great because it inherits all the requirements of the trainer (such as a train dataset, which doesnt really make sense here)run_inference.sh
script and documented it in README.mdprompts.txt