Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use (jax) pytrees inside of nnx modules? #4497

Open
PhilipVinc opened this issue Jan 22, 2025 · 5 comments
Open

How to use (jax) pytrees inside of nnx modules? #4497

PhilipVinc opened this issue Jan 22, 2025 · 5 comments

Comments

@PhilipVinc
Copy link
Contributor

PhilipVinc commented Jan 22, 2025

I have some objects which are jax pytrees, and would like to store them inside of an nnx module. In general, I would like to have a way to easily tag them (or better, the arrays they have inside) as Params or non trainable Variables.

However, this does not seem to work out of the box as I get the error that ValueError: Arrays leaves are not supported, at 'pytree/0': 2.0 (see MWE below).

Is there a way to support this? Is there an easy way to wrap/unwrap all fields of field into Params or variables?

import jax
import jax.numpy as jnp
import flax.linen as nn
import flax.nnx as nnx
from typing import Any, Callable

# Define a JAX Pytree class
@jax.tree_util.register_pytree_node_class
class SimplePytree:
    def __init__(self, value: float):
        self.value = value

    def __mul__(self, other):
        return other * self.value

    # Register this class as a pytree
    def tree_flatten(self):
        return ([self.value], None)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children)

# Define the Flax nnx module
class SimpleModule(nnx.Module):
    pytree: SimplePytree

    def __init__(
        self, N, pt, rngs: nnx.Rngs, visible_bias: bool = True, param_dtype=complex
    ):
        self.linear = nnx.Linear(N, 1, param_dtype=param_dtype, rngs=rngs)
        self.pytree = pt

    def __call__(self, x):
        return self.linear(self.pytree * x)

# Instantiate the pytree and module
pytree = SimplePytree(jax.numpy.array(2.0))

net = SimpleModule(2, pytree, nnx.Rngs(0), visible_bias=True, param_dtype=complex)

x = jnp.ones((10, 2))

nnx.split(net)

raises

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/flax/nnx/graph.py:1290, in split(node, *filters)
   1219 def split(
   1220   node: A, *filters: filterlib.Filter
   1221 ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]:
   1222   """Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is
   1223   a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef
   1224   contains all the static information needed to reconstruct a ``Module`` graph, it is analogous
   (...)
   1288     filters are passed, a single ``State`` is returned.
   1289   """
-> 1290   graphdef, state = flatten(node)
   1291   states = _split_state(state, filters)
   1292   return graphdef, *states

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/flax/nnx/graph.py:404, in flatten(node, ref_index)
    402   ref_index = RefMap()
    403 flat_state: dict[PathParts, StateLeaf] = {}
--> 404 graphdef = _graph_flatten((), ref_index, flat_state, node)
    405 return graphdef, GraphState.from_flat_path(flat_state)

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/flax/nnx/graph.py:436, in _graph_flatten(path, ref_index, flat_state, node)
    434 for key, value in values:
    435   if is_node(value):
--> 436     nodedef = _graph_flatten((*path, key), ref_index, flat_state, value)
    437     subgraphs.append((key, nodedef))
    438   elif isinstance(value, Variable):

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/flax/nnx/graph.py:451, in _graph_flatten(path, ref_index, flat_state, node)
    449     if isinstance(value, (jax.Array, np.ndarray)):
    450       path_str = '/'.join(map(str, (*path, key)))
--> 451       raise ValueError(
    452           f'Arrays leaves are not supported, at {path_str!r}: {value}'
    453       )
    454     static_fields.append((key, value))
    456 nodedef = NodeDef.create(
    457   type=node_impl.type,
    458   index=index,
   (...)
    464   index_mapping=None,
    465 )

ValueError: Arrays leaves are not supported, at 'pytree/0': 2.0
@PhilipVinc
Copy link
Contributor Author

PhilipVinc commented Jan 22, 2025

So to say, is this the recommended way?

class SimpleModule(nnx.Module):
    pytree: SimplePytree

    def __init__(
        self, N, pt, rngs: nnx.Rngs, visible_bias: bool = True, param_dtype=complex
    ):
        self.linear = nnx.Linear(N, 1, param_dtype=param_dtype, rngs=rngs)
        self.pytree = jax.tree.map(nnx.Variable, pt)

    def __call__(self, x):
        pt = jax.tree.map(lambda x:x.value, self.pytree)
        return self.linear(pt * x)

or is there some better approach?

I'm not sure I love this because it breaks SimpleModule.pytree which won't work anymore by default now.

@cgarciae
Copy link
Collaborator

Hi @PhilipVinc.

I'm not sure I love this because it breaks SimpleModule.pytree which won't work anymore by default now.

Can you clarify what you mean by this? Variable overloads all operators and implements __jax_array__ so wrapping Arrays tends to work. Can we do something to make your use case work?

@PhilipVinc
Copy link
Contributor Author

Well, it breaks any isinstance check for example?

@cgarciae
Copy link
Collaborator

I see. That is probably a case we don't want to support.

@mishavanbeek
Copy link

Is there anything wrong with just replacing self.pytree = jax.tree.map(nnx.Variable, pt) by self.pytree = nnx.Variable(pt)?

On a related note, I find that the object no longer prints if basic zero-copy casting is done per JAX directives. The following modification of the code results in an error.

import jax
import jax.numpy as jnp
import flax.linen as nn
import flax.nnx as nnx
from typing import Any, Callable

# Define a JAX Pytree class
@jax.tree_util.register_pytree_node_class
class SimplePytree:
    def __init__(self, value: ArrayLike):
        self.value = jnp.asarray(value)  # we do casting here from ArrayLike to Array

    def __mul__(self, other):
        return other * self.value

    # Register this class as a pytree
    def tree_flatten(self):
        return ([self.value], None)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children)

# Define the Flax nnx module
class SimpleModule(nnx.Module):
    pytree: SimplePytree

    def __init__(
        self, N, pt, rngs: nnx.Rngs, visible_bias: bool = True, param_dtype=complex
    ):
        self.linear = nnx.Linear(N, 1, param_dtype=param_dtype, rngs=rngs)
        self.pytree = nnx.Variable(pt)  # same for the original version

    def __call__(self, x):
        return self.linear(self.pytree * x)

# Instantiate the pytree and module
pytree = SimplePytree(jax.numpy.array([2.0, 2.0]))  # print fails when array is larger

net = SimpleModule(2, pytree, nnx.Rngs(0), visible_bias=True, param_dtype=complex)
net  # printing results in an error

raises

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
~/Work/projects/bayesline/code/lib/models/.venv/lib/python3.11/site-packages/flax/nnx/reprlib.py in ?(self)
    164     REPR_CONTEXT.current_color = NO_COLOR
    165     try:
    166       return get_repr(self)
    167     finally:
--> 168       REPR_CONTEXT.current_color = current_color

~/Work/projects/bayesline/code/lib/models/.venv/lib/python3.11/site-packages/flax/nnx/reprlib.py in ?(obj)
    194     indent = '' if config.same_line else config.indent
    195 
    196     return f'{indent}{elem.start}{key}{kv_sep}{value_repr}{elem.end}'
    197 
--> 198   elems = config.elem_sep.join(map(_repr_elem, iterator))
    199 
    200   if elems:
    201     if config.same_line:

~/Work/projects/bayesline/code/lib/models/.venv/lib/python3.11/site-packages/flax/nnx/object.py in ?(self)
    245     finally:
    246       if clear_seen:
    247         OBJECT_CONTEXT.seen_modules_repr = None
    248       if clear_node_stats:
--> 249         OBJECT_CONTEXT.node_stats = None

~/Work/projects/bayesline/code/lib/models/.venv/lib/python3.11/site-packages/jax/_src/tree.py in ?(f, tree, is_leaf, *rest)
    151   See Also:
    152     - :func:`jax.tree.leaves`
    153     - :func:`jax.tree.reduce`
    154   """
--> 155   return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)

~/Work/projects/bayesline/code/lib/models/.venv/lib/python3.11/site-packages/jax/_src/tree_util.py in ?(f, tree, is_leaf, *rest)
    354              is_leaf: Callable[[Any], bool] | None = None) -> Any:
    355   """Alias of :func:`jax.tree.map`."""
    356   leaves, treedef = tree_flatten(tree, is_leaf)
    357   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 358   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

~/Work/projects/bayesline/code/lib/models/.venv/lib/python3.11/site-packages/jax/_src/tree_util.py in ?(.0)
--> 358 def tree_map(f: Callable[..., Any],
    359              tree: Any,
    360              *rest: Any,
    361              is_leaf: Callable[[Any], bool] | None = None) -> Any:

~/Work/projects/bayesline/code/lib/models/.venv/lib/python3.11/site-packages/flax/nnx/object.py in ?(value)
    231         def to_shape_dtype(value):
    232           if isinstance(value, Variable):
    233             return value.replace(
--> 234               raw_value=jax.tree.map(to_shape_dtype, value.raw_value)
    235             )
    236           elif (
    237             isinstance(value, (np.ndarray, jax.Array))

~/Work/projects/bayesline/code/lib/models/.venv/lib/python3.11/site-packages/jax/_src/tree.py in ?(f, tree, is_leaf, *rest)
    151   See Also:
    152     - :func:`jax.tree.leaves`
    153     - :func:`jax.tree.reduce`
    154   """
--> 155   return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)

~/Work/projects/bayesline/code/lib/models/.venv/lib/python3.11/site-packages/jax/_src/tree_util.py in ?(f, tree, is_leaf, *rest)
    354              is_leaf: Callable[[Any], bool] | None = None) -> Any:
    355   """Alias of :func:`jax.tree.map`."""
    356   leaves, treedef = tree_flatten(tree, is_leaf)
    357   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 358   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

/var/folders/wv/kzkrkcld6ddd550zpw90hm_40000gn/T/ipykernel_31800/3505462009.py in ?(cls, aux_data, children)
     10     def tree_unflatten(cls, aux_data, children):
     11         # self = cls.__new__(cls)
     12         # self.x, = children  # bypass casting since it breaks flax module printing
     13         # return self
---> 14         return cls(*children)

/var/folders/wv/kzkrkcld6ddd550zpw90hm_40000gn/T/ipykernel_31800/3505462009.py in ?(self, x)
      3     def __init__(self, x: ArrayLike):
----> 4         self.x = jnp.asarray(x)

~/Work/projects/bayesline/code/lib/models/.venv/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py in ?(a, dtype, order, copy, device)
   5728                       "Consider using copy=None or copy=True instead.")
   5729   dtypes.check_user_dtype_supported(dtype, "asarray")
   5730   if dtype is not None:
   5731     dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)  # type: ignore[assignment]
-> 5732   return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)

~/Work/projects/bayesline/code/lib/models/.venv/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py in ?(object, dtype, copy, order, ndmin, device)
   5547     # TODO(jakevdp): falling back to numpy here fails to overflow for lists
   5548     # containing large integers; see discussion in
   5549     # https://github.com/jax-ml/jax/pull/6047. More correct would be to call
   5550     # coerce_to_array on each leaf, but this may have performance implications.
-> 5551     out = np.asarray(object, dtype=dtype)
   5552   elif isinstance(object, Array):
   5553     assert object.aval is not None
   5554     out = _array_copy(object) if copy else object

TypeError: float() argument must be a string or a real number, not 'Array'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants