diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 6fb73ed4..e8930a78 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -526,18 +526,11 @@ def _graph_flatten( paths: list[PathParts] | None, return_variables: bool, ) -> NodeDef[tp.Any] | NodeRef: - is_pytree_node_ = isinstance(node_impl, PytreeNodeImpl) - is_graph_node_ = isinstance(node_impl, GraphNodeImpl) - - if not is_pytree_node_ and node in ref_index: + if node in ref_index: return NodeRef(type(node), ref_index[node]) - # only cache graph nodes - if is_graph_node_: - index = len(ref_index) - ref_index[node] = index - else: - index = -1 + # assign index + ref_index[node] = index = len(ref_index) attributes: list[ tuple[Key, Static[tp.Any] | NodeDef[tp.Any] | VariableDef | NodeRef[tp.Any]] @@ -603,7 +596,7 @@ def _graph_flatten( type=node_impl.type, # type: ignore[arg-type] index=index, outer_index=ref_outer_index[node] - if is_graph_node_ and ref_outer_index and node in ref_outer_index + if ref_outer_index and node in ref_outer_index else None, attributes=tuple(attributes), metadata=metadata, @@ -646,23 +639,18 @@ def _graph_fingerprint( ref_index: RefMap, new_ref_index: RefMap, ): - is_pytree_node_ = type(node_impl) is PytreeNodeImpl - is_graph_node_ = type(node_impl) is GraphNodeImpl append_fn(type(node)) - if is_graph_node_: - append_fn(id(node)) - if node in ref_index: - append_fn(ref_index[node]) - return - elif node in new_ref_index: - append_fn(new_ref_index[node]) - return - index = new_ref_index[node] = ctx.next_index - ctx.next_index += 1 - else: - index = -1 + append_fn(id(node)) + if node in ref_index: + append_fn(ref_index[node]) + return + elif node in new_ref_index: + append_fn(new_ref_index[node]) + return + index = new_ref_index[node] = ctx.next_index + ctx.next_index += 1 values, metadata = node_impl.flatten(node) @@ -732,26 +720,20 @@ def _check_graph_fingerprint( ref_index: RefMap, new_ref_index: RefMap, ) -> bool: - is_pytree_node_ = type(node_impl) is PytreeNodeImpl - is_graph_node_ = type(node_impl) is GraphNodeImpl - if type(node) != next(fp_iterator): return False - if is_graph_node_: - # append_fn(id(node)) - if id(node) != next(fp_iterator): - return False - if node in ref_index: - # append_fn(ref_index[node]) - return ref_index[node] == next(fp_iterator) - elif node in new_ref_index: - # append_fn(new_ref_index[node]) - return new_ref_index[node] == next(fp_iterator) - index = new_ref_index[node] = ctx.next_index - ctx.next_index += 1 - else: - index = -1 + # append_fn(id(node)) + if id(node) != next(fp_iterator): + return False + if node in ref_index: + # append_fn(ref_index[node]) + return ref_index[node] == next(fp_iterator) + elif node in new_ref_index: + # append_fn(new_ref_index[node]) + return new_ref_index[node] == next(fp_iterator) + index = new_ref_index[node] = ctx.next_index + ctx.next_index += 1 values, metadata = node_impl.flatten(node) @@ -993,6 +975,7 @@ def _get_children() -> list[tuple[Key, tp.Any]]: # if the node type does not support the creation of an empty object it means # that it cannot reference itself, so we can create its children first node = node_impl.unflatten(_get_children(), nodedef.metadata) + index_ref[nodedef.index] = node return node diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index af50ad61..8298d5a8 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -69,7 +69,7 @@ def test_flatten(self): assert flat_state[0][1].value == 2 assert flat_state[1][1].value == 4 - assert len(refmap) == 2 + assert len(refmap) == 4 assert a['b'] in refmap assert g[3] in refmap @@ -85,7 +85,7 @@ def test_flatten_no_paths(self): assert flat_state[0] == 2 assert flat_state[1] == 4 - assert len(refmap) == 2 + assert len(refmap) == 4 assert a['b'] in refmap assert g[3] in refmap @@ -116,7 +116,7 @@ def test_unflatten_pytree(self): graphdef, state = nnx.split(g) g = nnx.merge(graphdef, state) - assert g[0] is not g[2] + assert g[0] is g[2] def test_unflatten_empty(self): a = Dict({'a': 1, 'b': nnx.Param(2)})