Skip to content

Commit c2e0f15

Browse files
author
eltsai
committed
Adding name scopes for easier XProf for WAN 2.1
1 parent f5f212f commit c2e0f15

File tree

2 files changed

+89
-60
lines changed

2 files changed

+89
-60
lines changed

src/maxdiffusion/models/attention_flax.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -850,30 +850,41 @@ def __call__(
850850
dtype = hidden_states.dtype
851851
if encoder_hidden_states is None:
852852
encoder_hidden_states = hidden_states
853-
854-
query_proj = self.query(hidden_states)
855-
key_proj = self.key(encoder_hidden_states)
856-
value_proj = self.value(encoder_hidden_states)
853+
854+
with jax.named_scope("attn_qkv_proj"):
855+
with jax.named_scope("proj_query"):
856+
query_proj = self.query(hidden_states)
857+
with jax.named_scope("proj_key"):
858+
key_proj = self.key(encoder_hidden_states)
859+
with jax.named_scope("proj_value"):
860+
value_proj = self.value(encoder_hidden_states)
857861

858862
if self.qk_norm:
859-
query_proj = self.norm_q(query_proj)
860-
key_proj = self.norm_k(key_proj)
863+
with jax.named_scope("attn_q_norm"):
864+
query_proj = self.norm_q(query_proj)
865+
with jax.named_scope("attn_k_norm"):
866+
key_proj = self.norm_k(key_proj)
861867
if rotary_emb is not None:
862-
query_proj = _unflatten_heads(query_proj, self.heads)
863-
key_proj = _unflatten_heads(key_proj, self.heads)
864-
value_proj = _unflatten_heads(value_proj, self.heads)
865-
# output of _unflatten_heads Batch, heads, seq_len, head_dim
866-
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
868+
with jax.named_scope("attn_rope"):
869+
query_proj = _unflatten_heads(query_proj, self.heads)
870+
key_proj = _unflatten_heads(key_proj, self.heads)
871+
value_proj = _unflatten_heads(value_proj, self.heads)
872+
# output of _unflatten_heads Batch, heads, seq_len, head_dim
873+
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
867874

868875
query_proj = checkpoint_name(query_proj, "query_proj")
869876
key_proj = checkpoint_name(key_proj, "key_proj")
870877
value_proj = checkpoint_name(value_proj, "value_proj")
871-
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
878+
879+
with jax.named_scope("attn_compute"):
880+
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
872881

873882
attn_output = attn_output.astype(dtype=dtype)
874883
attn_output = checkpoint_name(attn_output, "attn_output")
875-
hidden_states = self.proj_attn(attn_output)
876-
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
884+
885+
with jax.named_scope("attn_out_proj"):
886+
hidden_states = self.proj_attn(attn_output)
887+
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
877888
return hidden_states
878889

879890

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 64 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,12 @@ def __init__(
236236
)
237237

238238
def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array:
239-
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
240-
hidden_states = checkpoint_name(hidden_states, "ffn_activation")
241-
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
242-
return self.proj_out(hidden_states) # output is (4, 75600, 5120)
239+
with jax.named_scope("mlp_up_proj_and_gelu"):
240+
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
241+
hidden_states = checkpoint_name(hidden_states, "ffn_activation")
242+
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
243+
with jax.named_scope("mlp_down_proj"):
244+
return self.proj_out(hidden_states) # output is (4, 75600, 5120)
243245

244246

245247
class WanTransformerBlock(nnx.Module):
@@ -331,41 +333,55 @@ def __call__(
331333
deterministic: bool = True,
332334
rngs: nnx.Rngs = None,
333335
):
334-
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
335-
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
336-
)
337-
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
338-
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))
339-
340-
# 1. Self-attention
341-
norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
342-
hidden_states.dtype
343-
)
344-
attn_output = self.attn1(
345-
hidden_states=norm_hidden_states,
346-
encoder_hidden_states=norm_hidden_states,
347-
rotary_emb=rotary_emb,
348-
deterministic=deterministic,
349-
rngs=rngs,
350-
)
351-
hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype)
352-
353-
# 2. Cross-attention
354-
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype)
355-
attn_output = self.attn2(
356-
hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs
357-
)
358-
hidden_states = hidden_states + attn_output
359-
360-
# 3. Feed-forward
361-
norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(
362-
hidden_states.dtype
363-
)
364-
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
365-
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(
366-
hidden_states.dtype
367-
)
368-
return hidden_states
336+
with jax.named_scope("transformer_block"):
337+
with jax.named_scope("adaln"):
338+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
339+
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
340+
)
341+
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
342+
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))
343+
344+
# 1. Self-attention
345+
with jax.named_scope("self_attention"):
346+
with jax.named_scope("self_attention_norm"):
347+
norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
348+
hidden_states.dtype
349+
)
350+
with jax.named_scope("self_attention_attn"):
351+
attn_output = self.attn1(
352+
hidden_states=norm_hidden_states,
353+
encoder_hidden_states=norm_hidden_states,
354+
rotary_emb=rotary_emb,
355+
deterministic=deterministic,
356+
rngs=rngs,
357+
)
358+
with jax.named_scope("self_attention_residual"):
359+
hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype)
360+
361+
# 2. Cross-attention
362+
with jax.named_scope("cross_attention"):
363+
with jax.named_scope("cross_attention_norm"):
364+
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype)
365+
with jax.named_scope("cross_attention_attn"):
366+
attn_output = self.attn2(
367+
hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs
368+
)
369+
with jax.named_scope("cross_attention_residual"):
370+
hidden_states = hidden_states + attn_output
371+
372+
# 3. Feed-forward
373+
with jax.named_scope("mlp"):
374+
with jax.named_scope("mlp_norm"):
375+
norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(
376+
hidden_states.dtype
377+
)
378+
with jax.named_scope("mlp_ffn"):
379+
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
380+
with jax.named_scope("mlp_residual"):
381+
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(
382+
hidden_states.dtype
383+
)
384+
return hidden_states
369385

370386

371387
class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin):
@@ -522,13 +538,15 @@ def __call__(
522538

523539
hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1))
524540
rotary_emb = self.rope(hidden_states)
525-
526-
hidden_states = self.patch_embedding(hidden_states)
527-
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
528-
529-
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
530-
timestep, encoder_hidden_states, encoder_hidden_states_image
531-
)
541+
542+
with jax.named_scope("patch_embedding"):
543+
hidden_states = self.patch_embedding(hidden_states)
544+
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
545+
546+
with jax.named_scope("condition_embedding"):
547+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
548+
timestep, encoder_hidden_states, encoder_hidden_states_image
549+
)
532550
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)
533551

534552
if encoder_hidden_states_image is not None:

0 commit comments

Comments
 (0)