-
I was trying to gain some speed-ups with JIT for autoregressive decoding. import jax
from flax.nnx import MultiHeadAttention, Rngs, jit
rngs = Rngs(0)
attn = MultiHeadAttention(8, 512, rngs=rngs)
jited_attn = jit(attn)
attn.init_cache((1, 100, 512))
for i in range(1, 100):
test_input = jax.random.uniform(jax.random.PRNGKey(0), (1, 1, 512))
resp = jited_attn(test_input, decode=True)
print("Itering {i}".format(i=i)) This fails with error:
Autoregressive decoding without JIT import jax
from flax.nnx import MultiHeadAttention, Rngs, jit
rngs = Rngs(0)
attn = MultiHeadAttention(8, 512, rngs=rngs)
attn.init_cache((1, 100, 512))
for i in range(1, 100):
test_input = jax.random.uniform(jax.random.PRNGKey(0), (1, 1, 512))
resp = attn(test_input, decode=True)
print("Itering {i}".format(i=i)) works well. in my understanding, a jax JIT function should be a pure function. however, involving a KV-cache inside the module makes it impure. I wonder if this architecture was intentional or if I'm missing something. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hey @Neulus, because of how instance methods work in python you should not transform them as import jax
from flax import nnx
rngs = nnx.Rngs(0)
attn = nnx.MultiHeadAttention(8, 512, rngs=rngs)
attn.init_cache((1, 100, 512))
attn.set_attributes(decode=True)
@nnx.jit
def forward(attn, inputs):
return attn(inputs)
for i in range(4):
test_input = jax.random.uniform(jax.random.key(i), (1, 1, 512))
resp = forward(attn, test_input)
print(attn.cached_key.value[0, :4, 0, :5]) # watch cache grow |
Beta Was this translation helpful? Give feedback.
Hey @Neulus, because of how instance methods work in python you should not transform them as
self
will be passed as a capture. Instead transform a regular function with the Module as an explicit input. Also, you can use set_attributes to recursively set Module propierties likedecode
. Here is a working example: