Skip to content

Commit a3f07ad

Browse files
committed
[nnx] add support for standalone Variables
1 parent e3789de commit a3f07ad

File tree

6 files changed

+320
-119
lines changed

6 files changed

+320
-119
lines changed

flax/nnx/bridge/wrappers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def init(self, *, rngs: tp.Optional[Rngs] = None) -> State:
4949
graphdef, state = nnx.split(module)
5050
assert type(graphdef) is graph.NodeDef
5151
self.graphdef = graphdef
52-
return state
52+
return state # type: ignore
5353

5454
def apply(self, *states: tp.Any):
5555
assert self.graphdef is not None

flax/nnx/extract.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ def map_prefix(
127127
) -> tp.Any: ...
128128

129129
def check_consistent_aliasing(
130-
node: tuple[tp.Any, ...],
131-
prefix: tuple[tp.Any, ...],
130+
node: tp.Any,
131+
prefix: tp.Any,
132132
/,
133133
*,
134134
node_prefixes: dict[tp.Any, list[tuple[PathParts, tp.Any]]] | None = None,
@@ -279,7 +279,9 @@ def to_tree(
279279
with graph.split_context(ctxtag) as split_ctx:
280280
return jax.tree.map(
281281
lambda x: split_fn(split_ctx, (), prefix, x)
282-
if map_non_graph_nodes or graph.is_graph_node(x)
282+
if map_non_graph_nodes
283+
or graph.is_graph_node(x)
284+
or isinstance(x, variablelib.Variable)
283285
else x,
284286
tree,
285287
)
@@ -296,7 +298,7 @@ def to_tree(
296298

297299
with graph.split_context(ctxtag) as split_ctx:
298300
for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes):
299-
if graph.is_graph_node(leaf):
301+
if graph.is_graph_node(leaf) or isinstance(leaf, variablelib.Variable):
300302
if check_aliasing:
301303
check_consistent_aliasing(
302304
leaf, leaf_prefix, node_prefixes=node_prefixes
@@ -343,7 +345,9 @@ def from_tree(
343345
with graph.merge_context(is_inner, ctxtag) as merge_ctx:
344346
return jax.tree.map(
345347
lambda x: merge_fn(merge_ctx, (), prefix, x)
346-
if map_non_graph_nodes or is_node_leaf(x)
348+
if map_non_graph_nodes
349+
or is_node_leaf(x)
350+
or isinstance(x, variablelib.Variable)
347351
else x,
348352
tree,
349353
is_leaf=is_leaf,
@@ -362,12 +366,21 @@ def from_tree(
362366

363367
with graph.merge_context(is_inner, ctxtag) as merge_ctx:
364368
for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes):
365-
if map_non_graph_nodes or is_node_leaf(leaf):
369+
if (
370+
map_non_graph_nodes
371+
or is_node_leaf(leaf)
372+
or isinstance(leaf, variablelib.Variable)
373+
):
366374
leaf = merge_fn(merge_ctx, keypath, leaf_prefix, leaf)
367375
leaves_out.append(leaf)
368376

369377
pytree_out = jax.tree.unflatten(treedef, leaves_out)
370378
return pytree_out
371379

372380
def clear_non_graph_nodes(tree):
373-
return jax.tree.map(lambda x: x if graph.is_graph_node(x) else None, tree)
381+
return jax.tree.map(
382+
lambda x: x
383+
if graph.is_graph_node(x) or isinstance(x, variablelib.Variable)
384+
else None,
385+
tree,
386+
)

0 commit comments

Comments
 (0)