Skip to content

Commit c31048f

Browse files
author
eltsai
committed
Added flag to enable named_scope
1 parent cda888d commit c31048f

File tree

4 files changed

+66
-29
lines changed

4 files changed

+66
-29
lines changed

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,10 @@ enable_profiler: False
284284
skip_first_n_steps_for_profiler: 5
285285
profiler_steps: 10
286286

287+
# Enable JAX named scopes for detailed profiling and debugging
288+
# When enabled, adds named scopes around key operations in transformer and attention layers
289+
enable_jax_named_scopes: False
290+
287291
# Generation parameters
288292
prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
289293
prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."

src/maxdiffusion/models/attention_flax.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import contextlib
1516
import functools
1617
import math
1718
from typing import Optional, Callable, Tuple
@@ -805,6 +806,7 @@ def __init__(
805806
is_self_attention: bool = True,
806807
mask_padding_tokens: bool = True,
807808
residual_checkpoint_name: str | None = None,
809+
enable_jax_named_scopes: bool = False,
808810
):
809811
if attention_kernel == "cudnn_flash_te":
810812
raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}")
@@ -820,6 +822,7 @@ def __init__(
820822
self.key_axis_names = key_axis_names
821823
self.value_axis_names = value_axis_names
822824
self.out_axis_names = out_axis_names
825+
self.enable_jax_named_scopes = enable_jax_named_scopes
823826

824827
if is_self_attention:
825828
axis_names_q = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_Q_LENGTH, D_KV)
@@ -952,6 +955,10 @@ def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tup
952955

953956
return xq_out, xk_out
954957

958+
def conditional_named_scope(self, name: str):
959+
"""Return a JAX named scope if enabled, otherwise a null context."""
960+
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
961+
955962
def __call__(
956963
self,
957964
hidden_states: jax.Array,
@@ -965,7 +972,7 @@ def __call__(
965972
dtype = hidden_states.dtype
966973
if encoder_hidden_states is None:
967974
encoder_hidden_states = hidden_states
968-
975+
969976
with jax.named_scope("attn_qkv_proj"):
970977
with jax.named_scope("proj_query"):
971978
query_proj = self.query(hidden_states)
@@ -975,13 +982,13 @@ def __call__(
975982
value_proj = self.value(encoder_hidden_states)
976983

977984
if self.qk_norm:
978-
with jax.named_scope("attn_q_norm"):
985+
with self.conditional_named_scope("attn_q_norm"):
979986
query_proj = self.norm_q(query_proj)
980-
with jax.named_scope("attn_k_norm"):
987+
with self.conditional_named_scope("attn_k_norm"):
981988
key_proj = self.norm_k(key_proj)
982-
989+
983990
if rotary_emb is not None:
984-
with jax.named_scope("attn_rope"):
991+
with self.conditional_named_scope("attn_rope"):
985992
query_proj = _unflatten_heads(query_proj, self.heads)
986993
key_proj = _unflatten_heads(key_proj, self.heads)
987994
value_proj = _unflatten_heads(value_proj, self.heads)
@@ -991,14 +998,14 @@ def __call__(
991998
query_proj = checkpoint_name(query_proj, "query_proj")
992999
key_proj = checkpoint_name(key_proj, "key_proj")
9931000
value_proj = checkpoint_name(value_proj, "value_proj")
994-
995-
with jax.named_scope("attn_compute"):
1001+
1002+
with self.conditional_named_scope("attn_compute"):
9961003
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
9971004

9981005
attn_output = attn_output.astype(dtype=dtype)
9991006
attn_output = checkpoint_name(attn_output, "attn_output")
1000-
1001-
with jax.named_scope("attn_out_proj"):
1007+
1008+
with self.conditional_named_scope("attn_out_proj"):
10021009
hidden_states = self.proj_attn(attn_output)
10031010
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
10041011
return hidden_states

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""
1616

1717
from typing import Tuple, Optional, Dict, Union, Any
18+
import contextlib
1819
import math
1920
import jax
2021
import jax.numpy as jnp
@@ -205,11 +206,13 @@ def __init__(
205206
dtype: jnp.dtype = jnp.float32,
206207
weights_dtype: jnp.dtype = jnp.float32,
207208
precision: jax.lax.Precision = None,
209+
enable_jax_named_scopes: bool = False,
208210
):
209211
if inner_dim is None:
210212
inner_dim = int(dim * mult)
211213
dim_out = dim_out if dim_out is not None else dim
212214

215+
self.enable_jax_named_scopes = enable_jax_named_scopes
213216
self.act_fn = nnx.data(None)
214217
if activation_fn == "gelu-approximate":
215218
self.act_fn = ApproximateGELU(
@@ -236,12 +239,16 @@ def __init__(
236239
),
237240
)
238241

242+
def conditional_named_scope(self, name: str):
243+
"""Return a JAX named scope if enabled, otherwise a null context."""
244+
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
245+
239246
def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array:
240-
with jax.named_scope("mlp_up_proj_and_gelu"):
247+
with self.conditional_named_scope("mlp_up_proj_and_gelu"):
241248
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
242249
hidden_states = checkpoint_name(hidden_states, "ffn_activation")
243250
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
244-
with jax.named_scope("mlp_down_proj"):
251+
with self.conditional_named_scope("mlp_down_proj"):
245252
return self.proj_out(hidden_states) # output is (4, 75600, 5120)
246253

247254

@@ -267,8 +274,11 @@ def __init__(
267274
attention: str = "dot_product",
268275
dropout: float = 0.0,
269276
mask_padding_tokens: bool = True,
277+
enable_jax_named_scopes: bool = False,
270278
):
271279

280+
self.enable_jax_named_scopes = enable_jax_named_scopes
281+
272282
# 1. Self-attention
273283
self.norm1 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False)
274284
self.attn1 = FlaxWanAttention(
@@ -289,6 +299,7 @@ def __init__(
289299
is_self_attention=True,
290300
mask_padding_tokens=mask_padding_tokens,
291301
residual_checkpoint_name="self_attn",
302+
enable_jax_named_scopes=enable_jax_named_scopes,
292303
)
293304

294305
# 1. Cross-attention
@@ -310,6 +321,7 @@ def __init__(
310321
is_self_attention=False,
311322
mask_padding_tokens=mask_padding_tokens,
312323
residual_checkpoint_name="cross_attn",
324+
enable_jax_named_scopes=enable_jax_named_scopes,
313325
)
314326
assert cross_attn_norm is True
315327
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
@@ -324,6 +336,7 @@ def __init__(
324336
weights_dtype=weights_dtype,
325337
precision=precision,
326338
dropout=dropout,
339+
enable_jax_named_scopes=enable_jax_named_scopes,
327340
)
328341
self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False)
329342

@@ -332,6 +345,10 @@ def __init__(
332345
jax.random.normal(key, (1, 6, dim)) / dim**0.5,
333346
)
334347

348+
def conditional_named_scope(self, name: str):
349+
"""Return a JAX named scope if enabled, otherwise a null context."""
350+
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
351+
335352
def __call__(
336353
self,
337354
hidden_states: jax.Array,
@@ -341,8 +358,8 @@ def __call__(
341358
deterministic: bool = True,
342359
rngs: nnx.Rngs = None,
343360
):
344-
with jax.named_scope("transformer_block"):
345-
with jax.named_scope("adaln"):
361+
with self.conditional_named_scope("transformer_block"):
362+
with self.conditional_named_scope("adaln"):
346363
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
347364
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
348365
)
@@ -351,45 +368,45 @@ def __call__(
351368
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))
352369

353370
# 1. Self-attention
354-
with jax.named_scope("self_attn"):
355-
with jax.named_scope("self_attn_norm"):
371+
with self.conditional_named_scope("self_attn"):
372+
with self.conditional_named_scope("self_attn_norm"):
356373
norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
357374
hidden_states.dtype
358375
)
359-
with jax.named_scope("self_attn_attn"):
376+
with self.conditional_named_scope("self_attn_attn"):
360377
attn_output = self.attn1(
361378
hidden_states=norm_hidden_states,
362379
encoder_hidden_states=norm_hidden_states,
363380
rotary_emb=rotary_emb,
364381
deterministic=deterministic,
365382
rngs=rngs,
366383
)
367-
with jax.named_scope("self_attn_residual"):
384+
with self.conditional_named_scope("self_attn_residual"):
368385
hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype)
369386

370387
# 2. Cross-attention
371-
with jax.named_scope("cross_attn"):
372-
with jax.named_scope("cross_attn_norm"):
388+
with self.conditional_named_scope("cross_attn"):
389+
with self.conditional_named_scope("cross_attn_norm"):
373390
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype)
374-
with jax.named_scope("cross_attn_attn"):
391+
with self.conditional_named_scope("cross_attn_attn"):
375392
attn_output = self.attn2(
376393
hidden_states=norm_hidden_states,
377394
encoder_hidden_states=encoder_hidden_states,
378395
deterministic=deterministic,
379396
rngs=rngs,
380397
)
381-
with jax.named_scope("cross_attn_residual"):
398+
with self.conditional_named_scope("cross_attn_residual"):
382399
hidden_states = hidden_states + attn_output
383400

384401
# 3. Feed-forward
385-
with jax.named_scope("mlp"):
386-
with jax.named_scope("mlp_norm"):
402+
with self.conditional_named_scope("mlp"):
403+
with self.conditional_named_scope("mlp_norm"):
387404
norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(
388405
hidden_states.dtype
389406
)
390-
with jax.named_scope("mlp_ffn"):
407+
with self.conditional_named_scope("mlp_ffn"):
391408
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
392-
with jax.named_scope("mlp_residual"):
409+
with self.conditional_named_scope("mlp_residual"):
393410
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(
394411
hidden_states.dtype
395412
)
@@ -432,11 +449,13 @@ def __init__(
432449
names_which_can_be_offloaded: list = [],
433450
mask_padding_tokens: bool = True,
434451
scan_layers: bool = True,
452+
enable_jax_named_scopes: bool = False,
435453
):
436454
inner_dim = num_attention_heads * attention_head_dim
437455
out_channels = out_channels or in_channels
438456
self.num_layers = num_layers
439457
self.scan_layers = scan_layers
458+
self.enable_jax_named_scopes = enable_jax_named_scopes
440459

441460
# 1. Patch & position embedding
442461
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
@@ -488,6 +507,7 @@ def init_block(rngs):
488507
attention=attention,
489508
dropout=dropout,
490509
mask_padding_tokens=mask_padding_tokens,
510+
enable_jax_named_scopes=enable_jax_named_scopes,
491511
)
492512

493513
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
@@ -513,6 +533,7 @@ def init_block(rngs):
513533
weights_dtype=weights_dtype,
514534
precision=precision,
515535
attention=attention,
536+
enable_jax_named_scopes=enable_jax_named_scopes,
516537
)
517538
blocks.append(block)
518539
self.blocks = blocks
@@ -533,6 +554,10 @@ def init_block(rngs):
533554
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed")),
534555
)
535556

557+
def conditional_named_scope(self, name: str):
558+
"""Return a JAX named scope if enabled, otherwise a null context."""
559+
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
560+
536561
def __call__(
537562
self,
538563
hidden_states: jax.Array,
@@ -552,12 +577,12 @@ def __call__(
552577
post_patch_width = width // p_w
553578

554579
hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1))
555-
with jax.named_scope("rotary_embedding"):
580+
with self.conditional_named_scope("rotary_embedding"):
556581
rotary_emb = self.rope(hidden_states)
557-
with jax.named_scope("patch_embedding"):
582+
with self.conditional_named_scope("patch_embedding"):
558583
hidden_states = self.patch_embedding(hidden_states)
559584
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
560-
with jax.named_scope("condition_embedder"):
585+
with self.conditional_named_scope("condition_embedder"):
561586
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
562587
timestep, encoder_hidden_states, encoder_hidden_states_image
563588
)
@@ -611,4 +636,4 @@ def layer_forward(hidden_states):
611636
hidden_states = jax.lax.collapse(hidden_states, 6, None)
612637
hidden_states = jax.lax.collapse(hidden_states, 4, 6)
613638
hidden_states = jax.lax.collapse(hidden_states, 2, 4)
614-
return hidden_states
639+
return hidden_states

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
114114
wan_config["dropout"] = config.dropout
115115
wan_config["mask_padding_tokens"] = config.mask_padding_tokens
116116
wan_config["scan_layers"] = config.scan_layers
117+
wan_config["enable_jax_named_scopes"] = config.enable_jax_named_scopes
117118

118119
# 2. eval_shape - will not use flops or create weights on device
119120
# thus not using HBM memory.

0 commit comments

Comments
 (0)