@@ -236,10 +236,12 @@ def __init__(
236236 )
237237
238238 def __call__ (self , hidden_states : jax .Array , deterministic : bool = True , rngs : nnx .Rngs = None ) -> jax .Array :
239- hidden_states = self .act_fn (hidden_states ) # Output is (4, 75600, 13824)
240- hidden_states = checkpoint_name (hidden_states , "ffn_activation" )
241- hidden_states = self .drop_out (hidden_states , deterministic = deterministic , rngs = rngs )
242- return self .proj_out (hidden_states ) # output is (4, 75600, 5120)
239+ with jax .named_scope ("mlp_up_proj_and_gelu" ):
240+ hidden_states = self .act_fn (hidden_states ) # Output is (4, 75600, 13824)
241+ hidden_states = checkpoint_name (hidden_states , "ffn_activation" )
242+ hidden_states = self .drop_out (hidden_states , deterministic = deterministic , rngs = rngs )
243+ with jax .named_scope ("mlp_down_proj" ):
244+ return self .proj_out (hidden_states ) # output is (4, 75600, 5120)
243245
244246
245247class WanTransformerBlock (nnx .Module ):
@@ -331,41 +333,55 @@ def __call__(
331333 deterministic : bool = True ,
332334 rngs : nnx .Rngs = None ,
333335 ):
334- shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
335- (self .adaln_scale_shift_table + temb .astype (jnp .float32 )), 6 , axis = 1
336- )
337- hidden_states = jax .lax .with_sharding_constraint (hidden_states , PartitionSpec ("data" , "fsdp" , "tensor" ))
338- encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , PartitionSpec ("data" , "fsdp" , None ))
339-
340- # 1. Self-attention
341- norm_hidden_states = (self .norm1 (hidden_states .astype (jnp .float32 )) * (1 + scale_msa ) + shift_msa ).astype (
342- hidden_states .dtype
343- )
344- attn_output = self .attn1 (
345- hidden_states = norm_hidden_states ,
346- encoder_hidden_states = norm_hidden_states ,
347- rotary_emb = rotary_emb ,
348- deterministic = deterministic ,
349- rngs = rngs ,
350- )
351- hidden_states = (hidden_states .astype (jnp .float32 ) + attn_output * gate_msa ).astype (hidden_states .dtype )
352-
353- # 2. Cross-attention
354- norm_hidden_states = self .norm2 (hidden_states .astype (jnp .float32 )).astype (hidden_states .dtype )
355- attn_output = self .attn2 (
356- hidden_states = norm_hidden_states , encoder_hidden_states = encoder_hidden_states , deterministic = deterministic , rngs = rngs
357- )
358- hidden_states = hidden_states + attn_output
359-
360- # 3. Feed-forward
361- norm_hidden_states = (self .norm3 (hidden_states .astype (jnp .float32 )) * (1 + c_scale_msa ) + c_shift_msa ).astype (
362- hidden_states .dtype
363- )
364- ff_output = self .ffn (norm_hidden_states , deterministic = deterministic , rngs = rngs )
365- hidden_states = (hidden_states .astype (jnp .float32 ) + ff_output .astype (jnp .float32 ) * c_gate_msa ).astype (
366- hidden_states .dtype
367- )
368- return hidden_states
336+ with jax .named_scope ("transformer_block" ):
337+ with jax .named_scope ("adaln" ):
338+ shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
339+ (self .adaln_scale_shift_table + temb .astype (jnp .float32 )), 6 , axis = 1
340+ )
341+ hidden_states = jax .lax .with_sharding_constraint (hidden_states , PartitionSpec ("data" , "fsdp" , "tensor" ))
342+ encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , PartitionSpec ("data" , "fsdp" , None ))
343+
344+ # 1. Self-attention
345+ with jax .named_scope ("self_attention" ):
346+ with jax .named_scope ("self_attention_norm" ):
347+ norm_hidden_states = (self .norm1 (hidden_states .astype (jnp .float32 )) * (1 + scale_msa ) + shift_msa ).astype (
348+ hidden_states .dtype
349+ )
350+ with jax .named_scope ("self_attention_attn" ):
351+ attn_output = self .attn1 (
352+ hidden_states = norm_hidden_states ,
353+ encoder_hidden_states = norm_hidden_states ,
354+ rotary_emb = rotary_emb ,
355+ deterministic = deterministic ,
356+ rngs = rngs ,
357+ )
358+ with jax .named_scope ("self_attention_residual" ):
359+ hidden_states = (hidden_states .astype (jnp .float32 ) + attn_output * gate_msa ).astype (hidden_states .dtype )
360+
361+ # 2. Cross-attention
362+ with jax .named_scope ("cross_attention" ):
363+ with jax .named_scope ("cross_attention_norm" ):
364+ norm_hidden_states = self .norm2 (hidden_states .astype (jnp .float32 )).astype (hidden_states .dtype )
365+ with jax .named_scope ("cross_attention_attn" ):
366+ attn_output = self .attn2 (
367+ hidden_states = norm_hidden_states , encoder_hidden_states = encoder_hidden_states , deterministic = deterministic , rngs = rngs
368+ )
369+ with jax .named_scope ("cross_attention_residual" ):
370+ hidden_states = hidden_states + attn_output
371+
372+ # 3. Feed-forward
373+ with jax .named_scope ("mlp" ):
374+ with jax .named_scope ("mlp_norm" ):
375+ norm_hidden_states = (self .norm3 (hidden_states .astype (jnp .float32 )) * (1 + c_scale_msa ) + c_shift_msa ).astype (
376+ hidden_states .dtype
377+ )
378+ with jax .named_scope ("mlp_ffn" ):
379+ ff_output = self .ffn (norm_hidden_states , deterministic = deterministic , rngs = rngs )
380+ with jax .named_scope ("mlp_residual" ):
381+ hidden_states = (hidden_states .astype (jnp .float32 ) + ff_output .astype (jnp .float32 ) * c_gate_msa ).astype (
382+ hidden_states .dtype
383+ )
384+ return hidden_states
369385
370386
371387class WanModel (nnx .Module , FlaxModelMixin , ConfigMixin ):
@@ -522,13 +538,15 @@ def __call__(
522538
523539 hidden_states = jnp .transpose (hidden_states , (0 , 2 , 3 , 4 , 1 ))
524540 rotary_emb = self .rope (hidden_states )
525-
526- hidden_states = self .patch_embedding (hidden_states )
527- hidden_states = jax .lax .collapse (hidden_states , 1 , - 1 )
528-
529- temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = self .condition_embedder (
530- timestep , encoder_hidden_states , encoder_hidden_states_image
531- )
541+
542+ with jax .named_scope ("patch_embedding" ):
543+ hidden_states = self .patch_embedding (hidden_states )
544+ hidden_states = jax .lax .collapse (hidden_states , 1 , - 1 )
545+
546+ with jax .named_scope ("condition_embedding" ):
547+ temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = self .condition_embedder (
548+ timestep , encoder_hidden_states , encoder_hidden_states_image
549+ )
532550 timestep_proj = timestep_proj .reshape (timestep_proj .shape [0 ], 6 , - 1 )
533551
534552 if encoder_hidden_states_image is not None :
0 commit comments