From e3789de21db2a83d77df90d17595554714965189 Mon Sep 17 00:00:00 2001 From: Flax Team Date: Thu, 6 Mar 2025 13:22:21 -0800 Subject: [PATCH] Copybara import of the project: -- 4d8c9bbeff254016a81b5a286eb97dd02e9b7a0e by Cristian Garcia : [nnx] pytrees are graph nodes PiperOrigin-RevId: 734264536 --- flax/nnx/graph.py | 67 ++++++++++++++++++++++------------- tests/nnx/graph_utils_test.py | 6 ++-- 2 files changed, 45 insertions(+), 28 deletions(-) diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index e8930a78..6fb73ed4 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -526,11 +526,18 @@ def _graph_flatten( paths: list[PathParts] | None, return_variables: bool, ) -> NodeDef[tp.Any] | NodeRef: - if node in ref_index: + is_pytree_node_ = isinstance(node_impl, PytreeNodeImpl) + is_graph_node_ = isinstance(node_impl, GraphNodeImpl) + + if not is_pytree_node_ and node in ref_index: return NodeRef(type(node), ref_index[node]) - # assign index - ref_index[node] = index = len(ref_index) + # only cache graph nodes + if is_graph_node_: + index = len(ref_index) + ref_index[node] = index + else: + index = -1 attributes: list[ tuple[Key, Static[tp.Any] | NodeDef[tp.Any] | VariableDef | NodeRef[tp.Any]] @@ -596,7 +603,7 @@ def _graph_flatten( type=node_impl.type, # type: ignore[arg-type] index=index, outer_index=ref_outer_index[node] - if ref_outer_index and node in ref_outer_index + if is_graph_node_ and ref_outer_index and node in ref_outer_index else None, attributes=tuple(attributes), metadata=metadata, @@ -639,18 +646,23 @@ def _graph_fingerprint( ref_index: RefMap, new_ref_index: RefMap, ): + is_pytree_node_ = type(node_impl) is PytreeNodeImpl + is_graph_node_ = type(node_impl) is GraphNodeImpl append_fn(type(node)) - append_fn(id(node)) - if node in ref_index: - append_fn(ref_index[node]) - return - elif node in new_ref_index: - append_fn(new_ref_index[node]) - return - index = new_ref_index[node] = ctx.next_index - ctx.next_index += 1 + if is_graph_node_: + append_fn(id(node)) + if node in ref_index: + append_fn(ref_index[node]) + return + elif node in new_ref_index: + append_fn(new_ref_index[node]) + return + index = new_ref_index[node] = ctx.next_index + ctx.next_index += 1 + else: + index = -1 values, metadata = node_impl.flatten(node) @@ -720,20 +732,26 @@ def _check_graph_fingerprint( ref_index: RefMap, new_ref_index: RefMap, ) -> bool: + is_pytree_node_ = type(node_impl) is PytreeNodeImpl + is_graph_node_ = type(node_impl) is GraphNodeImpl + if type(node) != next(fp_iterator): return False - # append_fn(id(node)) - if id(node) != next(fp_iterator): - return False - if node in ref_index: - # append_fn(ref_index[node]) - return ref_index[node] == next(fp_iterator) - elif node in new_ref_index: - # append_fn(new_ref_index[node]) - return new_ref_index[node] == next(fp_iterator) - index = new_ref_index[node] = ctx.next_index - ctx.next_index += 1 + if is_graph_node_: + # append_fn(id(node)) + if id(node) != next(fp_iterator): + return False + if node in ref_index: + # append_fn(ref_index[node]) + return ref_index[node] == next(fp_iterator) + elif node in new_ref_index: + # append_fn(new_ref_index[node]) + return new_ref_index[node] == next(fp_iterator) + index = new_ref_index[node] = ctx.next_index + ctx.next_index += 1 + else: + index = -1 values, metadata = node_impl.flatten(node) @@ -975,7 +993,6 @@ def _get_children() -> list[tuple[Key, tp.Any]]: # if the node type does not support the creation of an empty object it means # that it cannot reference itself, so we can create its children first node = node_impl.unflatten(_get_children(), nodedef.metadata) - index_ref[nodedef.index] = node return node diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index 8298d5a8..af50ad61 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -69,7 +69,7 @@ def test_flatten(self): assert flat_state[0][1].value == 2 assert flat_state[1][1].value == 4 - assert len(refmap) == 4 + assert len(refmap) == 2 assert a['b'] in refmap assert g[3] in refmap @@ -85,7 +85,7 @@ def test_flatten_no_paths(self): assert flat_state[0] == 2 assert flat_state[1] == 4 - assert len(refmap) == 4 + assert len(refmap) == 2 assert a['b'] in refmap assert g[3] in refmap @@ -116,7 +116,7 @@ def test_unflatten_pytree(self): graphdef, state = nnx.split(g) g = nnx.merge(graphdef, state) - assert g[0] is g[2] + assert g[0] is not g[2] def test_unflatten_empty(self): a = Dict({'a': 1, 'b': nnx.Param(2)})