Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Util to let bridge module work with NNX submodules #4584

Merged
merged 1 commit into from
Mar 6, 2025

Conversation

IvyZX
Copy link
Collaborator

@IvyZX IvyZX commented Feb 27, 2025

  • Add current_module to track the current parent
  • Add wrap_nnx_module to automate RNG distribution into pure-NNX territory, and to avoid buffer allocation

Comment on lines 324 to 325
layer = nnx.bridge.wrap_nnx_module(lambda r: NNXLayer(8, 0.3, rngs=r))
setattr(self, f'layer{i}', layer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can do a bit better and try to mimic compact with auto-naming or optional name parameter like this:

Suggested change
layer = nnx.bridge.wrap_nnx_module(lambda r: NNXLayer(8, 0.3, rngs=r))
setattr(self, f'layer{i}', layer)
layer = nnx.bridge.compact_init(NNXLayer, 8, 0.3, rngs=self.scope.rngs, name=f'layer{i}')

@IvyZX IvyZX force-pushed the linx-misc branch 5 times, most recently from b174427 to 71eba05 Compare March 5, 2025 02:21
if parent.scope.rngs:
for _, stream in graph.iter_graph(module):
if isinstance(stream, rnglib.RngStream):
stream.key.value = rngs[stream.key.tag].key.value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will break Modules that have keys and counts with additional dimensions e.g.

class FooStack(nnx.Module):
  def __init__(self, n_layers, key):
    keys = jax.random.split(key, n_layers)
    self.rngs = nnx.Rngs(keys)

Looks extreme but users could actually construct something equivalent via vmap.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the policy could be that we only rng state from apply and try to replace it here if the keys and counts are scalars.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is indeed tricky. We supply non-abstract RNG state here essentially for any code in the nnx.Module that assumes a complete nnx.Rngs instance and call its methods in call time.
I think the cleanest solution is to use nnx.jit to partially init the NNX module and only keep all RngState. Updated the PR and added a test for it.

@copybara-service copybara-service bot merged commit 8254dd0 into google:main Mar 6, 2025
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants