Rotary embeddings not working because of dynamic shapes #3294
Unanswered
faresobeid
asked this question in
Q&A
Replies: 1 comment 2 replies
-
Hey @faresobeid, take a look at this implementation of rotary embeddings: https://github.com/google/flax/blob/nnx/flax/experimental/nnx/examples/07_transformer.py#L131-L157 |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I've tried to implement the RetNet model proposed in this paper: https://arxiv.org/pdf/2307.08621v4.pdf. Everything has worked well so far except for the XPOS part (or rotary embedding). In my code and other implementations, for the recurrent part of the model used for generating samples, the current timestep of the character has to be passed in and in the computation of the positional embeddings, an array based on that timestep is made.
Where i is the current timestep.
And here is the sampling function:
This works fine without any errors by using a simple for loop for sampling, however it is extremely slow. However, when switching over to
jax.lax.scan()
orjax.lax.fori_loop()
I get this error:I tried my way around this by manipulating the scan function but nothing has worked. Is there any way to do this without harming performance drastically. My guess is to not include the XPOS in the jit because it includes dynamic shapes. Surely theres a way to just include a 'jnp.arange(0,seq_len)' in my computation where seq_len changes.
Beta Was this translation helpful? Give feedback.
All reactions