Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@ enable_profiler: False
skip_first_n_steps_for_profiler: 5
profiler_steps: 10

# Enable JAX named scopes for detailed profiling and debugging
# When enabled, adds named scopes around key operations in transformer and attention layers
enable_jax_named_scopes: False

# Generation parameters
prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
Expand Down
45 changes: 32 additions & 13 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import functools
import math
from typing import Optional, Callable, Tuple
Expand Down Expand Up @@ -805,6 +806,7 @@ def __init__(
is_self_attention: bool = True,
mask_padding_tokens: bool = True,
residual_checkpoint_name: str | None = None,
enable_jax_named_scopes: bool = False,
):
if attention_kernel == "cudnn_flash_te":
raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}")
Expand All @@ -820,6 +822,7 @@ def __init__(
self.key_axis_names = key_axis_names
self.value_axis_names = value_axis_names
self.out_axis_names = out_axis_names
self.enable_jax_named_scopes = enable_jax_named_scopes

if is_self_attention:
axis_names_q = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_Q_LENGTH, D_KV)
Expand Down Expand Up @@ -952,6 +955,10 @@ def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tup

return xq_out, xk_out

def conditional_named_scope(self, name: str):
"""Return a JAX named scope if enabled, otherwise a null context."""
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()

def __call__(
self,
hidden_states: jax.Array,
Expand All @@ -966,29 +973,41 @@ def __call__(
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states

query_proj = self.query(hidden_states)
key_proj = self.key(encoder_hidden_states)
value_proj = self.value(encoder_hidden_states)
with self.conditional_named_scope("attn_qkv_proj"):
with self.conditional_named_scope("proj_query"):
query_proj = self.query(hidden_states)
with self.conditional_named_scope("proj_key"):
key_proj = self.key(encoder_hidden_states)
with self.conditional_named_scope("proj_value"):
value_proj = self.value(encoder_hidden_states)

if self.qk_norm:
query_proj = self.norm_q(query_proj)
key_proj = self.norm_k(key_proj)
with self.conditional_named_scope("attn_q_norm"):
query_proj = self.norm_q(query_proj)
with self.conditional_named_scope("attn_k_norm"):
key_proj = self.norm_k(key_proj)

if rotary_emb is not None:
query_proj = _unflatten_heads(query_proj, self.heads)
key_proj = _unflatten_heads(key_proj, self.heads)
value_proj = _unflatten_heads(value_proj, self.heads)
# output of _unflatten_heads Batch, heads, seq_len, head_dim
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
with self.conditional_named_scope("attn_rope"):
query_proj = _unflatten_heads(query_proj, self.heads)
key_proj = _unflatten_heads(key_proj, self.heads)
value_proj = _unflatten_heads(value_proj, self.heads)
# output of _unflatten_heads Batch, heads, seq_len, head_dim
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)

query_proj = checkpoint_name(query_proj, "query_proj")
key_proj = checkpoint_name(key_proj, "key_proj")
value_proj = checkpoint_name(value_proj, "value_proj")
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)

with self.conditional_named_scope("attn_compute"):
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)

attn_output = attn_output.astype(dtype=dtype)
attn_output = checkpoint_name(attn_output, "attn_output")
hidden_states = self.proj_attn(attn_output)
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)

with self.conditional_named_scope("attn_out_proj"):
hidden_states = self.proj_attn(attn_output)
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
return hidden_states


Expand Down
142 changes: 92 additions & 50 deletions src/maxdiffusion/models/wan/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

from typing import Tuple, Optional, Dict, Union, Any
import contextlib
import math
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -205,11 +206,13 @@ def __init__(
dtype: jnp.dtype = jnp.float32,
weights_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.Precision = None,
enable_jax_named_scopes: bool = False,
):
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim

self.enable_jax_named_scopes = enable_jax_named_scopes
self.act_fn = nnx.data(None)
if activation_fn == "gelu-approximate":
self.act_fn = ApproximateGELU(
Expand All @@ -236,11 +239,17 @@ def __init__(
),
)

def conditional_named_scope(self, name: str):
"""Return a JAX named scope if enabled, otherwise a null context."""
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()

def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array:
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
hidden_states = checkpoint_name(hidden_states, "ffn_activation")
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
return self.proj_out(hidden_states) # output is (4, 75600, 5120)
with self.conditional_named_scope("mlp_up_proj_and_gelu"):
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
hidden_states = checkpoint_name(hidden_states, "ffn_activation")
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
with self.conditional_named_scope("mlp_down_proj"):
return self.proj_out(hidden_states) # output is (4, 75600, 5120)


class WanTransformerBlock(nnx.Module):
Expand All @@ -265,8 +274,11 @@ def __init__(
attention: str = "dot_product",
dropout: float = 0.0,
mask_padding_tokens: bool = True,
enable_jax_named_scopes: bool = False,
):

self.enable_jax_named_scopes = enable_jax_named_scopes

# 1. Self-attention
self.norm1 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False)
self.attn1 = FlaxWanAttention(
Expand All @@ -287,6 +299,7 @@ def __init__(
is_self_attention=True,
mask_padding_tokens=mask_padding_tokens,
residual_checkpoint_name="self_attn",
enable_jax_named_scopes=enable_jax_named_scopes,
)

# 1. Cross-attention
Expand All @@ -308,6 +321,7 @@ def __init__(
is_self_attention=False,
mask_padding_tokens=mask_padding_tokens,
residual_checkpoint_name="cross_attn",
enable_jax_named_scopes=enable_jax_named_scopes,
)
assert cross_attn_norm is True
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
Expand All @@ -322,6 +336,7 @@ def __init__(
weights_dtype=weights_dtype,
precision=precision,
dropout=dropout,
enable_jax_named_scopes=enable_jax_named_scopes,
)
self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False)

Expand All @@ -330,6 +345,10 @@ def __init__(
jax.random.normal(key, (1, 6, dim)) / dim**0.5,
)

def conditional_named_scope(self, name: str):
"""Return a JAX named scope if enabled, otherwise a null context."""
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()

def __call__(
self,
hidden_states: jax.Array,
Expand All @@ -339,45 +358,59 @@ def __call__(
deterministic: bool = True,
rngs: nnx.Rngs = None,
):
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
)
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
hidden_states = checkpoint_name(hidden_states, "hidden_states")
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))

# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
hidden_states.dtype
)
attn_output = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
rotary_emb=rotary_emb,
deterministic=deterministic,
rngs=rngs,
)
hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype)

# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype)
attn_output = self.attn2(
hidden_states=norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
deterministic=deterministic,
rngs=rngs,
)
hidden_states = hidden_states + attn_output

# 3. Feed-forward
norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(
hidden_states.dtype
)
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(
hidden_states.dtype
)
return hidden_states
with self.conditional_named_scope("transformer_block"):
with self.conditional_named_scope("adaln"):
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
)
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
hidden_states = checkpoint_name(hidden_states, "hidden_states")
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))

# 1. Self-attention
with self.conditional_named_scope("self_attn"):
with self.conditional_named_scope("self_attn_norm"):
norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
hidden_states.dtype
)
with self.conditional_named_scope("self_attn_attn"):
attn_output = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
rotary_emb=rotary_emb,
deterministic=deterministic,
rngs=rngs,
)
with self.conditional_named_scope("self_attn_residual"):
hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype)

# 2. Cross-attention
with self.conditional_named_scope("cross_attn"):
with self.conditional_named_scope("cross_attn_norm"):
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype)
with self.conditional_named_scope("cross_attn_attn"):
attn_output = self.attn2(
hidden_states=norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
deterministic=deterministic,
rngs=rngs,
)
with self.conditional_named_scope("cross_attn_residual"):
hidden_states = hidden_states + attn_output

# 3. Feed-forward
with self.conditional_named_scope("mlp"):
with self.conditional_named_scope("mlp_norm"):
norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(
hidden_states.dtype
)
with self.conditional_named_scope("mlp_ffn"):
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
with self.conditional_named_scope("mlp_residual"):
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(
hidden_states.dtype
)
return hidden_states


class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin):
Expand Down Expand Up @@ -416,11 +449,13 @@ def __init__(
names_which_can_be_offloaded: list = [],
mask_padding_tokens: bool = True,
scan_layers: bool = True,
enable_jax_named_scopes: bool = False,
):
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
self.num_layers = num_layers
self.scan_layers = scan_layers
self.enable_jax_named_scopes = enable_jax_named_scopes

# 1. Patch & position embedding
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
Expand Down Expand Up @@ -472,6 +507,7 @@ def init_block(rngs):
attention=attention,
dropout=dropout,
mask_padding_tokens=mask_padding_tokens,
enable_jax_named_scopes=enable_jax_named_scopes,
)

self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
Expand All @@ -497,6 +533,7 @@ def init_block(rngs):
weights_dtype=weights_dtype,
precision=precision,
attention=attention,
enable_jax_named_scopes=enable_jax_named_scopes,
)
blocks.append(block)
self.blocks = blocks
Expand All @@ -517,6 +554,10 @@ def init_block(rngs):
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed")),
)

def conditional_named_scope(self, name: str):
"""Return a JAX named scope if enabled, otherwise a null context."""
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()

def __call__(
self,
hidden_states: jax.Array,
Expand All @@ -536,14 +577,15 @@ def __call__(
post_patch_width = width // p_w

hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1))
rotary_emb = self.rope(hidden_states)
with jax.named_scope("PatchEmbedding"):
with self.conditional_named_scope("rotary_embedding"):
rotary_emb = self.rope(hidden_states)
with self.conditional_named_scope("patch_embedding"):
hidden_states = self.patch_embedding(hidden_states)
hidden_states = jax.lax.collapse(hidden_states, 1, -1)

temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep, encoder_hidden_states, encoder_hidden_states_image
)
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
with self.conditional_named_scope("condition_embedder"):
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep, encoder_hidden_states, encoder_hidden_states_image
)
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)

if encoder_hidden_states_image is not None:
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
wan_config["dropout"] = config.dropout
wan_config["mask_padding_tokens"] = config.mask_padding_tokens
wan_config["scan_layers"] = config.scan_layers
wan_config["enable_jax_named_scopes"] = config.enable_jax_named_scopes

# 2. eval_shape - will not use flops or create weights on device
# thus not using HBM memory.
Expand Down
Loading