Skip to content

[nnx] add support for standalone Variables #4606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 11, 2025

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Mar 7, 2025

What does this PR do?

Adds support for using Variables outside of graph nodes, meaning they can now be passed directly to transforms, including inside pytrees without a parent graph node, and the ability to call split, merge, and update on them directly.

rngs = nnx.Rngs(0)
w = nnx.Param(jax.random.normal(rngs(), (2, 3)))
b = nnx.Param(jnp.zeros((3,)))
count = nnx.Variable(jnp.array(0))

@nnx.jit
def linear(w, b, count, x):
  count += 1
  return x @ w + b[None]

x = jax.random.normal(rngs(), (1, 2))
y = linear(w, b, count, x)

assert count.value == 1
assert y.shape == (1, 3)

@cgarciae cgarciae force-pushed the nnx-standalone-variables branch 3 times, most recently from 21852a6 to fc95f37 Compare March 8, 2025 00:59
@cgarciae cgarciae marked this pull request as ready for review March 8, 2025 01:04
@cgarciae cgarciae force-pushed the nnx-standalone-variables branch 4 times, most recently from 0100223 to ff38c52 Compare March 10, 2025 05:57
@github-advanced-security
Copy link

This pull request sets up GitHub code scanning for this repository. Once the scans have completed and the checks have passed, the analysis results for this pull request branch will appear on this overview. Once you merge this pull request, the 'Security' tab will show more code scanning analysis results (for example, for the default branch). Depending on your configuration and choice of analysis tool, future pull requests will be annotated with code scanning analysis results. For more information about GitHub code scanning, check out the documentation.

@cgarciae cgarciae force-pushed the nnx-standalone-variables branch 3 times, most recently from 3f5354d to d9e59e0 Compare March 10, 2025 21:12
@@ -343,7 +345,9 @@ def from_tree(
with graph.merge_context(is_inner, ctxtag) as merge_ctx:
return jax.tree.map(
lambda x: merge_fn(merge_ctx, (), prefix, x)
if map_non_graph_nodes or is_node_leaf(x)
if map_non_graph_nodes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This lambda becomes kinda long & unreadable. Can you refactor it out to a function? Then you can also use this for the merge_fn call in line 374.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great idea! done

@cgarciae cgarciae force-pushed the nnx-standalone-variables branch 2 times, most recently from ca2d03e to 25ad2b6 Compare March 10, 2025 23:41
@cgarciae cgarciae force-pushed the nnx-standalone-variables branch from 25ad2b6 to d4aa248 Compare March 10, 2025 23:48
@copybara-service copybara-service bot merged commit fab37ad into main Mar 11, 2025
20 checks passed
@copybara-service copybara-service bot deleted the nnx-standalone-variables branch March 11, 2025 03:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants