Skip to content

Commit 71025a8

Browse files
author
eltsai
committed
Fix linter error
1 parent c2e0f15 commit 71025a8

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/maxdiffusion/models/attention_flax.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,7 @@ def __call__(
850850
dtype = hidden_states.dtype
851851
if encoder_hidden_states is None:
852852
encoder_hidden_states = hidden_states
853-
853+
854854
with jax.named_scope("attn_qkv_proj"):
855855
with jax.named_scope("proj_query"):
856856
query_proj = self.query(hidden_states)
@@ -875,13 +875,13 @@ def __call__(
875875
query_proj = checkpoint_name(query_proj, "query_proj")
876876
key_proj = checkpoint_name(key_proj, "key_proj")
877877
value_proj = checkpoint_name(value_proj, "value_proj")
878-
878+
879879
with jax.named_scope("attn_compute"):
880880
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
881881

882882
attn_output = attn_output.astype(dtype=dtype)
883883
attn_output = checkpoint_name(attn_output, "attn_output")
884-
884+
885885
with jax.named_scope("attn_out_proj"):
886886
hidden_states = self.proj_attn(attn_output)
887887
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ def __call__(
538538

539539
hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1))
540540
rotary_emb = self.rope(hidden_states)
541-
541+
542542
with jax.named_scope("patch_embedding"):
543543
hidden_states = self.patch_embedding(hidden_states)
544544
hidden_states = jax.lax.collapse(hidden_states, 1, -1)

0 commit comments

Comments
 (0)