Skip to content

Commit cda888d

Browse files
author
eltsai
committed
Added named_scope for WAN 2.1 Xprof profiling
1 parent ba041cb commit cda888d

File tree

2 files changed

+94
-65
lines changed

2 files changed

+94
-65
lines changed

src/maxdiffusion/models/attention_flax.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -965,30 +965,42 @@ def __call__(
965965
dtype = hidden_states.dtype
966966
if encoder_hidden_states is None:
967967
encoder_hidden_states = hidden_states
968-
969-
query_proj = self.query(hidden_states)
970-
key_proj = self.key(encoder_hidden_states)
971-
value_proj = self.value(encoder_hidden_states)
968+
969+
with jax.named_scope("attn_qkv_proj"):
970+
with jax.named_scope("proj_query"):
971+
query_proj = self.query(hidden_states)
972+
with jax.named_scope("proj_key"):
973+
key_proj = self.key(encoder_hidden_states)
974+
with jax.named_scope("proj_value"):
975+
value_proj = self.value(encoder_hidden_states)
972976

973977
if self.qk_norm:
974-
query_proj = self.norm_q(query_proj)
975-
key_proj = self.norm_k(key_proj)
978+
with jax.named_scope("attn_q_norm"):
979+
query_proj = self.norm_q(query_proj)
980+
with jax.named_scope("attn_k_norm"):
981+
key_proj = self.norm_k(key_proj)
982+
976983
if rotary_emb is not None:
977-
query_proj = _unflatten_heads(query_proj, self.heads)
978-
key_proj = _unflatten_heads(key_proj, self.heads)
979-
value_proj = _unflatten_heads(value_proj, self.heads)
980-
# output of _unflatten_heads Batch, heads, seq_len, head_dim
981-
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
984+
with jax.named_scope("attn_rope"):
985+
query_proj = _unflatten_heads(query_proj, self.heads)
986+
key_proj = _unflatten_heads(key_proj, self.heads)
987+
value_proj = _unflatten_heads(value_proj, self.heads)
988+
# output of _unflatten_heads Batch, heads, seq_len, head_dim
989+
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
982990

983991
query_proj = checkpoint_name(query_proj, "query_proj")
984992
key_proj = checkpoint_name(key_proj, "key_proj")
985993
value_proj = checkpoint_name(value_proj, "value_proj")
986-
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
994+
995+
with jax.named_scope("attn_compute"):
996+
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
987997

988998
attn_output = attn_output.astype(dtype=dtype)
989999
attn_output = checkpoint_name(attn_output, "attn_output")
990-
hidden_states = self.proj_attn(attn_output)
991-
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
1000+
1001+
with jax.named_scope("attn_out_proj"):
1002+
hidden_states = self.proj_attn(attn_output)
1003+
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
9921004
return hidden_states
9931005

9941006

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 68 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,12 @@ def __init__(
237237
)
238238

239239
def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array:
240-
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
241-
hidden_states = checkpoint_name(hidden_states, "ffn_activation")
242-
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
243-
return self.proj_out(hidden_states) # output is (4, 75600, 5120)
240+
with jax.named_scope("mlp_up_proj_and_gelu"):
241+
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
242+
hidden_states = checkpoint_name(hidden_states, "ffn_activation")
243+
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
244+
with jax.named_scope("mlp_down_proj"):
245+
return self.proj_out(hidden_states) # output is (4, 75600, 5120)
244246

245247

246248
class WanTransformerBlock(nnx.Module):
@@ -339,45 +341,59 @@ def __call__(
339341
deterministic: bool = True,
340342
rngs: nnx.Rngs = None,
341343
):
342-
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
343-
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
344-
)
345-
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
346-
hidden_states = checkpoint_name(hidden_states, "hidden_states")
347-
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))
348-
349-
# 1. Self-attention
350-
norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
351-
hidden_states.dtype
352-
)
353-
attn_output = self.attn1(
354-
hidden_states=norm_hidden_states,
355-
encoder_hidden_states=norm_hidden_states,
356-
rotary_emb=rotary_emb,
357-
deterministic=deterministic,
358-
rngs=rngs,
359-
)
360-
hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype)
361-
362-
# 2. Cross-attention
363-
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype)
364-
attn_output = self.attn2(
365-
hidden_states=norm_hidden_states,
366-
encoder_hidden_states=encoder_hidden_states,
367-
deterministic=deterministic,
368-
rngs=rngs,
369-
)
370-
hidden_states = hidden_states + attn_output
371-
372-
# 3. Feed-forward
373-
norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(
374-
hidden_states.dtype
375-
)
376-
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
377-
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(
378-
hidden_states.dtype
379-
)
380-
return hidden_states
344+
with jax.named_scope("transformer_block"):
345+
with jax.named_scope("adaln"):
346+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
347+
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
348+
)
349+
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
350+
hidden_states = checkpoint_name(hidden_states, "hidden_states")
351+
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))
352+
353+
# 1. Self-attention
354+
with jax.named_scope("self_attn"):
355+
with jax.named_scope("self_attn_norm"):
356+
norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
357+
hidden_states.dtype
358+
)
359+
with jax.named_scope("self_attn_attn"):
360+
attn_output = self.attn1(
361+
hidden_states=norm_hidden_states,
362+
encoder_hidden_states=norm_hidden_states,
363+
rotary_emb=rotary_emb,
364+
deterministic=deterministic,
365+
rngs=rngs,
366+
)
367+
with jax.named_scope("self_attn_residual"):
368+
hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype)
369+
370+
# 2. Cross-attention
371+
with jax.named_scope("cross_attn"):
372+
with jax.named_scope("cross_attn_norm"):
373+
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype)
374+
with jax.named_scope("cross_attn_attn"):
375+
attn_output = self.attn2(
376+
hidden_states=norm_hidden_states,
377+
encoder_hidden_states=encoder_hidden_states,
378+
deterministic=deterministic,
379+
rngs=rngs,
380+
)
381+
with jax.named_scope("cross_attn_residual"):
382+
hidden_states = hidden_states + attn_output
383+
384+
# 3. Feed-forward
385+
with jax.named_scope("mlp"):
386+
with jax.named_scope("mlp_norm"):
387+
norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(
388+
hidden_states.dtype
389+
)
390+
with jax.named_scope("mlp_ffn"):
391+
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
392+
with jax.named_scope("mlp_residual"):
393+
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(
394+
hidden_states.dtype
395+
)
396+
return hidden_states
381397

382398

383399
class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin):
@@ -536,14 +552,15 @@ def __call__(
536552
post_patch_width = width // p_w
537553

538554
hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1))
539-
rotary_emb = self.rope(hidden_states)
540-
with jax.named_scope("PatchEmbedding"):
555+
with jax.named_scope("rotary_embedding"):
556+
rotary_emb = self.rope(hidden_states)
557+
with jax.named_scope("patch_embedding"):
541558
hidden_states = self.patch_embedding(hidden_states)
542-
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
543-
544-
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
545-
timestep, encoder_hidden_states, encoder_hidden_states_image
546-
)
559+
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
560+
with jax.named_scope("condition_embedder"):
561+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
562+
timestep, encoder_hidden_states, encoder_hidden_states_image
563+
)
547564
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)
548565

549566
if encoder_hidden_states_image is not None:
@@ -594,4 +611,4 @@ def layer_forward(hidden_states):
594611
hidden_states = jax.lax.collapse(hidden_states, 6, None)
595612
hidden_states = jax.lax.collapse(hidden_states, 4, 6)
596613
hidden_states = jax.lax.collapse(hidden_states, 2, 4)
597-
return hidden_states
614+
return hidden_states

0 commit comments

Comments
 (0)