Skip to content

Commit f28ef83

Browse files
author
eltsai
committed
Make named_scope in flash_attention respect enable_jax_named_scopes flag
1 parent c31048f commit f28ef83

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/maxdiffusion/models/attention_flax.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -973,12 +973,12 @@ def __call__(
973973
if encoder_hidden_states is None:
974974
encoder_hidden_states = hidden_states
975975

976-
with jax.named_scope("attn_qkv_proj"):
977-
with jax.named_scope("proj_query"):
976+
with self.conditional_named_scope("attn_qkv_proj"):
977+
with self.conditional_named_scope("proj_query"):
978978
query_proj = self.query(hidden_states)
979-
with jax.named_scope("proj_key"):
979+
with self.conditional_named_scope("proj_key"):
980980
key_proj = self.key(encoder_hidden_states)
981-
with jax.named_scope("proj_value"):
981+
with self.conditional_named_scope("proj_value"):
982982
value_proj = self.value(encoder_hidden_states)
983983

984984
if self.qk_norm:

0 commit comments

Comments
 (0)