Skip to content

Commit

Permalink
[nnx] add support for standalone Variables
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Mar 7, 2025
1 parent e3789de commit f334500
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 84 deletions.
209 changes: 125 additions & 84 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import threading
import typing as tp

from flax.nnx import filterlib, reprlib, variablelib
from flax.nnx import filterlib, reprlib, traversals, variablelib
from flax.nnx import statelib
from flax.nnx.proxy_caller import (
ApplyCaller,
Expand All @@ -40,6 +40,7 @@
B = tp.TypeVar('B')
C = tp.TypeVar('C')
F = tp.TypeVar('F', bound=tp.Callable)
V = tp.TypeVar('V', bound=Variable)

HA = tp.TypeVar('HA', bound=tp.Hashable)
HB = tp.TypeVar('HB', bound=tp.Hashable)
Expand Down Expand Up @@ -266,8 +267,8 @@ def __treescope_repr__(self, path, subtree_renderer):


@dataclasses.dataclass(frozen=True, repr=False)
class VariableDef(reprlib.Representable):
type: type[Variable]
class VariableDef(reprlib.Representable, tp.Generic[Node]):
type: type[Node]
index: int
outer_index: int | None
metadata: HashableMapping[str, tp.Any]
Expand Down Expand Up @@ -319,7 +320,8 @@ class NodeDef(tp.Generic[Node], reprlib.Representable):
outer_index: int | None
attributes: tuple[
tuple[
Key, NodeDef[tp.Any] | VariableDef | NodeRef[tp.Any] | Static[tp.Any]
Key,
NodeDef[tp.Any] | VariableDef[tp.Any] | NodeRef[tp.Any] | Static[tp.Any],
],
...,
]
Expand Down Expand Up @@ -406,7 +408,7 @@ def _apply(

jax.tree_util.register_static(NodeDef)

GraphDef = tp.Union[NodeDef[Node], NodeRef[Node]]
GraphDef = tp.Union[NodeDef[Node], NodeRef[Node], VariableDef[Node]]
PureState = tuple[GraphDef[Node], GraphState]


Expand Down Expand Up @@ -497,7 +499,7 @@ def flatten(
path: list[Key] | None = [] if with_paths else None
paths: list[PathParts] | None = [] if with_paths else None
node_impl = get_node_impl(node)
if node_impl is None:
if node_impl is None and not isinstance(node, Variable):
raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
graphdef = _graph_flatten(
node,
Expand All @@ -518,27 +520,51 @@ def flatten(

def _graph_flatten(
node: Node,
node_impl: NodeImpl[Node, Leaf, AuxData],
node_impl: NodeImpl[Node, Leaf, AuxData] | None,
path: list[Key] | None,
ref_index: RefMap,
ref_outer_index: RefMap | None,
leaves: list[StateLeaf | Variable[tp.Any]],
paths: list[PathParts] | None,
return_variables: bool,
) -> NodeDef[tp.Any] | NodeRef:
) -> NodeDef | NodeRef | VariableDef:
is_pytree_node_ = isinstance(node_impl, PytreeNodeImpl)
is_graph_node_ = isinstance(node_impl, GraphNodeImpl)
is_variable = isinstance(node, Variable)

if not is_pytree_node_ and node in ref_index:
return NodeRef(type(node), ref_index[node])

# only cache graph nodes
if is_graph_node_:
if is_graph_node_ or is_variable:
index = len(ref_index)
ref_index[node] = index
else:
index = -1

if is_variable:
assert isinstance(node, Variable)
if return_variables:
leaf = node
elif path is None:
leaf = node.raw_value
else:
leaf = node.to_state() # type: ignore[assignment]
leaves.append(leaf)
if path is not None:
assert paths is not None
paths.append(tuple(path))
variabledef = VariableDef(
type=type(node),
index=index,
outer_index=ref_outer_index.get(node, None) if ref_outer_index else None,
metadata=HashableMapping(node._var_metadata),
)
return variabledef

if node_impl is None:
raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')

attributes: list[
tuple[Key, Static[tp.Any] | NodeDef[tp.Any] | VariableDef | NodeRef[tp.Any]]
] = []
Expand All @@ -548,7 +574,7 @@ def _graph_flatten(
value_node_impl = get_node_impl(value)
if path is not None:
path.append(key)
if value_node_impl is not None:
if value_node_impl is not None or isinstance(value, Variable):
nodedef = _graph_flatten(
value,
value_node_impl,
Expand All @@ -560,30 +586,6 @@ def _graph_flatten(
return_variables,
)
attributes.append((key, nodedef))
elif isinstance(value, Variable):
if value in ref_index:
attributes.append((key, NodeRef(type(value), ref_index[value])))
else:
if return_variables:
leaf = value
elif path is None:
leaf = value.raw_value
else:
leaf = value.to_state() # type: ignore[assignment]
leaves.append(leaf)
if path is not None:
assert paths is not None
paths.append(tuple(path))
variable_index = ref_index[value] = len(ref_index)
variabledef = VariableDef(
type=type(value),
index=variable_index,
outer_index=ref_outer_index.get(value, None)
if ref_outer_index
else None,
metadata=HashableMapping(value._var_metadata),
)
attributes.append((key, variabledef))
else:
if isinstance(value, (jax.Array, np.ndarray)):
if path is not None:
Expand Down Expand Up @@ -867,10 +869,7 @@ def unflatten(
if isinstance(graphdef, NodeRef):
node = index_ref[graphdef.index]
else:
assert isinstance(graphdef, NodeDef)
node_impl = get_node_impl_for_type(graphdef.type)
if node_impl is None:
raise RuntimeError(f'Unsupported type: {graphdef.type}, this is a bug.')
node = _graph_unflatten(
graphdef, node_impl, leaves, index_ref, outer_index_outer_ref
)
Expand All @@ -883,8 +882,8 @@ def unflatten(


def _graph_unflatten(
nodedef: NodeDef[Node] | NodeRef[Node],
node_impl: NodeImpl[Node, Leaf, AuxData],
nodedef: GraphDef[Node],
node_impl: NodeImpl[Node, Leaf, AuxData] | None,
leaves: deque[tp.Any],
index_ref: dict[Index, tp.Any],
outer_index_outer_ref: dict[Index, tp.Any] | None,
Expand All @@ -904,7 +903,55 @@ def _graph_unflatten(
"""
if type(nodedef) is NodeRef:
return index_ref[nodedef.index]

def make_variable(key, variabledef: VariableDef[Variable]) -> tp.Any:
if not leaves:
raise ValueError('Not enough leaves to unflatten the graph')
# its a unseen variable, create a new one
value = leaves.popleft()
# when idxmap is present, check if the Varable exists there
# and update existing variables if it does
if (
outer_index_outer_ref is not None
and variabledef.outer_index in outer_index_outer_ref
):
# if variable exists, update it
variable = outer_index_outer_ref[variabledef.outer_index]
if not isinstance(variable, Variable):
raise ValueError(
f'Expected a Variable type for {key!r}, but got {type(variable)}.'
)
elif isinstance(value, Variable):
raise ValueError(
f'Cannot unflatten flat_state containing Variables when using `outer_index_outer_ref`. '
f'Got {value!r} for {key!r}.'
)
elif isinstance(value, VariableState):
variable.update_from_state(value)
else:
variable.raw_value = value
else: # variabledef.index not in index_ref_cache
# variable reference does not exist outside, create a new one
if isinstance(value, Variable):
variable = value
elif isinstance(value, VariableState):
variable = value.to_variable()
else:
variable = variabledef.type.from_metadata(
value, dict(variabledef.metadata)
)
index_ref[variabledef.index] = variable
return variable

if type(nodedef) is VariableDef:
return make_variable(
None,
nodedef, # type: ignore
)

assert type(nodedef) is NodeDef
if node_impl is None:
raise RuntimeError(f'Unsupported type: {nodedef.type}, this is a bug.')
if nodedef.index in index_ref:
raise RuntimeError(f'GraphDef index {nodedef.index} already used.')

Expand All @@ -927,44 +974,8 @@ def _get_children() -> list[tuple[Key, tp.Any]]:
)
children.append((key, subnode))
elif type(value) is VariableDef:
variabledef = value
if not leaves:
raise ValueError('Not enough leaves to unflatten the graph')
# its a unseen variable, create a new one
value = leaves.popleft()
# when idxmap is present, check if the Varable exists there
# and update existing variables if it does
if (
outer_index_outer_ref is not None
and variabledef.outer_index in outer_index_outer_ref
):
# if variable exists, update it
variable = outer_index_outer_ref[variabledef.outer_index]
if not isinstance(variable, Variable):
raise ValueError(
f'Expected a Variable type for {key!r}, but got {type(variable)}.'
)
elif isinstance(value, Variable):
raise ValueError(
f'Cannot unflatten flat_state containing Variables when using `outer_index_outer_ref`. '
f'Got {value!r} for {key!r}.'
)
elif isinstance(value, VariableState):
variable.update_from_state(value)
else:
variable.raw_value = value
else: # variabledef.index not in index_ref_cache
# variable reference does not exist outside, create a new one
if isinstance(value, Variable):
variable = value
elif isinstance(value, VariableState):
variable = value.to_variable()
else:
variable = variabledef.type.from_metadata(
value, dict(variabledef.metadata)
)
variable = make_variable(key, value)
children.append((key, variable))
index_ref[variabledef.index] = variable
else:
raise RuntimeError(f'Unknown static field: {key!r}')

Expand Down Expand Up @@ -1955,22 +1966,32 @@ def _split_state(


@tp.overload
def split(graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ...
def split(
graph_node: A, /
) -> tuple[GraphDef[A], GraphState | VariableState]: ...
@tp.overload
def split(
graph_node: A, first: filterlib.Filter, /
) -> tuple[GraphDef[A], GraphState]: ...
) -> tuple[GraphDef[A], GraphState | VariableState]: ...
@tp.overload
def split(
graph_node: A,
first: filterlib.Filter,
second: filterlib.Filter,
/,
*filters: filterlib.Filter,
) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ...
) -> tuple[
GraphDef[A],
GraphState | VariableState,
tpe.Unpack[tuple[GraphState | VariableState, ...]],
]: ...
def split(
node: A, *filters: filterlib.Filter
) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]:
) -> tuple[
GraphDef[A],
GraphState | VariableState,
tpe.Unpack[tuple[GraphState | VariableState, ...]],
]:
"""Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is
a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef
contains all the static information needed to reconstruct a ``Module`` graph, it is analogous
Expand Down Expand Up @@ -2041,15 +2062,35 @@ def split(
"""
graphdef, flat_state = flatten(node)
flat_states = _split_state(flat_state, filters)
states = tuple(statelib.from_flat_state(flat_state) for flat_state in flat_states)
if type(graphdef) is VariableDef:
states = tuple(
flat_state[0][1] if flat_state else State({})
for flat_state in flat_states
)
else:
states = tuple(
statelib.from_flat_state(flat_state) for flat_state in flat_states
)
return graphdef, *states # type: ignore[return-value]

def _merge_to_flat_state(states: tp.Iterable[tp.Any]):
flat_state: list[tuple[PathParts, tp.Any]] = []

for state in states:
if isinstance(state, dict | State):
flat_state.extend(traversals.flatten_to_sequence(state))
else:
flat_state.append(((), state))

flat_state.sort()
return flat_state


def merge(
graphdef: GraphDef[A],
state: tp.Mapping[Key, tp.Any],
state: tp.Any,
/,
*states: tp.Mapping[Key, tp.Any],
*states: tp.Any,
) -> A:
"""The inverse of :func:`flax.nnx.split`.
Expand Down Expand Up @@ -2094,7 +2135,7 @@ def merge(
Returns:
The merged :class:`flax.nnx.Module`.
"""
_state = statelib.merge_state(state, *states)
_state = _merge_to_flat_state((state, *states))
node = unflatten(graphdef, _state)
return node

Expand Down
26 changes: 26 additions & 0 deletions tests/nnx/graph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,32 @@ def test_fingerprint_module_id_insensitive(self):
self.assertFalse(nnx.graph.check_fingerprint(m1, fp2))
self.assertFalse(nnx.graph.check_fingerprint(m2, fp1))

def test_split_variable(self):
v = nnx.Param(1)
graphdef, state = nnx.split(v)

self.assertIsInstance(graphdef, nnx.graph.VariableDef)
self.assertIsInstance(state, nnx.VariableState)

v2 = nnx.merge(graphdef, state)
self.assertIsInstance(v2, nnx.Param)

def test_split_filter_variable(self):
v = nnx.Param(1)
graphdef, batch_stats, params, rest = nnx.split(
v, nnx.BatchStat, nnx.Param, ...
)

self.assertIsInstance(graphdef, nnx.graph.VariableDef)
self.assertIsInstance(params, nnx.VariableState)
self.assertIsInstance(batch_stats, nnx.State)
self.assertEmpty(batch_stats)
self.assertIsInstance(rest, nnx.State)
self.assertEmpty(rest)

v2 = nnx.merge(graphdef, batch_stats, params, rest)
self.assertIsInstance(v2, nnx.Param)


class SimpleModule(nnx.Module):
pass
Expand Down

0 comments on commit f334500

Please sign in to comment.