Skip to content

Conversation

@eltsai
Copy link
Collaborator

@eltsai eltsai commented Nov 20, 2025

Added Name Scopes for WAN 2.1 XProf TraceView

Added jax.named_scope for the following componenets in WAN 2.1 DiT when the enable_jax_named_scopes flag is set as True (default as False). Run scripts:

export RUN_NAME=${USR_NAME}-wan21-8tpu-namedscope-enabled
export LIBTPU_VERSION=libtpu-0.0.25.dev20251013+tpu7x-cp312-cp312-manylinux_2_31_x86_64.whl
export YOUR_GCS_BUCKET=gs://${USR_NAME}-wan-maxdiffusion

export OUTPUT_DIR=${YOUR_GCS_BUCKET}/wan/${RUN_NAME}
export DATASET_DIR=${YOUR_GCS_BUCKET}/wan_tfr_dataset_pusa_v1/train/
export EVAL_DATA_DIR=${YOUR_GCS_BUCKET}/wan_tfr_dataset_pusa_v1/eval_timesteps/
export SAVE_DATASET_DIR=${YOUR_GCS_BUCKET}/wan_tfr_dataset_pusa_v1/save/

export RANDOM=123456789
export IMAGE_DIR=gcr.io/tpu-prod-env-multipod/maxdiffusion_jax_stable_stack_nightly:2025-10-27
export LIBTPU_VERSION=libtpu-0.0.25.dev20251013+tpu7x-cp312-cp312-manylinux_2_31_x86_64.whl

export HUGGINGFACE_HUB_CACHE=/dev/shm

echo 'Starting WAN inference ...' && \
python src/maxdiffusion/generate_wan.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  enable_jax_named_scopes=True \. # <==============set as True, default=False
  attention='flash' \
  weights_dtype=bfloat16 \
  activations_dtype=bfloat16 \
  guidance_scale=5.0 \
  flow_shift=3.0 \
  fps=24 \
  skip_jax_distributed_system=True \
  run_name='test-wan-training-new' \
  output_dir=${OUTPUT_DIR} \
  load_tfrecord_cached=True \
  height=720 \
  width=1280 \
  num_frames=81 \
  num_inference_steps=50 \
  prompt='a japanese pop star young woman with black hair is singing with a smile. She is inside a studio with dim lighting and musical instruments.' \
  negative_prompt='low quality, over exposure.' \
  jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \
  max_train_steps=20000 \
  enable_profiler=True \
  dataset_save_location=${SAVE_DATASET_DIR} \
  remat_policy='FULL' \
  flash_min_seq_length=0 \
  seed=$RANDOM \
  skip_first_n_steps_for_profiler=3 \
  profiler_steps=3 \
  per_device_batch_size=0.125 \
  allow_split_physical_axes=True \
  ici_data_parallelism=2 \
  ici_fsdp_parallelism=2 \
  ici_tensor_parallelism=2

Log with enable_jax_named_scopes=True, XProf
Log with enable_jax_named_scopes=False, XProf

Below are the added named scopes:

  1. src/maxdiffusion/models/attention_flax.py:
attention_flax
├── attn_qkv_proj
│   ├── proj_query
│   ├── proj_key
│   └── proj_value
├── attn_q_norm 
├── attn_k_norm 
├── attn_rope 
├── attn_compute
└── attn_out_proj
  1. src/maxdiffusion/models/wan/transformers/transformer_wan.py
mlp_block
├── mlp_up_proj_and_gelu 
├── mlp_block
│ 
transformer_block
├── adaln 
├── self_attn
│   ├── self_attn_norm 
│   ├── self_attn_attn
│   └── self_attn_residual 
├── cross_attn
│   ├── cross_attn_norm
│   ├── cross_attn_attn
│   └── cross_attn_residual
└── mlp
    ├── mlp_norm 
    ├── mlp_ffn
    └── mlp_residual 

So that we can have better observability for XProf TraceView (example XProf):
image

Comparing to the old one (example XProf):
image

Testing

Function

  • Tested this with multiple configurations (on commit hash f5f212f2d15cf55de5b4007b180da00f91a746d2): row 176 to row 184 in the benchmark doc.
  • Tested on head commit. XProf
  • Fusion Integrity: Checked two traces with (DP, FSDP, TP)=(2, 8, 1):

@github-actions
Copy link

@eltsai eltsai self-assigned this Nov 20, 2025
@eltsai eltsai added the enhancement New feature or request label Nov 20, 2025
coolkp
coolkp previously approved these changes Nov 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants