Skip to content

Commit 21fcfd5

Browse files
committed
implement 'local' caption upsampling for Flux.2
1 parent fac57fd commit 21fcfd5

File tree

4 files changed

+248
-24
lines changed

4 files changed

+248
-24
lines changed

docs/diffusers/api/pipelines/flux2.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,10 @@ Original model checkpoints for Flux can be found [here](https://huggingface.co/b
2323
!!! tip
2424
Flux2 can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more.
2525

26+
## Caption upsampling
27+
28+
Flux.2 can potentially generate better better outputs with better prompts. We can "upsample"
29+
an input prompt by setting the `caption_upsample_temperature` argument in the pipeline call arguments.
30+
The [official implementation](https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L140) recommends this value to be 0.15.
31+
2632
::: mindone.diffusers.Flux2Pipeline

mindone/diffusers/pipelines/flux2/image_processor.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import Tuple
16+
from typing import List
1717

1818
import PIL.Image
1919

@@ -96,7 +96,7 @@ def check_image_input(
9696
return image
9797

9898
@staticmethod
99-
def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> Tuple[int, int]:
99+
def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> PIL.Image.Image:
100100
image_width, image_height = image.size
101101

102102
scale = math.sqrt(target_area / (image_width * image_height))
@@ -105,6 +105,14 @@ def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 102
105105

106106
return image.resize((width, height), PIL.Image.Resampling.LANCZOS)
107107

108+
@staticmethod
109+
def _resize_if_exceeds_area(image, target_area=1024 * 1024) -> PIL.Image.Image:
110+
image_width, image_height = image.size
111+
pixel_count = image_width * image_height
112+
if pixel_count <= target_area:
113+
return image
114+
return Flux2ImageProcessor._resize_to_target_area(image, target_area)
115+
108116
def _resize_and_crop(
109117
self,
110118
image: PIL.Image.Image,
@@ -134,3 +142,35 @@ def _resize_and_crop(
134142
bottom = top + height
135143

136144
return image.crop((left, top, right, bottom))
145+
146+
# Taken from
147+
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L310C1-L339C19
148+
@staticmethod
149+
def concatenate_images(images: List[PIL.Image.Image]) -> PIL.Image.Image:
150+
"""
151+
Concatenate a list of PIL images horizontally with center alignment and white background.
152+
"""
153+
154+
# If only one image, return a copy of it
155+
if len(images) == 1:
156+
return images[0].copy()
157+
158+
# Convert all images to RGB if not already
159+
images = [img.convert("RGB") if img.mode != "RGB" else img for img in images]
160+
161+
# Calculate dimensions for horizontal concatenation
162+
total_width = sum(img.width for img in images)
163+
max_height = max(img.height for img in images)
164+
165+
# Create new image with white background
166+
background_color = (255, 255, 255)
167+
new_img = PIL.Image.new("RGB", (total_width, max_height), background_color)
168+
169+
# Paste images with center alignment
170+
x_offset = 0
171+
for img in images:
172+
y_offset = (max_height - img.height) // 2
173+
new_img.paste(img, (x_offset, y_offset))
174+
x_offset += img.width
175+
176+
return new_img

mindone/diffusers/pipelines/flux2/pipeline_flux2.py

Lines changed: 167 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from ..pipeline_utils import DiffusionPipeline
3434
from .image_processor import Flux2ImageProcessor
3535
from .pipeline_output import Flux2PipelineOutput
36+
from .system_messages import SYSTEM_MESSAGE, SYSTEM_MESSAGE_UPSAMPLING_I2I, SYSTEM_MESSAGE_UPSAMPLING_T2I
3637

3738
XLA_AVAILABLE = False
3839

@@ -54,25 +55,105 @@
5455
```
5556
"""
5657

58+
UPSAMPLING_MAX_IMAGE_SIZE = 768**2
5759

58-
def format_text_input(prompts: List[str], system_message: str = None):
60+
61+
# Adapted from
62+
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68
63+
def format_input(
64+
prompts: List[str],
65+
system_message: str = SYSTEM_MESSAGE,
66+
images: Optional[Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]]] = None,
67+
):
68+
"""
69+
Format a batch of text prompts into the conversation format expected by apply_chat_template. Optionally, add images
70+
to the input.
71+
72+
Args:
73+
prompts: List of text prompts
74+
system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE)
75+
images (optional): List of images to add to the input.
76+
77+
Returns:
78+
List of conversations, where each conversation is a list of message dicts
79+
"""
5980
# Remove [IMG] tokens from prompts to avoid Pixtral validation issues
6081
# when truncation is enabled. The processor counts [IMG] tokens and fails
6182
# if the count changes after truncation.
6283
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
6384

64-
return [
65-
[
66-
{
67-
"role": "system",
68-
"content": [{"type": "text", "text": system_message}],
69-
},
70-
{"role": "user", "content": [{"type": "text", "text": prompt}]},
85+
if images is None or len(images) == 0:
86+
return [
87+
[
88+
{
89+
"role": "system",
90+
"content": [{"type": "text", "text": system_message}],
91+
},
92+
{"role": "user", "content": [{"type": "text", "text": prompt}]},
93+
]
94+
for prompt in cleaned_txt
7195
]
72-
for prompt in cleaned_txt
96+
else:
97+
assert len(images) == len(prompts), "Number of images must match number of prompts"
98+
messages = [
99+
[
100+
{
101+
"role": "system",
102+
"content": [{"type": "text", "text": system_message}],
103+
},
104+
]
105+
for _ in cleaned_txt
106+
]
107+
108+
for i, (el, images) in enumerate(zip(messages, images)):
109+
# optionally add the images per batch element.
110+
if images is not None:
111+
el.append(
112+
{
113+
"role": "user",
114+
"content": [{"type": "image", "image": image_obj} for image_obj in images],
115+
}
116+
)
117+
# add the text.
118+
el.append(
119+
{
120+
"role": "user",
121+
"content": [{"type": "text", "text": cleaned_txt[i]}],
122+
}
123+
)
124+
125+
return messages
126+
127+
128+
# Adapted from
129+
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L49C5-L66C19
130+
def _validate_and_process_images(
131+
images: List[List[PIL.Image.Image]] | List[PIL.Image.Image],
132+
image_processor: Flux2ImageProcessor,
133+
upsampling_max_image_size: int,
134+
) -> List[List[PIL.Image.Image]]:
135+
# Simple validation: ensure it's a list of PIL images or list of lists of PIL images
136+
if not images:
137+
return []
138+
139+
# Check if it's a list of lists or a list of images
140+
if isinstance(images[0], PIL.Image.Image):
141+
# It's a list of images, convert to list of lists
142+
images = [[im] for im in images]
143+
144+
# potentially concatenate multiple images to reduce the size
145+
images = [[image_processor.concatenate_images(img_i)] if len(img_i) > 1 else img_i for img_i in images]
146+
147+
# cap the pixels
148+
images = [
149+
[image_processor._resize_if_exceeds_area(img_i, upsampling_max_image_size) for img_i in img_i]
150+
for img_i in images
73151
]
152+
return images
74153

75154

155+
# Taken from
156+
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L251
76157
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
77158
a1, b1 = 8.73809524e-05, 1.89833333
78159
a2, b2 = 0.00016927, 0.45666666
@@ -209,9 +290,10 @@ def __init__(
209290
self.tokenizer_max_length = 512
210291
self.default_sample_size = 128
211292

212-
# fmt: off
213-
self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation." # noqa
214-
# fmt: on
293+
self.system_message = SYSTEM_MESSAGE
294+
self.system_message_upsampling_t2i = SYSTEM_MESSAGE_UPSAMPLING_T2I
295+
self.system_message_upsampling_i2i = SYSTEM_MESSAGE_UPSAMPLING_I2I
296+
self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE
215297

216298
@staticmethod
217299
def _get_mistral_3_small_prompt_embeds(
@@ -220,18 +302,15 @@ def _get_mistral_3_small_prompt_embeds(
220302
prompt: Union[str, List[str]],
221303
dtype: Optional[ms.Type] = None,
222304
max_sequence_length: int = 512,
223-
# fmt: off
224-
system_message: str = "You are an AI that reasons about image descriptions. " \
225-
"You give structured responses focusing on object relationships, object attribution and actions without speculation.",
226-
# fmt: on
305+
system_message: str = SYSTEM_MESSAGE,
227306
hidden_states_layers: List[int] = (10, 20, 30),
228307
):
229308
dtype = text_encoder.dtype if dtype is None else dtype
230309

231310
prompt = [prompt] if isinstance(prompt, str) else prompt
232311

233312
# Format input messages
234-
messages_batch = format_text_input(prompts=prompt, system_message=system_message)
313+
messages_batch = format_input(prompts=prompt, system_message=system_message)
235314

236315
# Process all messages at once
237316
inputs = tokenizer.apply_chat_template(
@@ -421,6 +500,66 @@ def _unpack_latents_with_ids(x: ms.Tensor, x_ids: ms.Tensor) -> list[ms.Tensor]:
421500

422501
return mint.stack(x_list, dim=0)
423502

503+
def upsample_prompt(
504+
self,
505+
prompt: Union[str, List[str]],
506+
images: Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]] = None,
507+
temperature: float = 0.15,
508+
) -> List[str]:
509+
prompt = [prompt] if isinstance(prompt, str) else prompt
510+
511+
# Set system message based on whether images are provided
512+
if images is None or len(images) == 0 or images[0] is None:
513+
system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I
514+
else:
515+
system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I
516+
517+
# Validate and process the input images
518+
if images:
519+
images = _validate_and_process_images(images, self.image_processor, self.upsampling_max_image_size)
520+
521+
# Format input messages
522+
messages_batch = format_input(prompts=prompt, system_message=system_message, images=images)
523+
524+
# Process all messages at once
525+
# with image processing a too short max length can throw an error in here.
526+
inputs = self.tokenizer.apply_chat_template(
527+
messages_batch,
528+
add_generation_prompt=True,
529+
tokenize=True,
530+
return_dict=True,
531+
return_tensors="np",
532+
padding="max_length",
533+
truncation=True,
534+
max_length=2048,
535+
)
536+
537+
# Move to device
538+
inputs["input_ids"] = ms.tensor(inputs["input_ids"])
539+
inputs["attention_mask"] = ms.tensor(inputs["attention_mask"])
540+
541+
if "pixel_values" in inputs:
542+
inputs["pixel_values"] = ms.tensor(inputs["pixel_values"]).to(self.text_encoder.dtype)
543+
544+
# Generate text using the model's generate method
545+
generated_ids = self.text_encoder.generate(
546+
**inputs,
547+
max_new_tokens=512,
548+
do_sample=True,
549+
temperature=temperature,
550+
use_cache=True,
551+
)
552+
553+
# Decode only the newly generated tokens (skip input tokens)
554+
# Extract only the generated portion
555+
input_length = inputs["input_ids"].shape[1]
556+
generated_tokens = generated_ids[:, input_length:]
557+
558+
upsampled_prompt = self.tokenizer.tokenizer.batch_decode(
559+
generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True
560+
)
561+
return upsampled_prompt
562+
424563
def encode_prompt(
425564
self,
426565
prompt: Union[str, List[str]],
@@ -605,6 +744,7 @@ def __call__(
605744
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
606745
max_sequence_length: int = 512,
607746
text_encoder_out_layers: Tuple[int] = (10, 20, 30),
747+
caption_upsample_temperature: float = None,
608748
):
609749
r"""
610750
Function invoked when calling the pipeline for generation.
@@ -620,11 +760,11 @@ def __call__(
620760
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
621761
instead.
622762
guidance_scale (`float`, *optional*, defaults to 1.0):
623-
Guidance scale as defined in [Classifier-Free Diffusion
624-
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
625-
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
626-
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
627-
the text `prompt`, usually at the expense of lower image quality.
763+
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
764+
a model to generate images more aligned with `prompt` at the expense of lower image quality.
765+
766+
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
767+
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
628768
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
629769
The height in pixels of the generated image. This is set to 1024 by default for the best results.
630770
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -669,6 +809,9 @@ def __call__(
669809
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
670810
text_encoder_out_layers (`Tuple[int]`):
671811
Layer indices to use in the `text_encoder` to derive the final prompt embeddings.
812+
caption_upsample_temperature (`float`):
813+
When specified, we will try to perform caption upsampling for potentially improved outputs. We
814+
recommend setting it to 0.15 if caption upsampling is to be performed.
672815
673816
Examples:
674817
@@ -701,6 +844,8 @@ def __call__(
701844
batch_size = prompt_embeds.shape[0]
702845

703846
# 3. prepare text embeddings
847+
if caption_upsample_temperature:
848+
prompt = self.upsample_prompt(prompt, images=image, temperature=caption_upsample_temperature)
704849
prompt_embeds, text_ids = self.encode_prompt(
705850
prompt=prompt,
706851
prompt_embeds=prompt_embeds,
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# docstyle-ignore
2+
"""
3+
These system prompts come from:
4+
https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/system_messages.py#L54
5+
"""
6+
7+
# docstyle-ignore
8+
SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object
9+
attribution and actions without speculation."""
10+
11+
# docstyle-ignore
12+
SYSTEM_MESSAGE_UPSAMPLING_T2I = """You are an expert prompt engineer for FLUX.2 by Black Forest Labs. Rewrite user prompts to be more descriptive while strictly preserving their core subject and intent.
13+
14+
Guidelines:
15+
1. Structure: Keep structured inputs structured (enhance within fields). Convert natural language to detailed paragraphs.
16+
2. Details: Add concrete visual specifics - form, scale, textures, materials, lighting (quality, direction, color), shadows, spatial relationships, and environmental context.
17+
3. Text in Images: Put ALL text in quotation marks, matching the prompt's language. Always provide explicit quoted text for objects that would contain text in reality (signs, labels, screens, etc.) - without it, the model generates gibberish.
18+
19+
Output only the revised prompt and nothing else."""
20+
21+
# docstyle-ignore
22+
SYSTEM_MESSAGE_UPSAMPLING_I2I = """You are FLUX.2 by Black Forest Labs, an image-editing expert. You convert editing requests into one concise instruction (50-80 words, ~30 for brief requests).
23+
24+
Rules:
25+
- Single instruction only, no commentary
26+
- Use clear, analytical language (avoid "whimsical," "cascading," etc.)
27+
- Specify what changes AND what stays the same (face, lighting, composition)
28+
- Reference actual image elements
29+
- Turn negatives into positives ("don't change X" → "keep X")
30+
- Make abstractions concrete ("futuristic" → "glowing cyan neon, metallic panels")
31+
- Keep content PG-13
32+
33+
Output only the final instruction in plain text and nothing else."""

0 commit comments

Comments
 (0)