Question about RNG state in NNX checkpointing with Orbax #4514
Unanswered
conorhassan
asked this question in
Q&A
Replies: 1 comment 10 replies
-
Hey this is an known orbax issue, see google/orbax#1105. |
Beta Was this translation helpful? Give feedback.
10 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
When attempting to save NNX models using Orbax checkpointing, certain models containing RNG state (from Dropout/MultiHeadAttention) failed with:
A simple MLP with dropout works fine with standard checkpointing:
However, more complex models like this transformer implementation fail:
Full implementation of the failing model available here: https://github.com/conorhassan/tnp/blob/main/tnp/models/nnx_models/layers.py
The solution that I found was
Example of what the
State
that fails looks like:Questions
Thanks for any help or thoughts!
Beta Was this translation helpful? Give feedback.
All reactions