From f334500b8ee776b826374ab60886cc6d35af04bb Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Fri, 7 Mar 2025 11:22:30 -0800 Subject: [PATCH] [nnx] add support for standalone Variables --- flax/nnx/graph.py | 209 ++++++++++++++++++++-------------- tests/nnx/graph_utils_test.py | 26 +++++ 2 files changed, 151 insertions(+), 84 deletions(-) diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 6fb73ed4..23e554a3 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -21,7 +21,7 @@ import threading import typing as tp -from flax.nnx import filterlib, reprlib, variablelib +from flax.nnx import filterlib, reprlib, traversals, variablelib from flax.nnx import statelib from flax.nnx.proxy_caller import ( ApplyCaller, @@ -40,6 +40,7 @@ B = tp.TypeVar('B') C = tp.TypeVar('C') F = tp.TypeVar('F', bound=tp.Callable) +V = tp.TypeVar('V', bound=Variable) HA = tp.TypeVar('HA', bound=tp.Hashable) HB = tp.TypeVar('HB', bound=tp.Hashable) @@ -266,8 +267,8 @@ def __treescope_repr__(self, path, subtree_renderer): @dataclasses.dataclass(frozen=True, repr=False) -class VariableDef(reprlib.Representable): - type: type[Variable] +class VariableDef(reprlib.Representable, tp.Generic[Node]): + type: type[Node] index: int outer_index: int | None metadata: HashableMapping[str, tp.Any] @@ -319,7 +320,8 @@ class NodeDef(tp.Generic[Node], reprlib.Representable): outer_index: int | None attributes: tuple[ tuple[ - Key, NodeDef[tp.Any] | VariableDef | NodeRef[tp.Any] | Static[tp.Any] + Key, + NodeDef[tp.Any] | VariableDef[tp.Any] | NodeRef[tp.Any] | Static[tp.Any], ], ..., ] @@ -406,7 +408,7 @@ def _apply( jax.tree_util.register_static(NodeDef) -GraphDef = tp.Union[NodeDef[Node], NodeRef[Node]] +GraphDef = tp.Union[NodeDef[Node], NodeRef[Node], VariableDef[Node]] PureState = tuple[GraphDef[Node], GraphState] @@ -497,7 +499,7 @@ def flatten( path: list[Key] | None = [] if with_paths else None paths: list[PathParts] | None = [] if with_paths else None node_impl = get_node_impl(node) - if node_impl is None: + if node_impl is None and not isinstance(node, Variable): raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') graphdef = _graph_flatten( node, @@ -518,27 +520,51 @@ def flatten( def _graph_flatten( node: Node, - node_impl: NodeImpl[Node, Leaf, AuxData], + node_impl: NodeImpl[Node, Leaf, AuxData] | None, path: list[Key] | None, ref_index: RefMap, ref_outer_index: RefMap | None, leaves: list[StateLeaf | Variable[tp.Any]], paths: list[PathParts] | None, return_variables: bool, -) -> NodeDef[tp.Any] | NodeRef: +) -> NodeDef | NodeRef | VariableDef: is_pytree_node_ = isinstance(node_impl, PytreeNodeImpl) is_graph_node_ = isinstance(node_impl, GraphNodeImpl) + is_variable = isinstance(node, Variable) if not is_pytree_node_ and node in ref_index: return NodeRef(type(node), ref_index[node]) # only cache graph nodes - if is_graph_node_: + if is_graph_node_ or is_variable: index = len(ref_index) ref_index[node] = index else: index = -1 + if is_variable: + assert isinstance(node, Variable) + if return_variables: + leaf = node + elif path is None: + leaf = node.raw_value + else: + leaf = node.to_state() # type: ignore[assignment] + leaves.append(leaf) + if path is not None: + assert paths is not None + paths.append(tuple(path)) + variabledef = VariableDef( + type=type(node), + index=index, + outer_index=ref_outer_index.get(node, None) if ref_outer_index else None, + metadata=HashableMapping(node._var_metadata), + ) + return variabledef + + if node_impl is None: + raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') + attributes: list[ tuple[Key, Static[tp.Any] | NodeDef[tp.Any] | VariableDef | NodeRef[tp.Any]] ] = [] @@ -548,7 +574,7 @@ def _graph_flatten( value_node_impl = get_node_impl(value) if path is not None: path.append(key) - if value_node_impl is not None: + if value_node_impl is not None or isinstance(value, Variable): nodedef = _graph_flatten( value, value_node_impl, @@ -560,30 +586,6 @@ def _graph_flatten( return_variables, ) attributes.append((key, nodedef)) - elif isinstance(value, Variable): - if value in ref_index: - attributes.append((key, NodeRef(type(value), ref_index[value]))) - else: - if return_variables: - leaf = value - elif path is None: - leaf = value.raw_value - else: - leaf = value.to_state() # type: ignore[assignment] - leaves.append(leaf) - if path is not None: - assert paths is not None - paths.append(tuple(path)) - variable_index = ref_index[value] = len(ref_index) - variabledef = VariableDef( - type=type(value), - index=variable_index, - outer_index=ref_outer_index.get(value, None) - if ref_outer_index - else None, - metadata=HashableMapping(value._var_metadata), - ) - attributes.append((key, variabledef)) else: if isinstance(value, (jax.Array, np.ndarray)): if path is not None: @@ -867,10 +869,7 @@ def unflatten( if isinstance(graphdef, NodeRef): node = index_ref[graphdef.index] else: - assert isinstance(graphdef, NodeDef) node_impl = get_node_impl_for_type(graphdef.type) - if node_impl is None: - raise RuntimeError(f'Unsupported type: {graphdef.type}, this is a bug.') node = _graph_unflatten( graphdef, node_impl, leaves, index_ref, outer_index_outer_ref ) @@ -883,8 +882,8 @@ def unflatten( def _graph_unflatten( - nodedef: NodeDef[Node] | NodeRef[Node], - node_impl: NodeImpl[Node, Leaf, AuxData], + nodedef: GraphDef[Node], + node_impl: NodeImpl[Node, Leaf, AuxData] | None, leaves: deque[tp.Any], index_ref: dict[Index, tp.Any], outer_index_outer_ref: dict[Index, tp.Any] | None, @@ -904,7 +903,55 @@ def _graph_unflatten( """ if type(nodedef) is NodeRef: return index_ref[nodedef.index] + + def make_variable(key, variabledef: VariableDef[Variable]) -> tp.Any: + if not leaves: + raise ValueError('Not enough leaves to unflatten the graph') + # its a unseen variable, create a new one + value = leaves.popleft() + # when idxmap is present, check if the Varable exists there + # and update existing variables if it does + if ( + outer_index_outer_ref is not None + and variabledef.outer_index in outer_index_outer_ref + ): + # if variable exists, update it + variable = outer_index_outer_ref[variabledef.outer_index] + if not isinstance(variable, Variable): + raise ValueError( + f'Expected a Variable type for {key!r}, but got {type(variable)}.' + ) + elif isinstance(value, Variable): + raise ValueError( + f'Cannot unflatten flat_state containing Variables when using `outer_index_outer_ref`. ' + f'Got {value!r} for {key!r}.' + ) + elif isinstance(value, VariableState): + variable.update_from_state(value) + else: + variable.raw_value = value + else: # variabledef.index not in index_ref_cache + # variable reference does not exist outside, create a new one + if isinstance(value, Variable): + variable = value + elif isinstance(value, VariableState): + variable = value.to_variable() + else: + variable = variabledef.type.from_metadata( + value, dict(variabledef.metadata) + ) + index_ref[variabledef.index] = variable + return variable + + if type(nodedef) is VariableDef: + return make_variable( + None, + nodedef, # type: ignore + ) + assert type(nodedef) is NodeDef + if node_impl is None: + raise RuntimeError(f'Unsupported type: {nodedef.type}, this is a bug.') if nodedef.index in index_ref: raise RuntimeError(f'GraphDef index {nodedef.index} already used.') @@ -927,44 +974,8 @@ def _get_children() -> list[tuple[Key, tp.Any]]: ) children.append((key, subnode)) elif type(value) is VariableDef: - variabledef = value - if not leaves: - raise ValueError('Not enough leaves to unflatten the graph') - # its a unseen variable, create a new one - value = leaves.popleft() - # when idxmap is present, check if the Varable exists there - # and update existing variables if it does - if ( - outer_index_outer_ref is not None - and variabledef.outer_index in outer_index_outer_ref - ): - # if variable exists, update it - variable = outer_index_outer_ref[variabledef.outer_index] - if not isinstance(variable, Variable): - raise ValueError( - f'Expected a Variable type for {key!r}, but got {type(variable)}.' - ) - elif isinstance(value, Variable): - raise ValueError( - f'Cannot unflatten flat_state containing Variables when using `outer_index_outer_ref`. ' - f'Got {value!r} for {key!r}.' - ) - elif isinstance(value, VariableState): - variable.update_from_state(value) - else: - variable.raw_value = value - else: # variabledef.index not in index_ref_cache - # variable reference does not exist outside, create a new one - if isinstance(value, Variable): - variable = value - elif isinstance(value, VariableState): - variable = value.to_variable() - else: - variable = variabledef.type.from_metadata( - value, dict(variabledef.metadata) - ) + variable = make_variable(key, value) children.append((key, variable)) - index_ref[variabledef.index] = variable else: raise RuntimeError(f'Unknown static field: {key!r}') @@ -1955,11 +1966,13 @@ def _split_state( @tp.overload -def split(graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ... +def split( + graph_node: A, / +) -> tuple[GraphDef[A], GraphState | VariableState]: ... @tp.overload def split( graph_node: A, first: filterlib.Filter, / -) -> tuple[GraphDef[A], GraphState]: ... +) -> tuple[GraphDef[A], GraphState | VariableState]: ... @tp.overload def split( graph_node: A, @@ -1967,10 +1980,18 @@ def split( second: filterlib.Filter, /, *filters: filterlib.Filter, -) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ... +) -> tuple[ + GraphDef[A], + GraphState | VariableState, + tpe.Unpack[tuple[GraphState | VariableState, ...]], +]: ... def split( node: A, *filters: filterlib.Filter -) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: +) -> tuple[ + GraphDef[A], + GraphState | VariableState, + tpe.Unpack[tuple[GraphState | VariableState, ...]], +]: """Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef contains all the static information needed to reconstruct a ``Module`` graph, it is analogous @@ -2041,15 +2062,35 @@ def split( """ graphdef, flat_state = flatten(node) flat_states = _split_state(flat_state, filters) - states = tuple(statelib.from_flat_state(flat_state) for flat_state in flat_states) + if type(graphdef) is VariableDef: + states = tuple( + flat_state[0][1] if flat_state else State({}) + for flat_state in flat_states + ) + else: + states = tuple( + statelib.from_flat_state(flat_state) for flat_state in flat_states + ) return graphdef, *states # type: ignore[return-value] +def _merge_to_flat_state(states: tp.Iterable[tp.Any]): + flat_state: list[tuple[PathParts, tp.Any]] = [] + + for state in states: + if isinstance(state, dict | State): + flat_state.extend(traversals.flatten_to_sequence(state)) + else: + flat_state.append(((), state)) + + flat_state.sort() + return flat_state + def merge( graphdef: GraphDef[A], - state: tp.Mapping[Key, tp.Any], + state: tp.Any, /, - *states: tp.Mapping[Key, tp.Any], + *states: tp.Any, ) -> A: """The inverse of :func:`flax.nnx.split`. @@ -2094,7 +2135,7 @@ def merge( Returns: The merged :class:`flax.nnx.Module`. """ - _state = statelib.merge_state(state, *states) + _state = _merge_to_flat_state((state, *states)) node = unflatten(graphdef, _state) return node diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index af50ad61..84bf2193 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -886,6 +886,32 @@ def test_fingerprint_module_id_insensitive(self): self.assertFalse(nnx.graph.check_fingerprint(m1, fp2)) self.assertFalse(nnx.graph.check_fingerprint(m2, fp1)) + def test_split_variable(self): + v = nnx.Param(1) + graphdef, state = nnx.split(v) + + self.assertIsInstance(graphdef, nnx.graph.VariableDef) + self.assertIsInstance(state, nnx.VariableState) + + v2 = nnx.merge(graphdef, state) + self.assertIsInstance(v2, nnx.Param) + + def test_split_filter_variable(self): + v = nnx.Param(1) + graphdef, batch_stats, params, rest = nnx.split( + v, nnx.BatchStat, nnx.Param, ... + ) + + self.assertIsInstance(graphdef, nnx.graph.VariableDef) + self.assertIsInstance(params, nnx.VariableState) + self.assertIsInstance(batch_stats, nnx.State) + self.assertEmpty(batch_stats) + self.assertIsInstance(rest, nnx.State) + self.assertEmpty(rest) + + v2 = nnx.merge(graphdef, batch_stats, params, rest) + self.assertIsInstance(v2, nnx.Param) + class SimpleModule(nnx.Module): pass