-
Hello, I have two questions about how to clone things in
But why does this implement a deepcopy? Is it in Thanks in advance for your help! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Hey @JINKEHE
state = jax.tree.map(lambda x: x, state) # clone
|
Beta Was this translation helpful? Give feedback.
-
Thank you! |
Beta Was this translation helpful? Give feedback.
Hey @JINKEHE
State
is a Pytree you can clone it usingjax.tree.map
:deepcopy
but do define a complete traversal of the object graph for all NNX objects and JAX pytrees. Think ofsplit
asjax.tree.flatten
andmerge
asjax.tree.unflatten
. In fact, NNX also hasnnx.graph.flatten
andnnx.graph.unflatten
which are used by split and merge under the hood.