Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@ enable_profiler: False
skip_first_n_steps_for_profiler: 5
profiler_steps: 10

# Enable JAX named scopes for detailed profiling and debugging
# When enabled, adds named scopes around key operations in transformer and attention layers
enable_jax_named_scopes: False

# Generation parameters
prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
Expand Down
45 changes: 32 additions & 13 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import functools
import math
from typing import Optional, Callable, Tuple
Expand Down Expand Up @@ -805,6 +806,7 @@ def __init__(
is_self_attention: bool = True,
mask_padding_tokens: bool = True,
residual_checkpoint_name: str | None = None,
enable_jax_named_scopes: bool = False,
):
if attention_kernel == "cudnn_flash_te":
raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}")
Expand All @@ -820,6 +822,7 @@ def __init__(
self.key_axis_names = key_axis_names
self.value_axis_names = value_axis_names
self.out_axis_names = out_axis_names
self.enable_jax_named_scopes = enable_jax_named_scopes

if is_self_attention:
axis_names_q = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_Q_LENGTH, D_KV)
Expand Down Expand Up @@ -952,6 +955,10 @@ def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tup

return xq_out, xk_out

def conditional_named_scope(self, name: str):
"""Return a JAX named scope if enabled, otherwise a null context."""
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()

def __call__(
self,
hidden_states: jax.Array,
Expand All @@ -966,29 +973,41 @@ def __call__(
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states

query_proj = self.query(hidden_states)
key_proj = self.key(encoder_hidden_states)
value_proj = self.value(encoder_hidden_states)
with self.conditional_named_scope("attn_qkv_proj"):
with self.conditional_named_scope("proj_query"):
query_proj = self.query(hidden_states)
with self.conditional_named_scope("proj_key"):
key_proj = self.key(encoder_hidden_states)
with self.conditional_named_scope("proj_value"):
value_proj = self.value(encoder_hidden_states)

if self.qk_norm:
query_proj = self.norm_q(query_proj)
key_proj = self.norm_k(key_proj)
with self.conditional_named_scope("attn_q_norm"):
query_proj = self.norm_q(query_proj)
with self.conditional_named_scope("attn_k_norm"):
key_proj = self.norm_k(key_proj)

if rotary_emb is not None:
query_proj = _unflatten_heads(query_proj, self.heads)
key_proj = _unflatten_heads(key_proj, self.heads)
value_proj = _unflatten_heads(value_proj, self.heads)
# output of _unflatten_heads Batch, heads, seq_len, head_dim
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
with self.conditional_named_scope("attn_rope"):
query_proj = _unflatten_heads(query_proj, self.heads)
key_proj = _unflatten_heads(key_proj, self.heads)
value_proj = _unflatten_heads(value_proj, self.heads)
# output of _unflatten_heads Batch, heads, seq_len, head_dim
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)

query_proj = checkpoint_name(query_proj, "query_proj")
key_proj = checkpoint_name(key_proj, "key_proj")
value_proj = checkpoint_name(value_proj, "value_proj")
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)

with self.conditional_named_scope("attn_compute"):
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)

attn_output = attn_output.astype(dtype=dtype)
attn_output = checkpoint_name(attn_output, "attn_output")
hidden_states = self.proj_attn(attn_output)
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)

with self.conditional_named_scope("attn_out_proj"):
hidden_states = self.proj_attn(attn_output)
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
return hidden_states


Expand Down
149 changes: 96 additions & 53 deletions src/maxdiffusion/models/wan/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

from typing import Tuple, Optional, Dict, Union, Any
import contextlib
import math
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -205,11 +206,13 @@ def __init__(
dtype: jnp.dtype = jnp.float32,
weights_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.Precision = None,
enable_jax_named_scopes: bool = False,
):
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim

self.enable_jax_named_scopes = enable_jax_named_scopes
self.act_fn = nnx.data(None)
if activation_fn == "gelu-approximate":
self.act_fn = ApproximateGELU(
Expand All @@ -236,11 +239,17 @@ def __init__(
),
)

def conditional_named_scope(self, name: str):
"""Return a JAX named scope if enabled, otherwise a null context."""
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()

def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array:
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
hidden_states = checkpoint_name(hidden_states, "ffn_activation")
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
return self.proj_out(hidden_states) # output is (4, 75600, 5120)
with self.conditional_named_scope("mlp_up_proj_and_gelu"):
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
hidden_states = checkpoint_name(hidden_states, "ffn_activation")
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
with self.conditional_named_scope("mlp_down_proj"):
return self.proj_out(hidden_states) # output is (4, 75600, 5120)


class WanTransformerBlock(nnx.Module):
Expand All @@ -265,8 +274,11 @@ def __init__(
attention: str = "dot_product",
dropout: float = 0.0,
mask_padding_tokens: bool = True,
enable_jax_named_scopes: bool = False,
):

self.enable_jax_named_scopes = enable_jax_named_scopes

# 1. Self-attention
self.norm1 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False)
self.attn1 = FlaxWanAttention(
Expand All @@ -287,6 +299,7 @@ def __init__(
is_self_attention=True,
mask_padding_tokens=mask_padding_tokens,
residual_checkpoint_name="self_attn",
enable_jax_named_scopes=enable_jax_named_scopes,
)

# 1. Cross-attention
Expand All @@ -308,6 +321,7 @@ def __init__(
is_self_attention=False,
mask_padding_tokens=mask_padding_tokens,
residual_checkpoint_name="cross_attn",
enable_jax_named_scopes=enable_jax_named_scopes,
)
assert cross_attn_norm is True
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
Expand All @@ -322,6 +336,7 @@ def __init__(
weights_dtype=weights_dtype,
precision=precision,
dropout=dropout,
enable_jax_named_scopes=enable_jax_named_scopes,
)
self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False)

Expand All @@ -330,6 +345,10 @@ def __init__(
jax.random.normal(key, (1, 6, dim)) / dim**0.5,
)

def conditional_named_scope(self, name: str):
"""Return a JAX named scope if enabled, otherwise a null context."""
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()

def __call__(
self,
hidden_states: jax.Array,
Expand All @@ -339,45 +358,59 @@ def __call__(
deterministic: bool = True,
rngs: nnx.Rngs = None,
):
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
)
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
hidden_states = checkpoint_name(hidden_states, "hidden_states")
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))

# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
hidden_states.dtype
)
attn_output = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
rotary_emb=rotary_emb,
deterministic=deterministic,
rngs=rngs,
)
hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype)

# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype)
attn_output = self.attn2(
hidden_states=norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
deterministic=deterministic,
rngs=rngs,
)
hidden_states = hidden_states + attn_output

# 3. Feed-forward
norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(
hidden_states.dtype
)
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(
hidden_states.dtype
)
return hidden_states
with self.conditional_named_scope("transformer_block"):
with self.conditional_named_scope("adaln"):
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
)
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
hidden_states = checkpoint_name(hidden_states, "hidden_states")
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))

# 1. Self-attention
with self.conditional_named_scope("self_attn"):
with self.conditional_named_scope("self_attn_norm"):
norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
hidden_states.dtype
)
with self.conditional_named_scope("self_attn_attn"):
attn_output = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
rotary_emb=rotary_emb,
deterministic=deterministic,
rngs=rngs,
)
with self.conditional_named_scope("self_attn_residual"):
hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype)

# 2. Cross-attention
with self.conditional_named_scope("cross_attn"):
with self.conditional_named_scope("cross_attn_norm"):
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype)
with self.conditional_named_scope("cross_attn_attn"):
attn_output = self.attn2(
hidden_states=norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
deterministic=deterministic,
rngs=rngs,
)
with self.conditional_named_scope("cross_attn_residual"):
hidden_states = hidden_states + attn_output

# 3. Feed-forward
with self.conditional_named_scope("mlp"):
with self.conditional_named_scope("mlp_norm"):
norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(
hidden_states.dtype
)
with self.conditional_named_scope("mlp_ffn"):
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
with self.conditional_named_scope("mlp_residual"):
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(
hidden_states.dtype
)
return hidden_states


class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin):
Expand Down Expand Up @@ -416,11 +449,13 @@ def __init__(
names_which_can_be_offloaded: list = [],
mask_padding_tokens: bool = True,
scan_layers: bool = True,
enable_jax_named_scopes: bool = False,
):
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
self.num_layers = num_layers
self.scan_layers = scan_layers
self.enable_jax_named_scopes = enable_jax_named_scopes

# 1. Patch & position embedding
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
Expand Down Expand Up @@ -472,6 +507,7 @@ def init_block(rngs):
attention=attention,
dropout=dropout,
mask_padding_tokens=mask_padding_tokens,
enable_jax_named_scopes=enable_jax_named_scopes,
)

self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
Expand All @@ -497,6 +533,7 @@ def init_block(rngs):
weights_dtype=weights_dtype,
precision=precision,
attention=attention,
enable_jax_named_scopes=enable_jax_named_scopes,
)
blocks.append(block)
self.blocks = blocks
Expand All @@ -517,6 +554,10 @@ def init_block(rngs):
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed")),
)

def conditional_named_scope(self, name: str):
"""Return a JAX named scope if enabled, otherwise a null context."""
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()

def __call__(
self,
hidden_states: jax.Array,
Expand All @@ -536,14 +577,15 @@ def __call__(
post_patch_width = width // p_w

hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1))
rotary_emb = self.rope(hidden_states)
with jax.named_scope("PatchEmbedding"):
with self.conditional_named_scope("rotary_embedding"):
rotary_emb = self.rope(hidden_states)
with self.conditional_named_scope("patch_embedding"):
hidden_states = self.patch_embedding(hidden_states)
hidden_states = jax.lax.collapse(hidden_states, 1, -1)

temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep, encoder_hidden_states, encoder_hidden_states_image
)
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
with self.conditional_named_scope("condition_embedder"):
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep, encoder_hidden_states, encoder_hidden_states_image
)
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)

if encoder_hidden_states_image is not None:
Expand Down Expand Up @@ -583,9 +625,10 @@ def layer_forward(hidden_states):
hidden_states = rematted_layer_forward(hidden_states)

shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1)

hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)
hidden_states = self.proj_out(hidden_states)
with self.conditional_named_scope("output_norm"):
hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)
with self.conditional_named_scope("output_proj"):
hidden_states = self.proj_out(hidden_states)

hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
wan_config["dropout"] = config.dropout
wan_config["mask_padding_tokens"] = config.mask_padding_tokens
wan_config["scan_layers"] = config.scan_layers
wan_config["enable_jax_named_scopes"] = config.enable_jax_named_scopes

# 2. eval_shape - will not use flops or create weights on device
# thus not using HBM memory.
Expand Down
Loading