1515"""
1616
1717from typing import Tuple , Optional , Dict , Union , Any
18+ import contextlib
1819import math
1920import jax
2021import jax .numpy as jnp
@@ -205,11 +206,13 @@ def __init__(
205206 dtype : jnp .dtype = jnp .float32 ,
206207 weights_dtype : jnp .dtype = jnp .float32 ,
207208 precision : jax .lax .Precision = None ,
209+ enable_jax_named_scopes : bool = False ,
208210 ):
209211 if inner_dim is None :
210212 inner_dim = int (dim * mult )
211213 dim_out = dim_out if dim_out is not None else dim
212214
215+ self .enable_jax_named_scopes = enable_jax_named_scopes
213216 self .act_fn = nnx .data (None )
214217 if activation_fn == "gelu-approximate" :
215218 self .act_fn = ApproximateGELU (
@@ -236,12 +239,16 @@ def __init__(
236239 ),
237240 )
238241
242+ def conditional_named_scope (self , name : str ):
243+ """Return a JAX named scope if enabled, otherwise a null context."""
244+ return jax .named_scope (name ) if self .enable_jax_named_scopes else contextlib .nullcontext ()
245+
239246 def __call__ (self , hidden_states : jax .Array , deterministic : bool = True , rngs : nnx .Rngs = None ) -> jax .Array :
240- with jax . named_scope ("mlp_up_proj_and_gelu" ):
247+ with self . conditional_named_scope ("mlp_up_proj_and_gelu" ):
241248 hidden_states = self .act_fn (hidden_states ) # Output is (4, 75600, 13824)
242249 hidden_states = checkpoint_name (hidden_states , "ffn_activation" )
243250 hidden_states = self .drop_out (hidden_states , deterministic = deterministic , rngs = rngs )
244- with jax . named_scope ("mlp_down_proj" ):
251+ with self . conditional_named_scope ("mlp_down_proj" ):
245252 return self .proj_out (hidden_states ) # output is (4, 75600, 5120)
246253
247254
@@ -267,8 +274,11 @@ def __init__(
267274 attention : str = "dot_product" ,
268275 dropout : float = 0.0 ,
269276 mask_padding_tokens : bool = True ,
277+ enable_jax_named_scopes : bool = False ,
270278 ):
271279
280+ self .enable_jax_named_scopes = enable_jax_named_scopes
281+
272282 # 1. Self-attention
273283 self .norm1 = FP32LayerNorm (rngs = rngs , dim = dim , eps = eps , elementwise_affine = False )
274284 self .attn1 = FlaxWanAttention (
@@ -289,6 +299,7 @@ def __init__(
289299 is_self_attention = True ,
290300 mask_padding_tokens = mask_padding_tokens ,
291301 residual_checkpoint_name = "self_attn" ,
302+ enable_jax_named_scopes = enable_jax_named_scopes ,
292303 )
293304
294305 # 1. Cross-attention
@@ -310,6 +321,7 @@ def __init__(
310321 is_self_attention = False ,
311322 mask_padding_tokens = mask_padding_tokens ,
312323 residual_checkpoint_name = "cross_attn" ,
324+ enable_jax_named_scopes = enable_jax_named_scopes ,
313325 )
314326 assert cross_attn_norm is True
315327 self .norm2 = FP32LayerNorm (rngs = rngs , dim = dim , eps = eps , elementwise_affine = True )
@@ -324,6 +336,7 @@ def __init__(
324336 weights_dtype = weights_dtype ,
325337 precision = precision ,
326338 dropout = dropout ,
339+ enable_jax_named_scopes = enable_jax_named_scopes ,
327340 )
328341 self .norm3 = FP32LayerNorm (rngs = rngs , dim = dim , eps = eps , elementwise_affine = False )
329342
@@ -332,6 +345,10 @@ def __init__(
332345 jax .random .normal (key , (1 , 6 , dim )) / dim ** 0.5 ,
333346 )
334347
348+ def conditional_named_scope (self , name : str ):
349+ """Return a JAX named scope if enabled, otherwise a null context."""
350+ return jax .named_scope (name ) if self .enable_jax_named_scopes else contextlib .nullcontext ()
351+
335352 def __call__ (
336353 self ,
337354 hidden_states : jax .Array ,
@@ -341,8 +358,8 @@ def __call__(
341358 deterministic : bool = True ,
342359 rngs : nnx .Rngs = None ,
343360 ):
344- with jax . named_scope ("transformer_block" ):
345- with jax . named_scope ("adaln" ):
361+ with self . conditional_named_scope ("transformer_block" ):
362+ with self . conditional_named_scope ("adaln" ):
346363 shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
347364 (self .adaln_scale_shift_table + temb .astype (jnp .float32 )), 6 , axis = 1
348365 )
@@ -351,45 +368,45 @@ def __call__(
351368 encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , PartitionSpec ("data" , "fsdp" , None ))
352369
353370 # 1. Self-attention
354- with jax . named_scope ("self_attn" ):
355- with jax . named_scope ("self_attn_norm" ):
371+ with self . conditional_named_scope ("self_attn" ):
372+ with self . conditional_named_scope ("self_attn_norm" ):
356373 norm_hidden_states = (self .norm1 (hidden_states .astype (jnp .float32 )) * (1 + scale_msa ) + shift_msa ).astype (
357374 hidden_states .dtype
358375 )
359- with jax . named_scope ("self_attn_attn" ):
376+ with self . conditional_named_scope ("self_attn_attn" ):
360377 attn_output = self .attn1 (
361378 hidden_states = norm_hidden_states ,
362379 encoder_hidden_states = norm_hidden_states ,
363380 rotary_emb = rotary_emb ,
364381 deterministic = deterministic ,
365382 rngs = rngs ,
366383 )
367- with jax . named_scope ("self_attn_residual" ):
384+ with self . conditional_named_scope ("self_attn_residual" ):
368385 hidden_states = (hidden_states .astype (jnp .float32 ) + attn_output * gate_msa ).astype (hidden_states .dtype )
369386
370387 # 2. Cross-attention
371- with jax . named_scope ("cross_attn" ):
372- with jax . named_scope ("cross_attn_norm" ):
388+ with self . conditional_named_scope ("cross_attn" ):
389+ with self . conditional_named_scope ("cross_attn_norm" ):
373390 norm_hidden_states = self .norm2 (hidden_states .astype (jnp .float32 )).astype (hidden_states .dtype )
374- with jax . named_scope ("cross_attn_attn" ):
391+ with self . conditional_named_scope ("cross_attn_attn" ):
375392 attn_output = self .attn2 (
376393 hidden_states = norm_hidden_states ,
377394 encoder_hidden_states = encoder_hidden_states ,
378395 deterministic = deterministic ,
379396 rngs = rngs ,
380397 )
381- with jax . named_scope ("cross_attn_residual" ):
398+ with self . conditional_named_scope ("cross_attn_residual" ):
382399 hidden_states = hidden_states + attn_output
383400
384401 # 3. Feed-forward
385- with jax . named_scope ("mlp" ):
386- with jax . named_scope ("mlp_norm" ):
402+ with self . conditional_named_scope ("mlp" ):
403+ with self . conditional_named_scope ("mlp_norm" ):
387404 norm_hidden_states = (self .norm3 (hidden_states .astype (jnp .float32 )) * (1 + c_scale_msa ) + c_shift_msa ).astype (
388405 hidden_states .dtype
389406 )
390- with jax . named_scope ("mlp_ffn" ):
407+ with self . conditional_named_scope ("mlp_ffn" ):
391408 ff_output = self .ffn (norm_hidden_states , deterministic = deterministic , rngs = rngs )
392- with jax . named_scope ("mlp_residual" ):
409+ with self . conditional_named_scope ("mlp_residual" ):
393410 hidden_states = (hidden_states .astype (jnp .float32 ) + ff_output .astype (jnp .float32 ) * c_gate_msa ).astype (
394411 hidden_states .dtype
395412 )
@@ -432,11 +449,13 @@ def __init__(
432449 names_which_can_be_offloaded : list = [],
433450 mask_padding_tokens : bool = True ,
434451 scan_layers : bool = True ,
452+ enable_jax_named_scopes : bool = False ,
435453 ):
436454 inner_dim = num_attention_heads * attention_head_dim
437455 out_channels = out_channels or in_channels
438456 self .num_layers = num_layers
439457 self .scan_layers = scan_layers
458+ self .enable_jax_named_scopes = enable_jax_named_scopes
440459
441460 # 1. Patch & position embedding
442461 self .rope = WanRotaryPosEmbed (attention_head_dim , patch_size , rope_max_seq_len )
@@ -488,6 +507,7 @@ def init_block(rngs):
488507 attention = attention ,
489508 dropout = dropout ,
490509 mask_padding_tokens = mask_padding_tokens ,
510+ enable_jax_named_scopes = enable_jax_named_scopes ,
491511 )
492512
493513 self .gradient_checkpoint = GradientCheckpointType .from_str (remat_policy )
@@ -513,6 +533,7 @@ def init_block(rngs):
513533 weights_dtype = weights_dtype ,
514534 precision = precision ,
515535 attention = attention ,
536+ enable_jax_named_scopes = enable_jax_named_scopes ,
516537 )
517538 blocks .append (block )
518539 self .blocks = blocks
@@ -533,6 +554,10 @@ def init_block(rngs):
533554 kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , "embed" )),
534555 )
535556
557+ def conditional_named_scope (self , name : str ):
558+ """Return a JAX named scope if enabled, otherwise a null context."""
559+ return jax .named_scope (name ) if self .enable_jax_named_scopes else contextlib .nullcontext ()
560+
536561 def __call__ (
537562 self ,
538563 hidden_states : jax .Array ,
@@ -552,12 +577,12 @@ def __call__(
552577 post_patch_width = width // p_w
553578
554579 hidden_states = jnp .transpose (hidden_states , (0 , 2 , 3 , 4 , 1 ))
555- with jax . named_scope ("rotary_embedding" ):
580+ with self . conditional_named_scope ("rotary_embedding" ):
556581 rotary_emb = self .rope (hidden_states )
557- with jax . named_scope ("patch_embedding" ):
582+ with self . conditional_named_scope ("patch_embedding" ):
558583 hidden_states = self .patch_embedding (hidden_states )
559584 hidden_states = jax .lax .collapse (hidden_states , 1 , - 1 )
560- with jax . named_scope ("condition_embedder" ):
585+ with self . conditional_named_scope ("condition_embedder" ):
561586 temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = self .condition_embedder (
562587 timestep , encoder_hidden_states , encoder_hidden_states_image
563588 )
@@ -611,4 +636,4 @@ def layer_forward(hidden_states):
611636 hidden_states = jax .lax .collapse (hidden_states , 6 , None )
612637 hidden_states = jax .lax .collapse (hidden_states , 4 , 6 )
613638 hidden_states = jax .lax .collapse (hidden_states , 2 , 4 )
614- return hidden_states
639+ return hidden_states
0 commit comments