From 010022369ce0086eb30d34eb6862bb3e174412dd Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Fri, 7 Mar 2025 17:46:28 -0800 Subject: [PATCH] [nnx] add support for standalone Variables --- flax/nnx/bridge/wrappers.py | 2 +- flax/nnx/extract.py | 27 ++- flax/nnx/graph.py | 287 +++++++++++++++++++------------ flax/nnx/module.py | 4 +- flax/nnx/statelib.py | 4 + flax/nnx/transforms/autodiff.py | 14 +- flax/nnx/transforms/iteration.py | 6 +- tests/nnx/graph_utils_test.py | 125 ++++++++++++++ tests/nnx/transforms_test.py | 29 ++++ 9 files changed, 369 insertions(+), 129 deletions(-) diff --git a/flax/nnx/bridge/wrappers.py b/flax/nnx/bridge/wrappers.py index 7fd651062..2e074b699 100644 --- a/flax/nnx/bridge/wrappers.py +++ b/flax/nnx/bridge/wrappers.py @@ -49,7 +49,7 @@ def init(self, *, rngs: tp.Optional[Rngs] = None) -> State: graphdef, state = nnx.split(module) assert type(graphdef) is graph.NodeDef self.graphdef = graphdef - return state + return state # type: ignore def apply(self, *states: tp.Any): assert self.graphdef is not None diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index 93aba960f..26e20077d 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -127,8 +127,8 @@ def map_prefix( ) -> tp.Any: ... def check_consistent_aliasing( - node: tuple[tp.Any, ...], - prefix: tuple[tp.Any, ...], + node: tp.Any, + prefix: tp.Any, /, *, node_prefixes: dict[tp.Any, list[tuple[PathParts, tp.Any]]] | None = None, @@ -279,7 +279,9 @@ def to_tree( with graph.split_context(ctxtag) as split_ctx: return jax.tree.map( lambda x: split_fn(split_ctx, (), prefix, x) - if map_non_graph_nodes or graph.is_graph_node(x) + if map_non_graph_nodes + or graph.is_graph_node(x) + or isinstance(x, variablelib.Variable) else x, tree, ) @@ -296,7 +298,7 @@ def to_tree( with graph.split_context(ctxtag) as split_ctx: for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes): - if graph.is_graph_node(leaf): + if graph.is_graph_node(leaf) or isinstance(leaf, variablelib.Variable): if check_aliasing: check_consistent_aliasing( leaf, leaf_prefix, node_prefixes=node_prefixes @@ -343,7 +345,9 @@ def from_tree( with graph.merge_context(is_inner, ctxtag) as merge_ctx: return jax.tree.map( lambda x: merge_fn(merge_ctx, (), prefix, x) - if map_non_graph_nodes or is_node_leaf(x) + if map_non_graph_nodes + or is_node_leaf(x) + or isinstance(x, variablelib.Variable) else x, tree, is_leaf=is_leaf, @@ -362,7 +366,11 @@ def from_tree( with graph.merge_context(is_inner, ctxtag) as merge_ctx: for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes): - if map_non_graph_nodes or is_node_leaf(leaf): + if ( + map_non_graph_nodes + or is_node_leaf(leaf) + or isinstance(leaf, variablelib.Variable) + ): leaf = merge_fn(merge_ctx, keypath, leaf_prefix, leaf) leaves_out.append(leaf) @@ -370,4 +378,9 @@ def from_tree( return pytree_out def clear_non_graph_nodes(tree): - return jax.tree.map(lambda x: x if graph.is_graph_node(x) else None, tree) \ No newline at end of file + return jax.tree.map( + lambda x: x + if graph.is_graph_node(x) or isinstance(x, variablelib.Variable) + else None, + tree, + ) \ No newline at end of file diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 6fb73ed4b..c2ce79695 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -21,14 +21,14 @@ import threading import typing as tp -from flax.nnx import filterlib, reprlib, variablelib +from flax.nnx import filterlib, reprlib, traversals, variablelib from flax.nnx import statelib from flax.nnx.proxy_caller import ( ApplyCaller, CallableProxy, DelayedAccessor, ) -from flax.nnx.statelib import FlatState, State +from flax.nnx.statelib import EmptyState, FlatState, State from flax.nnx.variablelib import Variable, VariableState from flax.typing import Key, PathParts, is_key_like import jax @@ -266,8 +266,8 @@ def __treescope_repr__(self, path, subtree_renderer): @dataclasses.dataclass(frozen=True, repr=False) -class VariableDef(reprlib.Representable): - type: type[Variable] +class VariableDef(reprlib.Representable, tp.Generic[Node]): + type: type[Node] index: int outer_index: int | None metadata: HashableMapping[str, tp.Any] @@ -319,7 +319,8 @@ class NodeDef(tp.Generic[Node], reprlib.Representable): outer_index: int | None attributes: tuple[ tuple[ - Key, NodeDef[tp.Any] | VariableDef | NodeRef[tp.Any] | Static[tp.Any] + Key, + NodeDef[tp.Any] | VariableDef[tp.Any] | NodeRef[tp.Any] | Static[tp.Any], ], ..., ] @@ -406,7 +407,7 @@ def _apply( jax.tree_util.register_static(NodeDef) -GraphDef = tp.Union[NodeDef[Node], NodeRef[Node]] +GraphDef = tp.Union[NodeDef[Node], NodeRef[Node], VariableDef[Node]] PureState = tuple[GraphDef[Node], GraphState] @@ -497,7 +498,7 @@ def flatten( path: list[Key] | None = [] if with_paths else None paths: list[PathParts] | None = [] if with_paths else None node_impl = get_node_impl(node) - if node_impl is None: + if node_impl is None and not isinstance(node, Variable): raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') graphdef = _graph_flatten( node, @@ -518,27 +519,51 @@ def flatten( def _graph_flatten( node: Node, - node_impl: NodeImpl[Node, Leaf, AuxData], + node_impl: NodeImpl[Node, Leaf, AuxData] | None, path: list[Key] | None, ref_index: RefMap, ref_outer_index: RefMap | None, leaves: list[StateLeaf | Variable[tp.Any]], paths: list[PathParts] | None, return_variables: bool, -) -> NodeDef[tp.Any] | NodeRef: +) -> NodeDef | NodeRef | VariableDef: is_pytree_node_ = isinstance(node_impl, PytreeNodeImpl) is_graph_node_ = isinstance(node_impl, GraphNodeImpl) + is_variable = isinstance(node, Variable) if not is_pytree_node_ and node in ref_index: return NodeRef(type(node), ref_index[node]) # only cache graph nodes - if is_graph_node_: + if is_graph_node_ or is_variable: index = len(ref_index) ref_index[node] = index else: index = -1 + if is_variable: + assert isinstance(node, Variable) + if return_variables: + leaf = node + elif path is None: + leaf = node.raw_value + else: + leaf = node.to_state() # type: ignore[assignment] + leaves.append(leaf) + if path is not None: + assert paths is not None + paths.append(tuple(path)) + variabledef = VariableDef( + type=type(node), + index=index, + outer_index=ref_outer_index.get(node, None) if ref_outer_index else None, + metadata=HashableMapping(node._var_metadata), + ) + return variabledef + + if node_impl is None: + raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') + attributes: list[ tuple[Key, Static[tp.Any] | NodeDef[tp.Any] | VariableDef | NodeRef[tp.Any]] ] = [] @@ -548,7 +573,7 @@ def _graph_flatten( value_node_impl = get_node_impl(value) if path is not None: path.append(key) - if value_node_impl is not None: + if value_node_impl is not None or isinstance(value, Variable): nodedef = _graph_flatten( value, value_node_impl, @@ -560,30 +585,6 @@ def _graph_flatten( return_variables, ) attributes.append((key, nodedef)) - elif isinstance(value, Variable): - if value in ref_index: - attributes.append((key, NodeRef(type(value), ref_index[value]))) - else: - if return_variables: - leaf = value - elif path is None: - leaf = value.raw_value - else: - leaf = value.to_state() # type: ignore[assignment] - leaves.append(leaf) - if path is not None: - assert paths is not None - paths.append(tuple(path)) - variable_index = ref_index[value] = len(ref_index) - variabledef = VariableDef( - type=type(value), - index=variable_index, - outer_index=ref_outer_index.get(value, None) - if ref_outer_index - else None, - metadata=HashableMapping(value._var_metadata), - ) - attributes.append((key, variabledef)) else: if isinstance(value, (jax.Array, np.ndarray)): if path is not None: @@ -867,10 +868,7 @@ def unflatten( if isinstance(graphdef, NodeRef): node = index_ref[graphdef.index] else: - assert isinstance(graphdef, NodeDef) node_impl = get_node_impl_for_type(graphdef.type) - if node_impl is None: - raise RuntimeError(f'Unsupported type: {graphdef.type}, this is a bug.') node = _graph_unflatten( graphdef, node_impl, leaves, index_ref, outer_index_outer_ref ) @@ -883,8 +881,8 @@ def unflatten( def _graph_unflatten( - nodedef: NodeDef[Node] | NodeRef[Node], - node_impl: NodeImpl[Node, Leaf, AuxData], + nodedef: GraphDef[Node], + node_impl: NodeImpl[Node, Leaf, AuxData] | None, leaves: deque[tp.Any], index_ref: dict[Index, tp.Any], outer_index_outer_ref: dict[Index, tp.Any] | None, @@ -904,7 +902,55 @@ def _graph_unflatten( """ if type(nodedef) is NodeRef: return index_ref[nodedef.index] + + def make_variable(key, variabledef: VariableDef[Variable]) -> tp.Any: + if not leaves: + raise ValueError('Not enough leaves to unflatten the graph') + # its a unseen variable, create a new one + value = leaves.popleft() + # when idxmap is present, check if the Varable exists there + # and update existing variables if it does + if ( + outer_index_outer_ref is not None + and variabledef.outer_index in outer_index_outer_ref + ): + # if variable exists, update it + variable = outer_index_outer_ref[variabledef.outer_index] + if not isinstance(variable, Variable): + raise ValueError( + f'Expected a Variable type for {key!r}, but got {type(variable)}.' + ) + elif isinstance(value, Variable): + raise ValueError( + f'Cannot unflatten flat_state containing Variables when using `outer_index_outer_ref`. ' + f'Got {value!r} for {key!r}.' + ) + elif isinstance(value, VariableState): + variable.update_from_state(value) + else: + variable.raw_value = value + else: # variabledef.index not in index_ref_cache + # variable reference does not exist outside, create a new one + if isinstance(value, Variable): + variable = value + elif isinstance(value, VariableState): + variable = value.to_variable() + else: + variable = variabledef.type.from_metadata( + value, dict(variabledef.metadata) + ) + index_ref[variabledef.index] = variable + return variable + + if type(nodedef) is VariableDef: + return make_variable( + None, + nodedef, # type: ignore + ) + assert type(nodedef) is NodeDef + if node_impl is None: + raise RuntimeError(f'Unsupported type: {nodedef.type}, this is a bug.') if nodedef.index in index_ref: raise RuntimeError(f'GraphDef index {nodedef.index} already used.') @@ -927,44 +973,8 @@ def _get_children() -> list[tuple[Key, tp.Any]]: ) children.append((key, subnode)) elif type(value) is VariableDef: - variabledef = value - if not leaves: - raise ValueError('Not enough leaves to unflatten the graph') - # its a unseen variable, create a new one - value = leaves.popleft() - # when idxmap is present, check if the Varable exists there - # and update existing variables if it does - if ( - outer_index_outer_ref is not None - and variabledef.outer_index in outer_index_outer_ref - ): - # if variable exists, update it - variable = outer_index_outer_ref[variabledef.outer_index] - if not isinstance(variable, Variable): - raise ValueError( - f'Expected a Variable type for {key!r}, but got {type(variable)}.' - ) - elif isinstance(value, Variable): - raise ValueError( - f'Cannot unflatten flat_state containing Variables when using `outer_index_outer_ref`. ' - f'Got {value!r} for {key!r}.' - ) - elif isinstance(value, VariableState): - variable.update_from_state(value) - else: - variable.raw_value = value - else: # variabledef.index not in index_ref_cache - # variable reference does not exist outside, create a new one - if isinstance(value, Variable): - variable = value - elif isinstance(value, VariableState): - variable = value.to_variable() - else: - variable = variabledef.type.from_metadata( - value, dict(variabledef.metadata) - ) + variable = make_variable(key, value) children.append((key, variable)) - index_ref[variabledef.index] = variable else: raise RuntimeError(f'Unknown static field: {key!r}') @@ -1070,6 +1080,24 @@ def _graph_pop( def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[KeyT, tp.Any]): + def _update_variable(node: Variable, value): + if isinstance(value, VariableState): + # updated from VariableState + node.update_from_state(value) + else: + # updated from raw value + if isinstance(value, State) and not value: + # NOTE: this is a special case when trying to update a Variable from state + # created when flattening into a NodeRef, which creates an empty State. This + # can happen when using standalone Variables with `grad` + pass + else: + node.raw_value = value + + if isinstance(node, Variable): + _update_variable(node, state) + return + if not is_node(node): raise RuntimeError(f'Unsupported type: {type(node)}') @@ -1090,7 +1118,6 @@ def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[KeyT, tp.Any]): node_impl.set_key(node, key, value) continue - # check values are of the same type current_value = node_dict[key] # case 2: subgraph is being updated @@ -1105,12 +1132,7 @@ def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[KeyT, tp.Any]): f'Trying to update a non-Variable attribute {key!r} with a Variable: ' f'{value!r}' ) - if isinstance(value, VariableState): - # updated from VariableState - current_value.update_from_state(value) - else: - # updated from raw value - current_value.raw_value = value + _update_variable(current_value, value) # -------------------------------------------------------- @@ -1134,8 +1156,8 @@ def create( new_ref_index: RefMap, ): new_index_ref = {index: obj for obj, index in new_ref_index.items()} - final_graphdef: NodeDef[tp.Any] | NodeRef[tp.Any] - if type(graphdef) is NodeDef: + final_graphdef: GraphDef[tp.Any] + if type(graphdef) is NodeDef or type(graphdef) is VariableDef: final_graphdef = graphdef.with_same_outer_index() else: final_graphdef = graphdef @@ -1310,9 +1332,7 @@ def split( node, ref_index=self.ref_index, ref_outer_index=inner_ref_outer_index ) flat_states = _split_state(flat_state, filters) - states = tuple( - statelib.from_flat_state(flat_state) for flat_state in flat_states - ) + states = _to_nested_state(graphdef, flat_states) return graphdef, *states @@ -1469,9 +1489,9 @@ class MergeContext: def merge( self, graphdef: GraphDef[A], - state: GraphState, + state: GraphState | VariableState, /, - *states: GraphState, + *states: GraphState | VariableState, ) -> A: ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None @@ -1480,7 +1500,7 @@ def merge( ctx.outer_index_outer_ref if ctx and ctx.outer_index_outer_ref else None ) - _state = statelib.merge_state(state, *states) + _state = _merge_to_flat_state((state, *states)) node = unflatten( graphdef, _state, @@ -1735,10 +1755,8 @@ def split( graphdef, flat_state = flatten( node, ref_index=ref_index, ref_outer_index=self.inner_ref_outer_index ) - states = tuple( - statelib.from_flat_state(flat_state) - for flat_state in _split_state(flat_state, filters) - ) + flat_states = _split_state(flat_state, filters) + states = _to_nested_state(graphdef, flat_states) assert len(states) >= 1 self.flatten_end(ref_index) return graphdef, *states # type: ignore[return-value] @@ -1765,11 +1783,11 @@ def merge( # inner merge (2) index_ref_cache = None - state = statelib.merge_state(state, *states) + _state = _merge_to_flat_state((state, *states)) index_ref: dict[Index, tp.Any] = {} node = unflatten( graphdef, - state, + _state, index_ref=index_ref, outer_index_outer_ref=index_ref_cache, ) @@ -1955,11 +1973,13 @@ def _split_state( @tp.overload -def split(graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ... +def split( + graph_node: A, / +) -> tuple[GraphDef[A], GraphState | VariableState]: ... @tp.overload def split( graph_node: A, first: filterlib.Filter, / -) -> tuple[GraphDef[A], GraphState]: ... +) -> tuple[GraphDef[A], GraphState | VariableState]: ... @tp.overload def split( graph_node: A, @@ -1967,10 +1987,18 @@ def split( second: filterlib.Filter, /, *filters: filterlib.Filter, -) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ... +) -> tuple[ + GraphDef[A], + GraphState | VariableState, + tpe.Unpack[tuple[GraphState | VariableState, ...]], +]: ... def split( node: A, *filters: filterlib.Filter -) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: +) -> tuple[ + GraphDef[A], + GraphState | VariableState, + tpe.Unpack[tuple[GraphState | VariableState, ...]], +]: """Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef contains all the static information needed to reconstruct a ``Module`` graph, it is analogous @@ -2041,15 +2069,43 @@ def split( """ graphdef, flat_state = flatten(node) flat_states = _split_state(flat_state, filters) - states = tuple(statelib.from_flat_state(flat_state) for flat_state in flat_states) + states = _to_nested_state(graphdef, flat_states) return graphdef, *states # type: ignore[return-value] +def _to_nested_state( + graphdef: GraphDef[A], flat_states: tp.Iterable[tp.Any] +) -> tuple[tp.Any, ...]: + if type(graphdef) is VariableDef: + states = tuple( + flat_state[0][1] if flat_state else EmptyState() + for flat_state in flat_states + ) + else: + states = tuple( + statelib.from_flat_state(flat_state) for flat_state in flat_states + ) + return states + + +def _merge_to_flat_state(states: tp.Iterable[tp.Any]): + flat_state: list[tuple[PathParts, tp.Any]] = [] + + for state in states: + if isinstance(state, dict | State): + flat_state.extend(traversals.flatten_to_sequence(state)) + else: + flat_state.append(((), state)) + + flat_state.sort() + return [value for _, value in flat_state] + + def merge( graphdef: GraphDef[A], - state: tp.Mapping[Key, tp.Any], + state: tp.Any, /, - *states: tp.Mapping[Key, tp.Any], + *states: tp.Any, ) -> A: """The inverse of :func:`flax.nnx.split`. @@ -2094,14 +2150,12 @@ def merge( Returns: The merged :class:`flax.nnx.Module`. """ - _state = statelib.merge_state(state, *states) + _state = _merge_to_flat_state((state, *states)) node = unflatten(graphdef, _state) return node -def update( - node, state: tp.Mapping[KeyT, tp.Any], /, *states: tp.Mapping[KeyT, tp.Any] -) -> None: +def update(node, state: tp.Any, /, *states: tp.Any) -> None: """Update the given graph node with a new state(s) in-place. Example usage:: @@ -2128,7 +2182,20 @@ def update( *states: Additional :class:`State` objects. """ if states: - state = statelib.merge_state(state, *states) + if isinstance(node, Variable): + non_empty_states = [ + _state + for _state in (state, *states) + if not isinstance(_state, tp.Mapping) or _state + ] + if len(non_empty_states) != 1: + all_states = (state, *states) + raise ValueError( + f'Expected exactly one non-empty state, got: {all_states!r}' + ) + state = non_empty_states[0] + else: + state = statelib.merge_state(state, *states) _graph_update_dynamic(node, state) diff --git a/flax/nnx/module.py b/flax/nnx/module.py index afa6a1512..b7dc87382 100644 --- a/flax/nnx/module.py +++ b/flax/nnx/module.py @@ -497,7 +497,9 @@ def __init_subclass__(cls, experimental_pytree: bool = False) -> None: # ------------------------- def _module_flatten(module: Module, *, with_keys: bool): graphdef, state = graph.split(module) - key_values = sorted(state.raw_mapping.items()) + key_values = sorted( + state.raw_mapping.items() # type: ignore + ) keys = tuple(key for key, _ in key_values) children: tuple[tp.Any, ...] diff --git a/flax/nnx/statelib.py b/flax/nnx/statelib.py index b903de001..0df5ef81e 100644 --- a/flax/nnx/statelib.py +++ b/flax/nnx/statelib.py @@ -445,6 +445,7 @@ def __init_subclass__(cls) -> None: ) + def _state_flatten_with_keys(x: State): items = sorted(x._mapping.items()) children = tuple((jtu.DictKey(key), value) for key, value in items) @@ -465,6 +466,9 @@ def _state_unflatten( partial(_state_unflatten, State), # type: ignore[arg-type] ) +class EmptyState(State): + def __init__(self): + super().__init__({}) def map_state(f: tp.Callable[[tuple, tp.Any], tp.Any], state: State) -> State: flat_state = to_flat_state(state) diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index 164c6d237..2095f6bbe 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -25,7 +25,7 @@ graph, variablelib, ) -from flax.nnx.statelib import State +from flax.nnx.statelib import EmptyState, State import jax import jax.core import jax.stages @@ -64,7 +64,7 @@ class DiffState: class GradFn: f: tp.Callable[..., tp.Any] has_aux: bool - nondiff_states: deque[State | None] + nondiff_states: deque[State | variablelib.VariableState | None] def __post_init__(self): functools.update_wrapper(self, self.f) @@ -135,7 +135,7 @@ def _grad_general( def grad_wrapper(*args, **kwargs): args = resolve_kwargs(f, args, kwargs) del kwargs - nondiff_states: deque[State | None] = deque() + nondiff_states: deque[State | variablelib.VariableState | None] = deque() def _grad_split_fn( ctx: graph.SplitContext, path, prefix: DiffState | None, value @@ -412,7 +412,7 @@ def _custom_vjp_split_fn( # but we return a TreeNode.from_states which doesn't have a graphdef # in order to keep the gradients clean from any metadata graphdef, passed = ctx.split(value) - broadcast = State({}) + broadcast = EmptyState() nondiff_states.append(extract.GraphDefState(graphdef, broadcast)) return extract.NodeStates.from_states(passed) else: @@ -554,8 +554,8 @@ def state_to_node_states(is_differentiable: bool, x): if is_differentiable: if isinstance(x, jax.Array): return x - elif not isinstance(x, State): - raise ValueError(f'Expected State, got {type(x)}') + elif not isinstance(x, State | variablelib.VariableState): + raise ValueError(f'Expected State or VariableState, got {type(x)}') return extract.NodeStates.from_states(x) return x @@ -563,7 +563,7 @@ def state_to_node_states(is_differentiable: bool, x): state_to_node_states, self.tree_node_args, tangent, - is_leaf=lambda x: isinstance(x, State), + is_leaf=lambda x: isinstance(x, State | variablelib.VariableState), ) return pure_tangent diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index 62c41a3f9..12d63782d 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -1328,7 +1328,7 @@ def __call__(self, pure_val): def _add_fake_index_mapping(tree: tp.Any): def per_node_state(node_state: extract.NodeStates | tp.Any): if not isinstance(node_state, extract.NodeStates) or not isinstance( - node_state._graphdef, graph.NodeDef + node_state._graphdef, graph.NodeDef | graph.VariableDef ): return node_state @@ -1345,10 +1345,10 @@ def _remove_index_mapping(tree: tp.Any): def per_node_state(node_state: extract.NodeStates | tp.Any): if not isinstance(node_state, extract.NodeStates) or not isinstance( - node_state._graphdef, graph.NodeDef + node_state._graphdef, graph.NodeDef | graph.VariableDef ): return node_state - assert isinstance(node_state._graphdef, graph.NodeDef) + assert isinstance(node_state._graphdef, graph.NodeDef | graph.VariableDef) node_state = dataclasses.replace( node_state, _graphdef=node_state._graphdef.with_no_outer_index() ) diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index af50ad61e..af4c5f53d 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -19,6 +19,7 @@ from typing import Any from absl.testing import absltest, parameterized +import numpy as np from flax import linen, nnx, struct import jax import jax.numpy as jnp @@ -886,6 +887,130 @@ def test_fingerprint_module_id_insensitive(self): self.assertFalse(nnx.graph.check_fingerprint(m1, fp2)) self.assertFalse(nnx.graph.check_fingerprint(m2, fp1)) + def test_split_variable(self): + v = nnx.Param(1) + graphdef, state = nnx.split(v) + + self.assertIsInstance(graphdef, nnx.graph.VariableDef) + self.assertIsInstance(state, nnx.VariableState) + + v2 = nnx.merge(graphdef, state) + self.assertIsInstance(v2, nnx.Param) + + def test_split_filter_variable(self): + v = nnx.Param(1) + graphdef, batch_stats, params, rest = nnx.split( + v, nnx.BatchStat, nnx.Param, ... + ) + + self.assertIsInstance(graphdef, nnx.graph.VariableDef) + self.assertIsInstance(params, nnx.VariableState) + self.assertIsInstance(batch_stats, nnx.State) + self.assertEmpty(batch_stats) + self.assertIsInstance(rest, nnx.State) + self.assertEmpty(rest) + + v2 = nnx.merge(graphdef, batch_stats, params, rest) + self.assertIsInstance(v2, nnx.Param) + + def test_split_update_variable(self): + v = nnx.Param(1) + graphdef, state = nnx.split(v) + + self.assertIsInstance(graphdef, nnx.graph.VariableDef) + self.assertIsInstance(state, nnx.VariableState) + + state.value = 2 + nnx.update(v, state) + + self.assertEqual(v.value, 2) + + def test_split_update_filter_variable(self): + v = nnx.Param(1) + graphdef, batch_stats, params, rest = nnx.split( + v, nnx.BatchStat, nnx.Param, ... + ) + + self.assertIsInstance(graphdef, nnx.graph.VariableDef) + self.assertIsInstance(params, nnx.VariableState) + self.assertIsInstance(batch_stats, nnx.State) + self.assertEmpty(batch_stats) + self.assertIsInstance(rest, nnx.State) + self.assertEmpty(rest) + + params.value = 2 + nnx.update(v, batch_stats, params, rest) + + self.assertEqual(v.value, 2) + + def test_jit_variable(self): + v = nnx.Param(1) + + @nnx.jit + def f(v): + v += 1 + + f(v) + + np.testing.assert_allclose(v.value, 2) + + def test_jit_pytree_of_variables(self): + v1 = nnx.Param(1) + v2 = nnx.Param(2) + vs = [v1, v1, v2] + + @nnx.jit + def f(vs): + self.assertIs(vs[0], vs[1]) + self.assertIsNot(vs[0], vs[2]) + vs[0] += 10 + + f(vs) + + self.assertIs(vs[0], vs[1]) + self.assertIsNot(vs[0], vs[2]) + np.testing.assert_allclose(vs[0].value, 11) + np.testing.assert_allclose(vs[2].value, 2) + + def test_variable_reference_in_module(self): + class Foo(nnx.Module): + def __init__(self, var): + self.var = var + + var = nnx.Param(1) + foo = Foo(var) + + @nnx.jit + def increment_var(var, foo): + self.assertIs(var, foo.var) + var += 1 + + increment_var(var, foo) + self.assertEqual(foo.var.value, 2) + + def test_variables_example(self): + def stateful_linear_init(din: int, dout: int, rngs: nnx.Rngs): + w = nnx.Param(jax.random.normal(rngs(), (din, dout))) + b = nnx.Param(jnp.zeros((dout,))) + count = nnx.Variable(jnp.array(0)) + return w, b, count + + rngs = nnx.Rngs(0) + w, b, count = stateful_linear_init(2, 3, rngs=rngs) + + @nnx.jit + def stateful_linear(w, b, count, x): + count += 1 + return x @ w + b[None] + + x = jax.random.normal(rngs(), (1, 2)) + y = stateful_linear(w, b, count, x) + self.assertEqual(count.value, 1) + + y = stateful_linear(w, b, count, x) + self.assertEqual(count.value, 2) + self.assertEqual(y.shape, (1, 3)) + class SimpleModule(nnx.Module): pass diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 9c812227f..65256a088 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -762,6 +762,35 @@ def loss_fn(l1: list[nnx.Linear], l2: list[nnx.Linear]): self.assertNotIn('kernel', grads_m2[0]) self.assertIn('bias', grads_m2[0]) + def test_variables_in_grad(self): + p1 = nnx.Param(10.0) + p2 = nnx.Param(20.0) + + m = dict(a=[p1, p2], b=p1) + + @nnx.grad + def f(m: dict): + # sum all params + return m['a'][0].value + m['a'][1].value + m['b'].value + + grads = f(m) + + assert m['a'][0] is m['b'] + assert isinstance(grads, dict) + assert issubclass(grads['a'][0].type, nnx.Variable) + assert grads['a'][1].value == 1.0 + assert issubclass(grads['a'][1].type, nnx.Variable) + assert len(jax.tree.leaves(grads)) == 2 + + jax.tree.map(nnx.update, m, grads) + + assert m['a'][0] is m['b'] + assert m['a'][0].value == 2.0 + assert m['a'][1].value == 1.0 + assert m['b'].value == 2.0 + assert m['c'] == 7 + assert m['d'] == 5.0 + class TestCustomVJP(parameterized.TestCase):