We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c31048f commit f28ef83Copy full SHA for f28ef83
src/maxdiffusion/models/attention_flax.py
@@ -973,12 +973,12 @@ def __call__(
973
if encoder_hidden_states is None:
974
encoder_hidden_states = hidden_states
975
976
- with jax.named_scope("attn_qkv_proj"):
977
- with jax.named_scope("proj_query"):
+ with self.conditional_named_scope("attn_qkv_proj"):
+ with self.conditional_named_scope("proj_query"):
978
query_proj = self.query(hidden_states)
979
- with jax.named_scope("proj_key"):
+ with self.conditional_named_scope("proj_key"):
980
key_proj = self.key(encoder_hidden_states)
981
- with jax.named_scope("proj_value"):
+ with self.conditional_named_scope("proj_value"):
982
value_proj = self.value(encoder_hidden_states)
983
984
if self.qk_norm:
0 commit comments