-
Notifications
You must be signed in to change notification settings - Fork 679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Util to let bridge module work with NNX submodules #4584
Conversation
tests/nnx/bridge/module_test.py
Outdated
layer = nnx.bridge.wrap_nnx_module(lambda r: NNXLayer(8, 0.3, rngs=r)) | ||
setattr(self, f'layer{i}', layer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can do a bit better and try to mimic compact with auto-naming or optional name
parameter like this:
layer = nnx.bridge.wrap_nnx_module(lambda r: NNXLayer(8, 0.3, rngs=r)) | |
setattr(self, f'layer{i}', layer) | |
layer = nnx.bridge.compact_init(NNXLayer, 8, 0.3, rngs=self.scope.rngs, name=f'layer{i}') |
b174427
to
71eba05
Compare
flax/nnx/bridge/module.py
Outdated
if parent.scope.rngs: | ||
for _, stream in graph.iter_graph(module): | ||
if isinstance(stream, rnglib.RngStream): | ||
stream.key.value = rngs[stream.key.tag].key.value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will break Modules that have keys
and counts
with additional dimensions e.g.
class FooStack(nnx.Module):
def __init__(self, n_layers, key):
keys = jax.random.split(key, n_layers)
self.rngs = nnx.Rngs(keys)
Looks extreme but users could actually construct something equivalent via vmap
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe the policy could be that we only rng state from apply
and try to replace it here if the keys and counts are scalars.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is indeed tricky. We supply non-abstract RNG state here essentially for any code in the nnx.Module that assumes a complete nnx.Rngs
instance and call its methods in call time.
I think the cleanest solution is to use nnx.jit
to partially init the NNX module and only keep all RngState
. Updated the PR and added a test for it.
current_module
to track the current parentwrap_nnx_module
to automate RNG distribution into pure-NNX territory, and to avoid buffer allocation