-
Notifications
You must be signed in to change notification settings - Fork 679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to use (jax) pytrees inside of nnx modules? #4497
Comments
So to say, is this the recommended way? class SimpleModule(nnx.Module):
pytree: SimplePytree
def __init__(
self, N, pt, rngs: nnx.Rngs, visible_bias: bool = True, param_dtype=complex
):
self.linear = nnx.Linear(N, 1, param_dtype=param_dtype, rngs=rngs)
self.pytree = jax.tree.map(nnx.Variable, pt)
def __call__(self, x):
pt = jax.tree.map(lambda x:x.value, self.pytree)
return self.linear(pt * x) or is there some better approach? I'm not sure I love this because it breaks |
Hi @PhilipVinc.
Can you clarify what you mean by this? |
Well, it breaks any |
I see. That is probably a case we don't want to support. |
Is there anything wrong with just replacing On a related note, I find that the object no longer prints if basic zero-copy casting is done per JAX directives. The following modification of the code results in an error. import jax
import jax.numpy as jnp
import flax.linen as nn
import flax.nnx as nnx
from typing import Any, Callable
# Define a JAX Pytree class
@jax.tree_util.register_pytree_node_class
class SimplePytree:
def __init__(self, value: ArrayLike):
self.value = jnp.asarray(value) # we do casting here from ArrayLike to Array
def __mul__(self, other):
return other * self.value
# Register this class as a pytree
def tree_flatten(self):
return ([self.value], None)
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
# Define the Flax nnx module
class SimpleModule(nnx.Module):
pytree: SimplePytree
def __init__(
self, N, pt, rngs: nnx.Rngs, visible_bias: bool = True, param_dtype=complex
):
self.linear = nnx.Linear(N, 1, param_dtype=param_dtype, rngs=rngs)
self.pytree = nnx.Variable(pt) # same for the original version
def __call__(self, x):
return self.linear(self.pytree * x)
# Instantiate the pytree and module
pytree = SimplePytree(jax.numpy.array([2.0, 2.0])) # print fails when array is larger
net = SimpleModule(2, pytree, nnx.Rngs(0), visible_bias=True, param_dtype=complex)
net # printing results in an error raises
|
I have some objects which are jax pytrees, and would like to store them inside of an nnx module. In general, I would like to have a way to easily tag them (or better, the arrays they have inside) as
Params
or non trainableVariables
.However, this does not seem to work out of the box as I get the error that
ValueError: Arrays leaves are not supported, at 'pytree/0': 2.0
(see MWE below).Is there a way to support this? Is there an easy way to wrap/unwrap all fields of field into Params or variables?
raises
The text was updated successfully, but these errors were encountered: