@@ -237,10 +237,12 @@ def __init__(
237237 )
238238
239239 def __call__ (self , hidden_states : jax .Array , deterministic : bool = True , rngs : nnx .Rngs = None ) -> jax .Array :
240- hidden_states = self .act_fn (hidden_states ) # Output is (4, 75600, 13824)
241- hidden_states = checkpoint_name (hidden_states , "ffn_activation" )
242- hidden_states = self .drop_out (hidden_states , deterministic = deterministic , rngs = rngs )
243- return self .proj_out (hidden_states ) # output is (4, 75600, 5120)
240+ with jax .named_scope ("mlp_up_proj_and_gelu" ):
241+ hidden_states = self .act_fn (hidden_states ) # Output is (4, 75600, 13824)
242+ hidden_states = checkpoint_name (hidden_states , "ffn_activation" )
243+ hidden_states = self .drop_out (hidden_states , deterministic = deterministic , rngs = rngs )
244+ with jax .named_scope ("mlp_down_proj" ):
245+ return self .proj_out (hidden_states ) # output is (4, 75600, 5120)
244246
245247
246248class WanTransformerBlock (nnx .Module ):
@@ -339,45 +341,59 @@ def __call__(
339341 deterministic : bool = True ,
340342 rngs : nnx .Rngs = None ,
341343 ):
342- shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
343- (self .adaln_scale_shift_table + temb .astype (jnp .float32 )), 6 , axis = 1
344- )
345- hidden_states = jax .lax .with_sharding_constraint (hidden_states , PartitionSpec ("data" , "fsdp" , "tensor" ))
346- hidden_states = checkpoint_name (hidden_states , "hidden_states" )
347- encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , PartitionSpec ("data" , "fsdp" , None ))
348-
349- # 1. Self-attention
350- norm_hidden_states = (self .norm1 (hidden_states .astype (jnp .float32 )) * (1 + scale_msa ) + shift_msa ).astype (
351- hidden_states .dtype
352- )
353- attn_output = self .attn1 (
354- hidden_states = norm_hidden_states ,
355- encoder_hidden_states = norm_hidden_states ,
356- rotary_emb = rotary_emb ,
357- deterministic = deterministic ,
358- rngs = rngs ,
359- )
360- hidden_states = (hidden_states .astype (jnp .float32 ) + attn_output * gate_msa ).astype (hidden_states .dtype )
361-
362- # 2. Cross-attention
363- norm_hidden_states = self .norm2 (hidden_states .astype (jnp .float32 )).astype (hidden_states .dtype )
364- attn_output = self .attn2 (
365- hidden_states = norm_hidden_states ,
366- encoder_hidden_states = encoder_hidden_states ,
367- deterministic = deterministic ,
368- rngs = rngs ,
369- )
370- hidden_states = hidden_states + attn_output
371-
372- # 3. Feed-forward
373- norm_hidden_states = (self .norm3 (hidden_states .astype (jnp .float32 )) * (1 + c_scale_msa ) + c_shift_msa ).astype (
374- hidden_states .dtype
375- )
376- ff_output = self .ffn (norm_hidden_states , deterministic = deterministic , rngs = rngs )
377- hidden_states = (hidden_states .astype (jnp .float32 ) + ff_output .astype (jnp .float32 ) * c_gate_msa ).astype (
378- hidden_states .dtype
379- )
380- return hidden_states
344+ with jax .named_scope ("transformer_block" ):
345+ with jax .named_scope ("adaln" ):
346+ shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
347+ (self .adaln_scale_shift_table + temb .astype (jnp .float32 )), 6 , axis = 1
348+ )
349+ hidden_states = jax .lax .with_sharding_constraint (hidden_states , PartitionSpec ("data" , "fsdp" , "tensor" ))
350+ hidden_states = checkpoint_name (hidden_states , "hidden_states" )
351+ encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , PartitionSpec ("data" , "fsdp" , None ))
352+
353+ # 1. Self-attention
354+ with jax .named_scope ("self_attn" ):
355+ with jax .named_scope ("self_attn_norm" ):
356+ norm_hidden_states = (self .norm1 (hidden_states .astype (jnp .float32 )) * (1 + scale_msa ) + shift_msa ).astype (
357+ hidden_states .dtype
358+ )
359+ with jax .named_scope ("self_attn_attn" ):
360+ attn_output = self .attn1 (
361+ hidden_states = norm_hidden_states ,
362+ encoder_hidden_states = norm_hidden_states ,
363+ rotary_emb = rotary_emb ,
364+ deterministic = deterministic ,
365+ rngs = rngs ,
366+ )
367+ with jax .named_scope ("self_attn_residual" ):
368+ hidden_states = (hidden_states .astype (jnp .float32 ) + attn_output * gate_msa ).astype (hidden_states .dtype )
369+
370+ # 2. Cross-attention
371+ with jax .named_scope ("cross_attn" ):
372+ with jax .named_scope ("cross_attn_norm" ):
373+ norm_hidden_states = self .norm2 (hidden_states .astype (jnp .float32 )).astype (hidden_states .dtype )
374+ with jax .named_scope ("cross_attn_attn" ):
375+ attn_output = self .attn2 (
376+ hidden_states = norm_hidden_states ,
377+ encoder_hidden_states = encoder_hidden_states ,
378+ deterministic = deterministic ,
379+ rngs = rngs ,
380+ )
381+ with jax .named_scope ("cross_attn_residual" ):
382+ hidden_states = hidden_states + attn_output
383+
384+ # 3. Feed-forward
385+ with jax .named_scope ("mlp" ):
386+ with jax .named_scope ("mlp_norm" ):
387+ norm_hidden_states = (self .norm3 (hidden_states .astype (jnp .float32 )) * (1 + c_scale_msa ) + c_shift_msa ).astype (
388+ hidden_states .dtype
389+ )
390+ with jax .named_scope ("mlp_ffn" ):
391+ ff_output = self .ffn (norm_hidden_states , deterministic = deterministic , rngs = rngs )
392+ with jax .named_scope ("mlp_residual" ):
393+ hidden_states = (hidden_states .astype (jnp .float32 ) + ff_output .astype (jnp .float32 ) * c_gate_msa ).astype (
394+ hidden_states .dtype
395+ )
396+ return hidden_states
381397
382398
383399class WanModel (nnx .Module , FlaxModelMixin , ConfigMixin ):
@@ -536,14 +552,15 @@ def __call__(
536552 post_patch_width = width // p_w
537553
538554 hidden_states = jnp .transpose (hidden_states , (0 , 2 , 3 , 4 , 1 ))
539- rotary_emb = self .rope (hidden_states )
540- with jax .named_scope ("PatchEmbedding" ):
555+ with jax .named_scope ("rotary_embedding" ):
556+ rotary_emb = self .rope (hidden_states )
557+ with jax .named_scope ("patch_embedding" ):
541558 hidden_states = self .patch_embedding (hidden_states )
542- hidden_states = jax .lax .collapse (hidden_states , 1 , - 1 )
543-
544- temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = self .condition_embedder (
545- timestep , encoder_hidden_states , encoder_hidden_states_image
546- )
559+ hidden_states = jax .lax .collapse (hidden_states , 1 , - 1 )
560+ with jax . named_scope ( "condition_embedder" ):
561+ temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = self .condition_embedder (
562+ timestep , encoder_hidden_states , encoder_hidden_states_image
563+ )
547564 timestep_proj = timestep_proj .reshape (timestep_proj .shape [0 ], 6 , - 1 )
548565
549566 if encoder_hidden_states_image is not None :
@@ -594,4 +611,4 @@ def layer_forward(hidden_states):
594611 hidden_states = jax .lax .collapse (hidden_states , 6 , None )
595612 hidden_states = jax .lax .collapse (hidden_states , 4 , 6 )
596613 hidden_states = jax .lax .collapse (hidden_states , 2 , 4 )
597- return hidden_states
614+ return hidden_states
0 commit comments