Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pytrees as graph nodes again. #4605

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 25 additions & 42 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,18 +526,11 @@ def _graph_flatten(
paths: list[PathParts] | None,
return_variables: bool,
) -> NodeDef[tp.Any] | NodeRef:
is_pytree_node_ = isinstance(node_impl, PytreeNodeImpl)
is_graph_node_ = isinstance(node_impl, GraphNodeImpl)

if not is_pytree_node_ and node in ref_index:
if node in ref_index:
return NodeRef(type(node), ref_index[node])

# only cache graph nodes
if is_graph_node_:
index = len(ref_index)
ref_index[node] = index
else:
index = -1
# assign index
ref_index[node] = index = len(ref_index)

attributes: list[
tuple[Key, Static[tp.Any] | NodeDef[tp.Any] | VariableDef | NodeRef[tp.Any]]
Expand Down Expand Up @@ -603,7 +596,7 @@ def _graph_flatten(
type=node_impl.type, # type: ignore[arg-type]
index=index,
outer_index=ref_outer_index[node]
if is_graph_node_ and ref_outer_index and node in ref_outer_index
if ref_outer_index and node in ref_outer_index
else None,
attributes=tuple(attributes),
metadata=metadata,
Expand Down Expand Up @@ -646,23 +639,18 @@ def _graph_fingerprint(
ref_index: RefMap,
new_ref_index: RefMap,
):
is_pytree_node_ = type(node_impl) is PytreeNodeImpl
is_graph_node_ = type(node_impl) is GraphNodeImpl

append_fn(type(node))

if is_graph_node_:
append_fn(id(node))
if node in ref_index:
append_fn(ref_index[node])
return
elif node in new_ref_index:
append_fn(new_ref_index[node])
return
index = new_ref_index[node] = ctx.next_index
ctx.next_index += 1
else:
index = -1
append_fn(id(node))
if node in ref_index:
append_fn(ref_index[node])
return
elif node in new_ref_index:
append_fn(new_ref_index[node])
return
index = new_ref_index[node] = ctx.next_index
ctx.next_index += 1

values, metadata = node_impl.flatten(node)

Expand Down Expand Up @@ -732,26 +720,20 @@ def _check_graph_fingerprint(
ref_index: RefMap,
new_ref_index: RefMap,
) -> bool:
is_pytree_node_ = type(node_impl) is PytreeNodeImpl
is_graph_node_ = type(node_impl) is GraphNodeImpl

if type(node) != next(fp_iterator):
return False

if is_graph_node_:
# append_fn(id(node))
if id(node) != next(fp_iterator):
return False
if node in ref_index:
# append_fn(ref_index[node])
return ref_index[node] == next(fp_iterator)
elif node in new_ref_index:
# append_fn(new_ref_index[node])
return new_ref_index[node] == next(fp_iterator)
index = new_ref_index[node] = ctx.next_index
ctx.next_index += 1
else:
index = -1
# append_fn(id(node))
if id(node) != next(fp_iterator):
return False
if node in ref_index:
# append_fn(ref_index[node])
return ref_index[node] == next(fp_iterator)
elif node in new_ref_index:
# append_fn(new_ref_index[node])
return new_ref_index[node] == next(fp_iterator)
index = new_ref_index[node] = ctx.next_index
ctx.next_index += 1

values, metadata = node_impl.flatten(node)

Expand Down Expand Up @@ -993,6 +975,7 @@ def _get_children() -> list[tuple[Key, tp.Any]]:
# if the node type does not support the creation of an empty object it means
# that it cannot reference itself, so we can create its children first
node = node_impl.unflatten(_get_children(), nodedef.metadata)
index_ref[nodedef.index] = node

return node

Expand Down
6 changes: 3 additions & 3 deletions tests/nnx/graph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_flatten(self):
assert flat_state[0][1].value == 2
assert flat_state[1][1].value == 4

assert len(refmap) == 2
assert len(refmap) == 4
assert a['b'] in refmap
assert g[3] in refmap

Expand All @@ -85,7 +85,7 @@ def test_flatten_no_paths(self):
assert flat_state[0] == 2
assert flat_state[1] == 4

assert len(refmap) == 2
assert len(refmap) == 4
assert a['b'] in refmap
assert g[3] in refmap

Expand Down Expand Up @@ -116,7 +116,7 @@ def test_unflatten_pytree(self):
graphdef, state = nnx.split(g)
g = nnx.merge(graphdef, state)

assert g[0] is not g[2]
assert g[0] is g[2]

def test_unflatten_empty(self):
a = Dict({'a': 1, 'b': nnx.Param(2)})
Expand Down
Loading