How to use ravel_pytree in Flax/JAX? #1214
Answered
by
marcvanzee
marcvanzee
asked this question in
Q&A
-
Original question by @dulacarnold. I am trying a simple ravel, unravel and I'm losing the contents pt = {'foo': jnp.array([1,2,3]), 'bar': {'baz': jnp.array([5,6]), 'bork': jnp.array([7,8])}}
flat, unflatten = jax.flatten_util.ravel_pytree(pt)
unflatten(flat) and I get...
|
Beta Was this translation helpful? Give feedback.
Answered by
marcvanzee
Apr 8, 2021
Replies: 1 comment
-
You have to use this function with floats, with integers it will have this problem. |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
marcvanzee
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You have to use this function with floats, with integers it will have this problem.