Skip to content

Commit

Permalink
Add pytrees as graph nodes again.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 734367472
  • Loading branch information
Cristian Garcia authored and Flax Authors committed Mar 7, 2025
1 parent e3789de commit 728556a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 45 deletions.
67 changes: 25 additions & 42 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,18 +526,11 @@ def _graph_flatten(
paths: list[PathParts] | None,
return_variables: bool,
) -> NodeDef[tp.Any] | NodeRef:
is_pytree_node_ = isinstance(node_impl, PytreeNodeImpl)
is_graph_node_ = isinstance(node_impl, GraphNodeImpl)

if not is_pytree_node_ and node in ref_index:
if node in ref_index:
return NodeRef(type(node), ref_index[node])

# only cache graph nodes
if is_graph_node_:
index = len(ref_index)
ref_index[node] = index
else:
index = -1
# assign index
ref_index[node] = index = len(ref_index)

attributes: list[
tuple[Key, Static[tp.Any] | NodeDef[tp.Any] | VariableDef | NodeRef[tp.Any]]
Expand Down Expand Up @@ -603,7 +596,7 @@ def _graph_flatten(
type=node_impl.type, # type: ignore[arg-type]
index=index,
outer_index=ref_outer_index[node]
if is_graph_node_ and ref_outer_index and node in ref_outer_index
if ref_outer_index and node in ref_outer_index
else None,
attributes=tuple(attributes),
metadata=metadata,
Expand Down Expand Up @@ -646,23 +639,18 @@ def _graph_fingerprint(
ref_index: RefMap,
new_ref_index: RefMap,
):
is_pytree_node_ = type(node_impl) is PytreeNodeImpl
is_graph_node_ = type(node_impl) is GraphNodeImpl

append_fn(type(node))

if is_graph_node_:
append_fn(id(node))
if node in ref_index:
append_fn(ref_index[node])
return
elif node in new_ref_index:
append_fn(new_ref_index[node])
return
index = new_ref_index[node] = ctx.next_index
ctx.next_index += 1
else:
index = -1
append_fn(id(node))
if node in ref_index:
append_fn(ref_index[node])
return
elif node in new_ref_index:
append_fn(new_ref_index[node])
return
index = new_ref_index[node] = ctx.next_index
ctx.next_index += 1

values, metadata = node_impl.flatten(node)

Expand Down Expand Up @@ -732,26 +720,20 @@ def _check_graph_fingerprint(
ref_index: RefMap,
new_ref_index: RefMap,
) -> bool:
is_pytree_node_ = type(node_impl) is PytreeNodeImpl
is_graph_node_ = type(node_impl) is GraphNodeImpl

if type(node) != next(fp_iterator):
return False

if is_graph_node_:
# append_fn(id(node))
if id(node) != next(fp_iterator):
return False
if node in ref_index:
# append_fn(ref_index[node])
return ref_index[node] == next(fp_iterator)
elif node in new_ref_index:
# append_fn(new_ref_index[node])
return new_ref_index[node] == next(fp_iterator)
index = new_ref_index[node] = ctx.next_index
ctx.next_index += 1
else:
index = -1
# append_fn(id(node))
if id(node) != next(fp_iterator):
return False
if node in ref_index:
# append_fn(ref_index[node])
return ref_index[node] == next(fp_iterator)
elif node in new_ref_index:
# append_fn(new_ref_index[node])
return new_ref_index[node] == next(fp_iterator)
index = new_ref_index[node] = ctx.next_index
ctx.next_index += 1

values, metadata = node_impl.flatten(node)

Expand Down Expand Up @@ -993,6 +975,7 @@ def _get_children() -> list[tuple[Key, tp.Any]]:
# if the node type does not support the creation of an empty object it means
# that it cannot reference itself, so we can create its children first
node = node_impl.unflatten(_get_children(), nodedef.metadata)
index_ref[nodedef.index] = node

return node

Expand Down
6 changes: 3 additions & 3 deletions tests/nnx/graph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_flatten(self):
assert flat_state[0][1].value == 2
assert flat_state[1][1].value == 4

assert len(refmap) == 2
assert len(refmap) == 4
assert a['b'] in refmap
assert g[3] in refmap

Expand All @@ -85,7 +85,7 @@ def test_flatten_no_paths(self):
assert flat_state[0] == 2
assert flat_state[1] == 4

assert len(refmap) == 2
assert len(refmap) == 4
assert a['b'] in refmap
assert g[3] in refmap

Expand Down Expand Up @@ -116,7 +116,7 @@ def test_unflatten_pytree(self):
graphdef, state = nnx.split(g)
g = nnx.merge(graphdef, state)

assert g[0] is not g[2]
assert g[0] is g[2]

def test_unflatten_empty(self):
a = Dict({'a': 1, 'b': nnx.Param(2)})
Expand Down

0 comments on commit 728556a

Please sign in to comment.