From d2b488bc7725a67cf235b93af45766652d5924d3 Mon Sep 17 00:00:00 2001 From: qiujiawei Date: Thu, 2 Jan 2025 18:37:33 +0800 Subject: [PATCH 1/2] ppdiffusers add Flux pipline --- ppdiffusers/ppdiffusers/__init__.py | 5 + ppdiffusers/ppdiffusers/loaders/__init__.py | 2 +- ppdiffusers/ppdiffusers/models/__init__.py | 3 + .../ppdiffusers/models/attention_processor.py | 175 ++++ ppdiffusers/ppdiffusers/models/embeddings.py | 120 ++- .../ppdiffusers/models/normalization.py | 34 + .../ppdiffusers/models/transformer_flux.py | 589 ++++++++++++ ppdiffusers/ppdiffusers/pipelines/__init__.py | 6 + .../ppdiffusers/pipelines/flux/__init__.py | 69 ++ .../pipelines/flux/pipeline_flux.py | 903 ++++++++++++++++++ .../pipelines/flux/pipeline_output.py | 37 + 11 files changed, 1941 insertions(+), 2 deletions(-) create mode 100644 ppdiffusers/ppdiffusers/models/transformer_flux.py create mode 100644 ppdiffusers/ppdiffusers/pipelines/flux/__init__.py create mode 100644 ppdiffusers/ppdiffusers/pipelines/flux/pipeline_flux.py create mode 100644 ppdiffusers/ppdiffusers/pipelines/flux/pipeline_output.py diff --git a/ppdiffusers/ppdiffusers/__init__.py b/ppdiffusers/ppdiffusers/__init__.py index 80e2d6dc6..5778ac3c6 100644 --- a/ppdiffusers/ppdiffusers/__init__.py +++ b/ppdiffusers/ppdiffusers/__init__.py @@ -151,6 +151,8 @@ "SD3MultiControlNetModel", # new add "VCtrlModel", + # new add + "FluxTransformer2DModel", ] ) @@ -277,6 +279,7 @@ "CLIPImageProjection", "CogVideoXPipeline", "CycleDiffusionPipeline", + "FluxPipeline", "IFImg2ImgPipeline", "IFImg2ImgSuperResolutionPipeline", "IFInpaintingPipeline", @@ -506,6 +509,7 @@ ControlNetModel, DiTLLaMA2DModel, DiTLLaMAT2IModel, + FluxTransformer2DModel, GaussianDiffusion, GaussianDiffusion_SDEdit, Kandinsky3UNet, @@ -558,6 +562,7 @@ DDPMPipeline, DiffusionPipeline, DiTPipeline, + FluxPipeline, ImagePipelineOutput, ImgToVideoSDPipeline, KarrasVePipeline, diff --git a/ppdiffusers/ppdiffusers/loaders/__init__.py b/ppdiffusers/ppdiffusers/loaders/__init__.py index b99e9c33b..f8732635d 100644 --- a/ppdiffusers/ppdiffusers/loaders/__init__.py +++ b/ppdiffusers/ppdiffusers/loaders/__init__.py @@ -45,7 +45,7 @@ text_encoder_lora_state_dict, ) from .ip_adapter import IPAdapterMixin - from .lora import LoraLoaderMixin, SD3LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin + from .lora import LoraLoaderMixin, SD3LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, FluxLoraLoaderMixin from .single_file import FromCkptMixin, FromSingleFileMixin from .textual_inversion import TextualInversionLoaderMixin else: diff --git a/ppdiffusers/ppdiffusers/models/__init__.py b/ppdiffusers/ppdiffusers/models/__init__.py index 6427eac96..cdb23b1da 100644 --- a/ppdiffusers/ppdiffusers/models/__init__.py +++ b/ppdiffusers/ppdiffusers/models/__init__.py @@ -62,6 +62,8 @@ # NOTE, new add _import_structure["vctrl"] = ["VCtrlModel"] _import_structure["cogvideox_transformer_3d_vctrl"] = ["CogVideoXTransformer3DVCtrlModel"] + # NOTE, new add + _import_structure["transformer_flux"] = ["FluxTransformer2DModel"] if TYPE_CHECKING or PPDIFFUSERS_SLOW_IMPORT: @@ -95,6 +97,7 @@ from .prior_transformer import PriorTransformer from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel + from .transformer_flux import FluxTransformer2DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_temporal import TransformerTemporalModel from .unet_1d import UNet1DModel diff --git a/ppdiffusers/ppdiffusers/models/attention_processor.py b/ppdiffusers/ppdiffusers/models/attention_processor.py index 8894b1447..34f27106c 100644 --- a/ppdiffusers/ppdiffusers/models/attention_processor.py +++ b/ppdiffusers/ppdiffusers/models/attention_processor.py @@ -2149,6 +2149,179 @@ def __call__( return out +class FluxAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: paddle.Tensor, + encoder_hidden_states: paddle.Tensor = None, + attention_mask: Optional[paddle.Tensor] = None, + image_rotary_emb: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.reshape([batch_size, -1, attn.heads, head_dim]) + key = key.reshape([batch_size, -1, attn.heads, head_dim]) + value = value.reshape([batch_size, -1, attn.heads, head_dim]) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.reshape([batch_size, -1, attn.heads, head_dim]) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.reshape([batch_size, -1, attn.heads, head_dim]) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.reshape([batch_size, -1, attn.heads, head_dim]) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = paddle.concat([encoder_hidden_states_query_proj, query], axis=1) + key = paddle.concat([encoder_hidden_states_key_proj, key], axis=1) + value = paddle.concat([encoder_hidden_states_value_proj, value], axis=1) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention_( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.reshape([batch_size, -1, attn.heads * head_dim]) + hidden_states = hidden_states.astype(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class FusedFluxAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: paddle.FloatTensor, + encoder_hidden_states: paddle.Tensor = None, + attention_mask: Optional[paddle.Tensor] = None, + image_rotary_emb: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + qkv = attn.to_qkv(hidden_states) + # split_size = qkv.shape[-1] // 3 + query, key, value = paddle.split(qkv, 3, axis=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.reshape([batch_size, -1, attn.heads, head_dim]) + key = key.reshape([batch_size, -1, attn.heads, head_dim]) + value = value.reshape([batch_size, -1, attn.heads, head_dim]) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + # `context` projections. + if encoder_hidden_states is not None: + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + # split_size = encoder_qkv.shape[-1] // 3 + ( + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + ) = paddle.split(encoder_qkv, 3, dim=-1) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.reshape([batch_size, -1, attn.heads, head_dim]) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.reshape([batch_size, -1, attn.heads, head_dim]) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.reshape([batch_size, -1, attn.heads, head_dim]) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = paddle.concat([encoder_hidden_states_query_proj, query], axis=1) + key = paddle.concat([encoder_hidden_states_key_proj, key], axis=1) + value = paddle.concat([encoder_hidden_states_value_proj, value], axis=1) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention_(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.reshape([batch_size, -1, attn.heads * head_dim]) + hidden_states = hidden_states.astype(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + class CogVideoXAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on @@ -2347,6 +2520,8 @@ def __call__( CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_5, + FluxAttnProcessor2_0, + FusedFluxAttnProcessor2_0, # deprecated LoRAAttnProcessor, LoRAAttnProcessor2_5, diff --git a/ppdiffusers/ppdiffusers/models/embeddings.py b/ppdiffusers/ppdiffusers/models/embeddings.py index 152422d01..a17c4a698 100644 --- a/ppdiffusers/ppdiffusers/models/embeddings.py +++ b/ppdiffusers/ppdiffusers/models/embeddings.py @@ -13,7 +13,7 @@ # limitations under the License. import math import warnings -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import paddle @@ -567,6 +567,29 @@ def forward(self, timestep, class_labels, hidden_dtype=None): return conditioning +class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, pooled_projection_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward(self, timestep, guidance, pooled_projection): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.astype(dtype=pooled_projection.dtype)) # (N, D) + + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.astype(dtype=pooled_projection.dtype)) # (N, D) + + time_guidance_emb = timesteps_emb + guidance_emb + + pooled_projections = self.text_embedder(pooled_projection) + conditioning = time_guidance_emb + pooled_projections + + return conditioning + class TextTimeEmbedding(nn.Layer): def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): super().__init__() @@ -1130,6 +1153,71 @@ def broadcast(tensors, dim=-1): return freqs_cis +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[np.ndarray, int], + theta: float = 10000.0, + use_real=False, + linear_factor=1.0, + ntk_factor=1.0, + repeat_interleave_real=True, + freqs_dtype=paddle.float32, # paddle.float32, paddle.float64 (flux) +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 + data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + linear_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the context extrapolation. Defaults to 1.0. + ntk_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. + repeat_interleave_real (`bool`, *optional*, defaults to `True`): + If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. + Otherwise, they are concateanted with themselves. + freqs_dtype (`paddle.float32` or `paddle.float64`, *optional*, defaults to `paddle.float32`): + the dtype of the frequency tensor. + Returns: + `paddle.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + assert dim % 2 == 0 + + if isinstance(pos, int): + pos = paddle.arange(pos) + if isinstance(pos, np.ndarray): + pos = paddle.to_tensor(pos) # type: ignore # [S] + + theta = theta * ntk_factor + freqs = ( + 1.0 + / (theta ** (paddle.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) + / linear_factor + ) # [D/2] + freqs = paddle.outer(pos, freqs) # type: ignore # [S, D/2] + if use_real and repeat_interleave_real: + # flux, hunyuan-dit, cogvideox + freqs_cos = freqs.cos().repeat_interleave(2, axis=1).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, axis=1).float() # [S, D] + return freqs_cos, freqs_sin + elif use_real: + # stable audio, allegro + freqs_cos = paddle.concat([freqs.cos(), freqs.cos()], axis=-1).float() # [S, D] + freqs_sin = paddle.concat([freqs.sin(), freqs.sin()], axis=-1).float() # [S, D] + return freqs_cos, freqs_sin + else: + # lumina + freqs_cis = paddle.polar(paddle.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + def apply_rotary_emb( x: paddle.Tensor, freqs_cis: Union[paddle.Tensor, Tuple[paddle.Tensor]], @@ -1169,3 +1257,33 @@ def apply_rotary_emb( freqs_cis = freqs_cis.unsqueeze(axis=2) x_out = paddle.as_real(x=x_rotated * freqs_cis).flatten(start_axis=3) return x_out.astype(dtype=x.dtype) + + +class FluxPosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: paddle.Tensor) -> paddle.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + freqs_dtype = paddle.float32 if is_mps else paddle.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = paddle.concat(cos_out, axis=-1) + freqs_sin = paddle.concat(sin_out, axis=-1) + return freqs_cos, freqs_sin diff --git a/ppdiffusers/ppdiffusers/models/normalization.py b/ppdiffusers/ppdiffusers/models/normalization.py index 4b676b8a2..ba7a8cede 100644 --- a/ppdiffusers/ppdiffusers/models/normalization.py +++ b/ppdiffusers/ppdiffusers/models/normalization.py @@ -162,6 +162,40 @@ def forward( return x, gate_msa, shift_mlp, scale_mlp, gate_mlp +class AdaLayerNormZeroSingle(nn.Layer): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias_attr=bias) + if norm_type == "layer_norm": + norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=False) + self.norm = nn.LayerNorm(embedding_dim, epsilon=1e-6, **norm_elementwise_affine_kwargs) + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) + + def forward( + self, + x: paddle.Tensor, + emb: Optional[paddle.Tensor] = None, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa = emb.chunk(3, axis=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa + + + class AdaLayerNormSingle(nn.Layer): r""" Norm layer adaptive layer norm single (adaLN-single). diff --git a/ppdiffusers/ppdiffusers/models/transformer_flux.py b/ppdiffusers/ppdiffusers/models/transformer_flux.py new file mode 100644 index 000000000..8d44ef08e --- /dev/null +++ b/ppdiffusers/ppdiffusers/models/transformer_flux.py @@ -0,0 +1,589 @@ +# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from ..configuration_utils import ConfigMixin, register_to_config +# from ..loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin +from ..models.attention import FeedForward +from ..models.attention_processor import ( + Attention, + AttentionProcessor, + FluxAttnProcessor2_0, + # FluxAttnProcessor2_0_NPU, + FusedFluxAttnProcessor2_0, +) +from ..models.modeling_utils import ModelMixin +from ..models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ..utils.paddle_utils import maybe_allow_in_graph +from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed +from .modeling_outputs import Transformer2DModelOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@maybe_allow_in_graph +class FluxSingleTransformerBlock(nn.Layer): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) + self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + + processor = FluxAttnProcessor2_0() + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=processor, + qk_norm="rms_norm", + eps=1e-6, + pre_only=True, + ) + + def forward( + self, + hidden_states: paddle.Tensor, + temb: paddle.Tensor, + image_rotary_emb=None, + joint_attention_kwargs=None, + ): + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = paddle.concat([attn_output, mlp_hidden_states], axis=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if hidden_states.dtype == paddle.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return hidden_states + + +@maybe_allow_in_graph +class FluxTransformerBlock(nn.Layer): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://arxiv.org/abs/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + + self.norm1_context = AdaLayerNormZero(dim) + + if hasattr(F, "scaled_dot_product_attention"): + processor = FluxAttnProcessor2_0() + else: + raise ValueError( + "The current PyTorch version does not support the `scaled_dot_product_attention` function." + ) + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=processor, + qk_norm=qk_norm, + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, epsilon=1e-06, weight_attr=False, bias_attr=False) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(dim, epsilon=1e-06, weight_attr=False, bias_attr=False) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def forward( + self, + hidden_states: paddle.Tensor, + encoder_hidden_states: paddle.Tensor, + temb: paddle.Tensor, + image_rotary_emb=None, + joint_attention_kwargs=None, + ): + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + joint_attention_kwargs = joint_attention_kwargs or {} + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == paddle.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class FluxTransformer2DModel( + ModelMixin, ConfigMixin # , PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin +): + """ + The Transformer model introduced in Flux. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Parameters: + patch_size (`int`): Patch size to turn the input data into small patches. + in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. + num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. + joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. + guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + out_channels: Optional[int] = None, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 4096, + pooled_projection_dim: int = 768, + guidance_embeds: bool = False, + axes_dims_rope: Tuple[int] = (16, 56, 56), + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) + + text_time_guidance_cls = ( + CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings + ) + self.time_text_embed = text_time_guidance_cls( + embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim + ) + + self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) + self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim) + + self.transformer_blocks = nn.LayerList( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for i in range(self.config.num_layers) + ] + ) + + self.single_transformer_blocks = nn.LayerList( + [ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for i in range(self.config.num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: paddle.nn.Layer, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: paddle.nn.Layer, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedFluxAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: paddle.Tensor, + encoder_hidden_states: paddle.Tensor = None, + pooled_projections: paddle.Tensor = None, + timestep: paddle.Tensor = None, + img_ids: paddle.Tensor = None, + txt_ids: paddle.Tensor = None, + guidance: paddle.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + return_dict: bool = True, + controlnet_blocks_repeat: bool = False, + ) -> Union[paddle.Tensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`paddle.Tensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`paddle.Tensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`paddle.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `paddle.Tensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `paddle.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + else: + guidance = None + + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = paddle.concat((txt_ids, img_ids), axis=0) + image_rotary_emb = self.pos_embed(ids) + + if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: + ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") + ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) + joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) + + for index_block, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = paddle.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) + interval_control = int(np.ceil(interval_control)) + # For Xlabs ControlNet. + if controlnet_blocks_repeat: + hidden_states = ( + hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] + ) + else: + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + hidden_states = paddle.concat([encoder_hidden_states, hidden_states], axis=1) + + for index_block, block in enumerate(self.single_transformer_blocks): + if self.is_training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = paddle.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_single_block_samples is not None: + interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( + hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + controlnet_single_block_samples[index_block // interval_control] + ) + + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/ppdiffusers/ppdiffusers/pipelines/__init__.py b/ppdiffusers/ppdiffusers/pipelines/__init__.py index 1437a0be5..377697211 100644 --- a/ppdiffusers/ppdiffusers/pipelines/__init__.py +++ b/ppdiffusers/ppdiffusers/pipelines/__init__.py @@ -132,6 +132,9 @@ "IFPipeline", "IFSuperResolutionPipeline", ] + _import_structure["flux"] = [ + "FluxPipeline", + ] _import_structure["kandinsky"] = [ "KandinskyCombinedPipeline", "KandinskyImg2ImgCombinedPipeline", @@ -407,6 +410,9 @@ IFPipeline, IFSuperResolutionPipeline, ) + from .flux import ( + FluxPipeline, + ) from .img_to_video import ImgToVideoSDPipeline, ImgToVideoSDPipelineOutput from .kandinsky import ( KandinskyCombinedPipeline, diff --git a/ppdiffusers/ppdiffusers/pipelines/flux/__init__.py b/ppdiffusers/ppdiffusers/pipelines/flux/__init__.py new file mode 100644 index 000000000..aef0f8a4b --- /dev/null +++ b/ppdiffusers/ppdiffusers/pipelines/flux/__init__.py @@ -0,0 +1,69 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_paddle_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["FluxPipelineOutput", "FluxPriorReduxPipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + # _import_structure["modeling_flux"] = ["ReduxImageEncoder"] + _import_structure["pipeline_flux"] = ["FluxPipeline"] + # _import_structure["pipeline_flux_control"] = ["FluxControlPipeline"] + # _import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"] + # _import_structure["pipeline_flux_control_inpaint"] = ["FluxControlInpaintPipeline"] + # _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"] + # _import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"] + # _import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"] + # _import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"] + # _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"] + # _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] + # _import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_paddle_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_paddle_and_paddlenlp_objects import * # noqa F403 + else: + # from .modeling_flux import ReduxImageEncoder + from .pipeline_flux import FluxPipeline + # from .pipeline_flux_control import FluxControlPipeline + # from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline + # from .pipeline_flux_control_inpaint import FluxControlInpaintPipeline + # from .pipeline_flux_controlnet import FluxControlNetPipeline + # from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline + # from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline + # from .pipeline_flux_fill import FluxFillPipeline + # from .pipeline_flux_img2img import FluxImg2ImgPipeline + # from .pipeline_flux_inpaint import FluxInpaintPipeline + # from .pipeline_flux_prior_redux import FluxPriorReduxPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/ppdiffusers/ppdiffusers/pipelines/flux/pipeline_flux.py b/ppdiffusers/ppdiffusers/pipelines/flux/pipeline_flux.py new file mode 100644 index 000000000..ca2a673aa --- /dev/null +++ b/ppdiffusers/ppdiffusers/pipelines/flux/pipeline_flux.py @@ -0,0 +1,903 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import paddle +from ppdiffusers.transformers import ( # T5TokenizerFast, + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + T5EncoderModel, + T5Tokenizer +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, TextualInversionLoaderMixin # FluxIPAdapterMixin, FluxLoraLoaderMixin +from ...models.autoencoder_kl import AutoencoderKL +from ...models.transformer_flux import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.paddle_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + +try: + # paddle.incubate.jit.inference is available in paddle develop but not in paddle 3.0beta, so we add a try except. + from paddle.incubate.jit import is_inference_mode +except: + + def is_inference_mode(func): + return False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import paddle + >>> from ppdiffusers import FluxPipeline + >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", paddle_dtype=paddle.float16) + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] + >>> image.save("flux.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[paddle.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + +# FluxLoraLoaderMixin, FluxIPAdapterMixin +class FluxPipeline( + DiffusionPipeline, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + The Flux pipeline for text-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5Tokenizer, + transformer: FluxTransformer2DModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + dtype: Optional[paddle.dtype] = None, + ): + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not paddle.equal_all(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids, output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.astype(dtype=dtype) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + ): + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not paddle.equal_all(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids, output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.astype(dtype=self.text_encoder.dtype) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + num_images_per_prompt: int = 1, + prompt_embeds: Optional[paddle.FloatTensor] = None, + pooled_prompt_embeds: Optional[paddle.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`paddle.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`paddle.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + # TODO + # if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + # self._lora_scale = lora_scale + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = paddle.zeros(prompt_embeds.shape[1], 3).astype(dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + def encode_image(self, image, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, paddle.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.astype(dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers + ): + single_image_embeds = self.encode_image(single_ip_adapter_image, 1) + + image_embeds.append(single_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = paddle.concat([single_image_embeds] * num_images_per_prompt, axis=0) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, dtype): + latent_image_ids = paddle.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + paddle.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + paddle.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.astype(dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, dtype) + return latents.astype(dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, dtype) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @paddle.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + true_cfg_scale: float = 1.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None, + latents: Optional[paddle.Tensor] = None, + prompt_embeds: Optional[paddle.Tensor] = None, + pooled_prompt_embeds: Optional[paddle.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[paddle.Tensor]] = None, + negative_ip_adapter_image: Optional[PipelineImageInput] = None, + negative_ip_adapter_image_embeds: Optional[List[paddle.Tensor]] = None, + negative_prompt_embeds: Optional[paddle.Tensor] = None, + negative_pooled_prompt_embeds: Optional[paddle.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`paddle.Generator` or `List[paddle.Generator]`, *optional*): + One or a list of [paddle generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`paddle.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`paddle.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`paddle.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[paddle.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`List[paddle.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + _, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = paddle.full([1], guidance_scale, dtype=paddle.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + batch_size * num_images_per_prompt, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).astype(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/ppdiffusers/ppdiffusers/pipelines/flux/pipeline_output.py b/ppdiffusers/ppdiffusers/pipelines/flux/pipeline_output.py new file mode 100644 index 000000000..8a8ea2382 --- /dev/null +++ b/ppdiffusers/ppdiffusers/pipelines/flux/pipeline_output.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image +import paddle + +from ...utils import BaseOutput + + +@dataclass +class FluxPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +@dataclass +class FluxPriorReduxPipelineOutput(BaseOutput): + """ + Output class for Flux Prior Redux pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + prompt_embeds: paddle.Tensor + pooled_prompt_embeds: paddle.Tensor From e435809b02f381273c72d35691f77b5565ba114c Mon Sep 17 00:00:00 2001 From: qiujiawei Date: Wed, 15 Jan 2025 17:07:23 +0800 Subject: [PATCH 2/2] ppdiffusers fix bugs for flux --- .../ppdiffusers/models/attention_processor.py | 41 ++-- ppdiffusers/ppdiffusers/models/embeddings.py | 24 +- .../ppdiffusers/models/normalization.py | 6 +- .../ppdiffusers/models/transformer_flux.py | 6 +- .../ppdiffusers/pipelines/flux/__init__.py | 6 +- .../pipelines/flux/pipeline_flux.py | 43 ++-- .../scheduling_flow_match_euler_discrete.py | 223 ++++++++++++++---- 7 files changed, 252 insertions(+), 97 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/attention_processor.py b/ppdiffusers/ppdiffusers/models/attention_processor.py index 34f27106c..1fc28cbce 100644 --- a/ppdiffusers/ppdiffusers/models/attention_processor.py +++ b/ppdiffusers/ppdiffusers/models/attention_processor.py @@ -106,6 +106,7 @@ def __init__( processor: Optional["AttnProcessor"] = None, out_dim: int = None, context_pre_only=None, + pre_only=False, elementwise_affine: bool = True, ): super().__init__() @@ -124,6 +125,7 @@ def __init__( self.dropout = dropout self.out_dim = out_dim if out_dim is not None else query_dim self.context_pre_only = context_pre_only + self.pre_only = pre_only # we make use of this private variable to know whether this class is loaded # with an deprecated state dict so that we can convert it on the fly @@ -224,10 +226,12 @@ def __init__( self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) if self.context_pre_only is not None: self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) - - self.to_out = nn.LayerList([]) - self.to_out.append(linear_cls(self.inner_dim, query_dim, bias_attr=out_bias)) - self.to_out.append(nn.Dropout(dropout)) + if not self.pre_only: + self.to_out = nn.LayerList([]) + self.to_out.append(linear_cls(self.inner_dim, query_dim, bias_attr=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + else: + self.to_out = None if self.context_pre_only is not None and not self.context_pre_only: self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias_attr=out_bias) @@ -2174,9 +2178,9 @@ def __call__( inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads - query = query.reshape([batch_size, -1, attn.heads, head_dim]) - key = key.reshape([batch_size, -1, attn.heads, head_dim]) - value = value.reshape([batch_size, -1, attn.heads, head_dim]) + query = query.reshape([batch_size, -1, attn.heads, head_dim]).transpose([0, 2, 1, 3]) + key = key.reshape([batch_size, -1, attn.heads, head_dim]).transpose([0, 2, 1, 3]) + value = value.reshape([batch_size, -1, attn.heads, head_dim]).transpose([0, 2, 1, 3]) if attn.norm_q is not None: query = attn.norm_q(query) @@ -2190,9 +2194,9 @@ def __call__( encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.reshape([batch_size, -1, attn.heads, head_dim]) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.reshape([batch_size, -1, attn.heads, head_dim]) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.reshape([batch_size, -1, attn.heads, head_dim]) + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.reshape([batch_size, -1, attn.heads, head_dim]).transpose([0, 2, 1, 3]) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.reshape([batch_size, -1, attn.heads, head_dim]).transpose([0, 2, 1, 3]) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.reshape([batch_size, -1, attn.heads, head_dim]).transpose([0, 2, 1, 3]) if attn.norm_added_q is not None: encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) @@ -2200,18 +2204,23 @@ def __call__( encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) # attention - query = paddle.concat([encoder_hidden_states_query_proj, query], axis=1) - key = paddle.concat([encoder_hidden_states_key_proj, key], axis=1) - value = paddle.concat([encoder_hidden_states_value_proj, value], axis=1) + query = paddle.concat([encoder_hidden_states_query_proj, query], axis=2) + key = paddle.concat([encoder_hidden_states_key_proj, key], axis=2) + value = paddle.concat([encoder_hidden_states_value_proj, value], axis=2) if image_rotary_emb is not None: from .embeddings import apply_rotary_emb query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - + hidden_states = F.scaled_dot_product_attention_( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query.transpose([0, 2, 1, 3]), + key.transpose([0, 2, 1, 3]), + value.transpose([0, 2, 1, 3]), + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False ) hidden_states = hidden_states.reshape([batch_size, -1, attn.heads * head_dim]) hidden_states = hidden_states.astype(query.dtype) @@ -2246,7 +2255,7 @@ def __init__(self): def __call__( self, attn: Attention, - hidden_states: paddle.FloatTensor, + hidden_states: paddle.Tensor, encoder_hidden_states: paddle.Tensor = None, attention_mask: Optional[paddle.Tensor] = None, image_rotary_emb: Optional[paddle.Tensor] = None, diff --git a/ppdiffusers/ppdiffusers/models/embeddings.py b/ppdiffusers/ppdiffusers/models/embeddings.py index a17c4a698..dbfd8ac88 100644 --- a/ppdiffusers/ppdiffusers/models/embeddings.py +++ b/ppdiffusers/ppdiffusers/models/embeddings.py @@ -567,7 +567,7 @@ def forward(self, timestep, class_labels, hidden_dtype=None): return conditioning -class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): +class CombinedTimestepGuidanceTextProjEmbeddings(nn.Layer): def __init__(self, embedding_dim, pooled_projection_dim): super().__init__() @@ -1199,22 +1199,24 @@ def get_1d_rotary_pos_embed( theta = theta * ntk_factor freqs = ( 1.0 - / (theta ** (paddle.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) + / (theta ** (paddle.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor ) # [D/2] + pos = pos.astype(freqs_dtype) freqs = paddle.outer(pos, freqs) # type: ignore # [S, D/2] if use_real and repeat_interleave_real: # flux, hunyuan-dit, cogvideox - freqs_cos = freqs.cos().repeat_interleave(2, axis=1).float() # [S, D] - freqs_sin = freqs.sin().repeat_interleave(2, axis=1).float() # [S, D] + freqs_cos = freqs.cos().repeat_interleave(2, axis=1).astype(dtype="float32") # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, axis=1).astype(dtype="float32") # [S, D] return freqs_cos, freqs_sin elif use_real: # stable audio, allegro - freqs_cos = paddle.concat([freqs.cos(), freqs.cos()], axis=-1).float() # [S, D] - freqs_sin = paddle.concat([freqs.sin(), freqs.sin()], axis=-1).float() # [S, D] + freqs_cos = paddle.concat([freqs.cos(), freqs.cos()], axis=-1).astype(dtype="float32") # [S, D] + freqs_sin = paddle.concat([freqs.sin(), freqs.sin()], axis=-1).astype(dtype="float32") # [S, D] return freqs_cos, freqs_sin else: # lumina + # paddle.complex(abs * paddle.cos(angle), abs * paddle.sin(angle)) freqs_cis = paddle.polar(paddle.ones_like(freqs), freqs) # complex64 # [S, D/2] return freqs_cis @@ -1250,7 +1252,7 @@ def apply_rotary_emb( x_rotated = paddle.concat(x=[-x_imag, x_real], axis=-1) else: raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") - out = (x.astype(dtype="float32") * cos + x_rotated.astype(dtype="float32") * sin).to(x.dtype) + out = (x.astype(dtype="float32") * cos + x_rotated.astype(dtype="float32") * sin).astype(x.dtype) return out else: x_rotated = paddle.as_complex(x=x.astype(dtype="float32").reshape(*tuple(x.shape)[:-1], -1, 2)) @@ -1259,7 +1261,7 @@ def apply_rotary_emb( return x_out.astype(dtype=x.dtype) -class FluxPosEmbed(nn.Module): +class FluxPosEmbed(nn.Layer): # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 def __init__(self, theta: int, axes_dim: List[int]): super().__init__() @@ -1270,8 +1272,10 @@ def forward(self, ids: paddle.Tensor) -> paddle.Tensor: n_axes = ids.shape[-1] cos_out = [] sin_out = [] - pos = ids.float() - is_mps = ids.device.type == "mps" + pos = ids.astype('float32') + # TODO + # is_mps = ids.device.type == "mps" + is_mps = False freqs_dtype = paddle.float32 if is_mps else paddle.float64 for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( diff --git a/ppdiffusers/ppdiffusers/models/normalization.py b/ppdiffusers/ppdiffusers/models/normalization.py index ba7a8cede..09c43da42 100644 --- a/ppdiffusers/ppdiffusers/models/normalization.py +++ b/ppdiffusers/ppdiffusers/models/normalization.py @@ -174,7 +174,7 @@ class AdaLayerNormZeroSingle(nn.Layer): def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): super().__init__() - self.silu = nn.SiLU() + self.silu = nn.Silu() self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias_attr=bias) if norm_type == "layer_norm": norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=False) @@ -326,13 +326,13 @@ def __init__(self, dim, epsilon: float, elementwise_affine: bool = True): else: self.weight = None - def forward(self, hidden_states, begin_norm_axis=2): + def forward(self, hidden_states, begin_norm_axis=None): return paddle.incubate.nn.functional.fused_rms_norm( x=hidden_states, norm_weight=self.weight, norm_bias=None, epsilon=self.epsilon, - begin_norm_axis=begin_norm_axis, + begin_norm_axis=len(hidden_states.shape)-1 if begin_norm_axis is None else begin_norm_axis, )[0] diff --git a/ppdiffusers/ppdiffusers/models/transformer_flux.py b/ppdiffusers/ppdiffusers/models/transformer_flux.py index 8d44ef08e..a3a2ad5a0 100644 --- a/ppdiffusers/ppdiffusers/models/transformer_flux.py +++ b/ppdiffusers/ppdiffusers/models/transformer_flux.py @@ -62,7 +62,7 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): self.norm = AdaLayerNormZeroSingle(dim) self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) - self.act_mlp = nn.GELU(approximate="tanh") + self.act_mlp = nn.GELU(approximate=True) self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) processor = FluxAttnProcessor2_0() @@ -292,7 +292,7 @@ def __init__( ) self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels) self.gradient_checkpointing = False @@ -537,7 +537,7 @@ def custom_forward(*inputs): hidden_states = paddle.concat([encoder_hidden_states, hidden_states], axis=1) for index_block, block in enumerate(self.single_transformer_blocks): - if self.is_training and self.gradient_checkpointing: + if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/ppdiffusers/ppdiffusers/pipelines/flux/__init__.py b/ppdiffusers/ppdiffusers/pipelines/flux/__init__.py index aef0f8a4b..d20001c58 100644 --- a/ppdiffusers/ppdiffusers/pipelines/flux/__init__.py +++ b/ppdiffusers/ppdiffusers/pipelines/flux/__init__.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING from ...utils import ( - DIFFUSERS_SLOW_IMPORT, + PPDIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, @@ -15,7 +15,7 @@ _import_structure = {"pipeline_output": ["FluxPipelineOutput", "FluxPriorReduxPipelineOutput"]} try: - if not (is_transformers_available() and is_torch_available()): + if not (is_transformers_available() and is_paddle_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 @@ -34,7 +34,7 @@ # _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"] # _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] # _import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"] -if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: +if TYPE_CHECKING or PPDIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_paddle_available()): raise OptionalDependencyNotAvailable() diff --git a/ppdiffusers/ppdiffusers/pipelines/flux/pipeline_flux.py b/ppdiffusers/ppdiffusers/pipelines/flux/pipeline_flux.py index ca2a673aa..695b004e2 100644 --- a/ppdiffusers/ppdiffusers/pipelines/flux/pipeline_flux.py +++ b/ppdiffusers/ppdiffusers/pipelines/flux/pipeline_flux.py @@ -229,10 +229,10 @@ def _get_t5_prompt_embeds( truncation=True, return_length=False, return_overflowing_tokens=False, - return_tensors="pt", + return_tensors="pd", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pd").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not paddle.equal_all(text_input_ids, untruncated_ids): removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) @@ -249,8 +249,8 @@ def _get_t5_prompt_embeds( _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.tile([1, num_images_per_prompt, 1]) + prompt_embeds = prompt_embeds.reshape([batch_size * num_images_per_prompt, seq_len, -1]) return prompt_embeds @@ -273,11 +273,11 @@ def _get_clip_prompt_embeds( truncation=True, return_overflowing_tokens=False, return_length=False, - return_tensors="pt", + return_tensors="pd", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pd").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not paddle.equal_all(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( @@ -291,8 +291,8 @@ def _get_clip_prompt_embeds( prompt_embeds = prompt_embeds.astype(dtype=self.text_encoder.dtype) # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + prompt_embeds = prompt_embeds.tile([1, num_images_per_prompt]) + prompt_embeds = prompt_embeds.reshape([batch_size * num_images_per_prompt, -1]) return prompt_embeds @@ -301,8 +301,8 @@ def encode_prompt( prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], num_images_per_prompt: int = 1, - prompt_embeds: Optional[paddle.FloatTensor] = None, - pooled_prompt_embeds: Optional[paddle.FloatTensor] = None, + prompt_embeds: Optional[paddle.Tensor] = None, + pooled_prompt_embeds: Optional[paddle.Tensor] = None, max_sequence_length: int = 512, lora_scale: Optional[float] = None, ): @@ -349,7 +349,7 @@ def encode_prompt( ) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = paddle.zeros(prompt_embeds.shape[1], 3).astype(dtype=dtype) + text_ids = paddle.zeros([prompt_embeds.shape[1], 3]).astype(dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -357,11 +357,11 @@ def encode_image(self, image, num_images_per_prompt): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, paddle.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values + image = self.feature_extractor(image, return_tensors="pd").pixel_values image = image.astype(dtype=dtype) image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, axis=0) return image_embeds def prepare_ip_adapter_image_embeds( @@ -473,23 +473,23 @@ def check_inputs( @staticmethod def _prepare_latent_image_ids(batch_size, height, width, dtype): - latent_image_ids = paddle.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + paddle.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + paddle.arange(width)[None, :] + latent_image_ids = paddle.zeros([height, width, 3], dtype=dtype) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + paddle.arange(height, dtype=dtype)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + paddle.arange(width, dtype=dtype)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels + [latent_image_id_height * latent_image_id_width, latent_image_id_channels] ) return latent_image_ids.astype(dtype=dtype) @staticmethod def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.reshape([batch_size, num_channels_latents, height // 2, 2, width // 2, 2]) latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + latents = latents.reshape([batch_size, (height // 2) * (width // 2), num_channels_latents * 4]) return latents @@ -502,10 +502,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor): height = 2 * (int(height) // (vae_scale_factor * 2)) width = 2 * (int(width) // (vae_scale_factor * 2)) - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.reshape([batch_size, height // 2, width // 2, channels // 4, 2, 2]) latents = latents.permute(0, 3, 1, 4, 2, 5) - latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + latents = latents.reshape([batch_size, channels // (2 * 2), height, width]) return latents @@ -613,6 +613,7 @@ def __call__( negative_ip_adapter_image_embeds: Optional[List[paddle.Tensor]] = None, negative_prompt_embeds: Optional[paddle.Tensor] = None, negative_pooled_prompt_embeds: Optional[paddle.Tensor] = None, + text_ids: Optional[paddle.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, diff --git a/ppdiffusers/ppdiffusers/schedulers/scheduling_flow_match_euler_discrete.py b/ppdiffusers/ppdiffusers/schedulers/scheduling_flow_match_euler_discrete.py index 56ea8074b..40c9bbbac 100644 --- a/ppdiffusers/ppdiffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/ppdiffusers/ppdiffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -13,17 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import paddle from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, logging -from ..utils.paddle_utils import randn_tensor +from ..utils import BaseOutput, is_scipy_available, logging from .scheduling_utils import SchedulerMixin +if is_scipy_available(): + import scipy.stats + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -66,22 +69,49 @@ def __init__( self, num_train_timesteps: int = 1000, shift: float = 1.0, + use_dynamic_shifting=False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, + shift_terminal: Optional[float] = None, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, ): + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() timesteps = paddle.to_tensor(timesteps).astype(dtype=paddle.float32) sigmas = timesteps / num_train_timesteps - sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) self.timesteps = sigmas * num_train_timesteps self._step_index = None self._begin_index = None + self._shift = shift + self.sigmas = sigmas # to avoid too much CPU/GPU communication self.sigma_min = self.sigmas[-1].item() self.sigma_max = self.sigmas[0].item() + @property + def shift(self): + """ + The value used for shifting. + """ + return self._shift + @property def step_index(self): """ @@ -107,6 +137,9 @@ def set_begin_index(self, begin_index: int = 0): """ self._begin_index = begin_index + def set_shift(self, shift: float): + self._shift = shift + def scale_noise( self, sample: paddle.Tensor, @@ -126,10 +159,26 @@ def scale_noise( `paddle.Tensor`: A scaled input sample. """ - if self.step_index is None: - self._init_step_index(timestep) + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.astype(dtype=sample.dtype) + + # TODO + schedule_timesteps = self.timesteps + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timestep.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timestep.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) - sigma = self.sigmas[self.step_index] sample = sigma * noise + (1.0 - sigma) * sample return sample @@ -137,7 +186,36 @@ def scale_noise( def _sigma_to_t(self, sigma): return sigma * self.config.num_train_timesteps - def set_timesteps(self, num_inference_steps: int): + def time_shift(self, mu: float, sigma: float, t: paddle.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def stretch_shift_to_terminal(self, t: paddle.Tensor) -> paddle.Tensor: + r""" + Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config + value. + + Reference: + https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 + + Args: + t (`torch.Tensor`): + A tensor of timesteps to be stretched and shifted. + + Returns: + `torch.Tensor`: + A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`. + """ + one_minus_z = 1 - t + scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal) + stretched_t = 1 - (one_minus_z / scale_factor) + return stretched_t + + def set_timesteps( + self, + num_inference_steps: int = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -145,23 +223,56 @@ def set_timesteps(self, num_inference_steps: int): num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. """ + if self.config.use_dynamic_shifting and mu is None: + raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + + if sigmas is None: + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) + + sigmas = timesteps / self.config.num_train_timesteps + else: + sigmas = np.array(sigmas).astype(np.float32) + num_inference_steps = len(sigmas) self.num_inference_steps = num_inference_steps + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + + if self.config.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.linspace( self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps ) - - sigmas = timesteps / self.config.num_train_timesteps - sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + sigmas = paddle.to_tensor(sigmas).astype(dtype=paddle.float32) - timesteps = sigmas * self.config.num_train_timesteps - self.timesteps = timesteps - self.sigmas = paddle.concat([sigmas, paddle.zeros(shape=[1])]) + if self.config.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.config.num_train_timesteps + sigmas = paddle.concat([sigmas, paddle.ones(1)]) + else: + sigmas = paddle.concat([sigmas, paddle.zeros(1)]) + + self.timesteps = timesteps + self.sigmas = sigmas self._step_index = None self._begin_index = None - + def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -244,43 +355,73 @@ def step( sample = sample.cast(paddle.float32) sigma = self.sigmas[self.step_index] - # NOTE:(changwenbin & zhoukangkang) when s_churn == 0.0,not need to compute gamma, Can avoid cuda synchronization - if s_churn == 0.0: - gamma = 0.0 - else: - gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + sigma_next = self.sigmas[self.step_index + 1] + + prev_sample = sample + (sigma_next - sigma) * model_output + # Cast sample back to model compatible dtype + prev_sample = prev_sample.cast(model_output.dtype) - noise = randn_tensor(model_output.shape, dtype=model_output.dtype, generator=generator) + # upon completion increase step index by one + self._step_index += 1 - eps = noise * s_noise - sigma_hat = sigma * (gamma + 1) + if not return_dict: + return (prev_sample,) - if gamma > 0: - sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) - # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - # NOTE: "original_sample" should not be an expected prediction_type but is left in for - # backwards compatibility + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: paddle.Tensor, num_inference_steps: int) -> paddle.Tensor: + """Constructs an exponential noise schedule.""" - # if self.config.prediction_type == "vector_field": + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None - denoised = sample - model_output * sigma - # 2. Convert to an ODE derivative - derivative = (sample - denoised) / sigma_hat + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None - dt = self.sigmas[self.step_index + 1] - sigma_hat + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() - prev_sample = sample + derivative * dt - # Cast sample back to model compatible dtype - prev_sample = prev_sample.cast(model_output.dtype) + sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) + return sigmas - # upon completion increase step index by one - self._step_index += 1 + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: paddle.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> paddle.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" - if not return_dict: - return (prev_sample,) + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None - return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas def __len__(self): return self.config.num_train_timesteps