Skip to content

Commit

Permalink
[nnx] add support for standalone Variables
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Mar 8, 2025
1 parent e3789de commit a3f07ad
Show file tree
Hide file tree
Showing 6 changed files with 320 additions and 119 deletions.
2 changes: 1 addition & 1 deletion flax/nnx/bridge/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def init(self, *, rngs: tp.Optional[Rngs] = None) -> State:
graphdef, state = nnx.split(module)
assert type(graphdef) is graph.NodeDef
self.graphdef = graphdef
return state
return state # type: ignore

def apply(self, *states: tp.Any):
assert self.graphdef is not None
Expand Down
27 changes: 20 additions & 7 deletions flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def map_prefix(
) -> tp.Any: ...

def check_consistent_aliasing(
node: tuple[tp.Any, ...],
prefix: tuple[tp.Any, ...],
node: tp.Any,
prefix: tp.Any,
/,
*,
node_prefixes: dict[tp.Any, list[tuple[PathParts, tp.Any]]] | None = None,
Expand Down Expand Up @@ -279,7 +279,9 @@ def to_tree(
with graph.split_context(ctxtag) as split_ctx:
return jax.tree.map(
lambda x: split_fn(split_ctx, (), prefix, x)
if map_non_graph_nodes or graph.is_graph_node(x)
if map_non_graph_nodes
or graph.is_graph_node(x)
or isinstance(x, variablelib.Variable)
else x,
tree,
)
Expand All @@ -296,7 +298,7 @@ def to_tree(

with graph.split_context(ctxtag) as split_ctx:
for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes):
if graph.is_graph_node(leaf):
if graph.is_graph_node(leaf) or isinstance(leaf, variablelib.Variable):
if check_aliasing:
check_consistent_aliasing(
leaf, leaf_prefix, node_prefixes=node_prefixes
Expand Down Expand Up @@ -343,7 +345,9 @@ def from_tree(
with graph.merge_context(is_inner, ctxtag) as merge_ctx:
return jax.tree.map(
lambda x: merge_fn(merge_ctx, (), prefix, x)
if map_non_graph_nodes or is_node_leaf(x)
if map_non_graph_nodes
or is_node_leaf(x)
or isinstance(x, variablelib.Variable)
else x,
tree,
is_leaf=is_leaf,
Expand All @@ -362,12 +366,21 @@ def from_tree(

with graph.merge_context(is_inner, ctxtag) as merge_ctx:
for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes):
if map_non_graph_nodes or is_node_leaf(leaf):
if (
map_non_graph_nodes
or is_node_leaf(leaf)
or isinstance(leaf, variablelib.Variable)
):
leaf = merge_fn(merge_ctx, keypath, leaf_prefix, leaf)
leaves_out.append(leaf)

pytree_out = jax.tree.unflatten(treedef, leaves_out)
return pytree_out

def clear_non_graph_nodes(tree):
return jax.tree.map(lambda x: x if graph.is_graph_node(x) else None, tree)
return jax.tree.map(
lambda x: x
if graph.is_graph_node(x) or isinstance(x, variablelib.Variable)
else None,
tree,
)
Loading

0 comments on commit a3f07ad

Please sign in to comment.