RNN easy to get wrong? (with MWE) #4517
Unanswered
JoaoAparicio
asked this question in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Problem: if I pass an
LSTMCell
an input with 2 dimensions, it trivially broadcasts over the outer dimension (and so bothc
andh
have 2 dimensions). I expectedRNN
to do the same, and it can, but it was really easy to get it wrong and unnoticed.Example. Input has 2 dimensions. The inner (feature) dimension is size 4 and the outer (broadcast) dimension is size 2. Hidden dimension is size 7. In a scenario where LSTM broadcasts trivially I would expect the hidden dimension to have shape
(2,7)
, and that's the case.So far everything is as expected. Importantly, note that the second position of the broadcast dimension is all zeros, that's a sign broadcast is happening correctly.
Now lets see
RNN
. Take the same input, repeat it in the time dimension 3 times, so shape(time_steps, lstm_broadcast, features)
.Run
Now the presence of the broadcast dimension mixes the features. And furthermore the shape of
c
is wrong. It's (3,7) should be (2, 7). Something internally isn't aware that we would like the second dimension to be an LSTM broadcast.Ah, but
RNN
has a major_time=True feature so lets try that:Now the LSTM dimensions aren't mixed and the internal state has the right shape. Great.
Now what if we want to make this work with batch dimensions?
(batch, time, lstm_broadcast, features)
. The problem now is that time is not longer the outer dimension. I did manage to make this work by stating that time is the outer dimension, but then wrapping invmap
for the batch dimension. However, it would be much nicer to have a solution like #4507 (comment) without having to manually vmap. Such solution does "just work" in the situation where LSTM takes only 1 dimension.One possibility could be that instead of passing
time_major: bool
we could just passtime_axis: int
. Currently this is inferred internally https://github.com/google/flax/blob/881685c75ebbfb44bf15c0161b11b4eedcfc455d/flax/nnx/nn/recurrent.py#L649C5-L649C14.Beta Was this translation helpful? Give feedback.
All reactions