Skip to content

How to clone a state in flax.nnx? #4467

Closed Answered by cgarciae
JINKEHE asked this question in Q&A
Discussion options

You must be logged in to vote

Hey @JINKEHE

  1. Since State is a Pytree you can clone it using jax.tree.map:
state = jax.tree.map(lambda x: x, state) # clone
  1. split and merge don't use python's deepcopy but do define a complete traversal of the object graph for all NNX objects and JAX pytrees. Think of split as jax.tree.flatten and merge as jax.tree.unflatten. In fact, NNX also has nnx.graph.flatten and nnx.graph.unflatten which are used by split and merge under the hood.

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by JINKEHE
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants