RNN easy to get wrong? (with MWE) #4517
Unanswered
JoaoAparicio
asked this question in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Problem: if I pass an
LSTMCell
an input with 2 dimensions, it trivially broadcasts over the outer dimension (and so bothc
andh
have 2 dimensions). I expectedRNN
to do the same, and it can, but it was really easy to get it wrong and unnoticed.Example. Input has 2 dimensions. The inner (feature) dimension is size 4 and the outer (broadcast) dimension is size 2. Hidden dimension is size 7. In a scenario where LSTM broadcasts trivially I would expect the hidden dimension to have shape
(2,7)
, and that's the case.So far everything is as expected. Importantly, note that the second position of the broadcast dimension is all zeros, that's a sign broadcast is happening correctly.
Now lets see
RNN
. Take the same input, repeat it in the time dimension 3 times, so shape(time_steps, lstm_broadcast, features)
.Run
Now the presence of the broadcast dimension mixes the features. And furthermore the shape of
c
is wrong. It's (3,7) should be (2, 7). Something internally isn't aware that we would like the second dimension to be an LSTM broadcast.Ah, but
RNN
has a major_time=True feature so lets try that:Now the LSTM dimensions aren't mixed and the internal state has the right shape. Great.
Now what if we want to make this work with batch dimensions?
(batch, time, lstm_broadcast, features)
. The problem now is that time is not longer the outer dimension. I did manage to make this work by stating that time is the outer dimension, but then wrapping invmap
for the batch dimension. However, it would be much nicer to have a solution like #4507 (comment) without having to manually vmap. Such solution does "just work" in the situation where LSTM takes only 1 dimension.One possibility could be that instead of passing
time_major: bool
we could just passtime_axis: int
. Currently this is inferred internally https://github.com/google/flax/blob/881685c75ebbfb44bf15c0161b11b4eedcfc455d/flax/nnx/nn/recurrent.py#L649C5-L649C14.Beta Was this translation helpful? Give feedback.
All reactions