Skip to content

is MultiHeadAttention in flax.nnx JIT-compilable with the decode=True parameter? #4523

Answered by cgarciae
Neulus asked this question in Q&A
Discussion options

You must be logged in to vote

Hey @Neulus, because of how instance methods work in python you should not transform them as self will be passed as a capture. Instead transform a regular function with the Module as an explicit input. Also, you can use set_attributes to recursively set Module propierties like decode. Here is a working example:

import jax
from flax import nnx

rngs = nnx.Rngs(0)
attn = nnx.MultiHeadAttention(8, 512, rngs=rngs)

attn.init_cache((1, 100, 512))
attn.set_attributes(decode=True)


@nnx.jit
def forward(attn, inputs):
  return attn(inputs)


for i in range(4):
  test_input = jax.random.uniform(jax.random.key(i), (1, 1, 512))
  resp = forward(attn, test_input)
  print(attn.cached_key.value[0, :4, 0

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@Neulus
Comment options

Answer selected by Neulus
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants