Skip to content

Commit 496875d

Browse files
author
eltsai
committed
Added named scope for WanModel output
1 parent f28ef83 commit 496875d

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -625,9 +625,10 @@ def layer_forward(hidden_states):
625625
hidden_states = rematted_layer_forward(hidden_states)
626626

627627
shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1)
628-
629-
hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)
630-
hidden_states = self.proj_out(hidden_states)
628+
with self.conditional_named_scope("output_norm"):
629+
hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)
630+
with self.conditional_named_scope("output_proj"):
631+
hidden_states = self.proj_out(hidden_states)
631632

632633
hidden_states = hidden_states.reshape(
633634
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1

0 commit comments

Comments
 (0)