Error: vmapped nnx.Module
initialization with selective Variable
broadcasting
#4526
-
I'm trying to use Two commented lines provide modifications that succeed, but are not quite what I want.
import jax
import jax.numpy as jnp
from flax import nnx
class MyVar(nnx.Variable):
pass
state_axes = nnx.StateAxes({(nnx.RngState, MyVar): 0, ...: None})
# state_axes = nnx.StateAxes({(nnx.RngState, nnx.Param, MyVar): 0, ...: None})
class MyModule(nnx.Module):
@nnx.split_rngs(splits=2)
@nnx.vmap(in_axes=(state_axes, 0))
def __init__(self, rngs):
self.param = nnx.Param(jax.random.uniform(rngs()))
# self.param = nnx.Param(jnp.float32(0.0))
self.var = MyVar(jax.random.uniform(rngs()))
rngs = nnx.Rngs(123)
model = MyModule(rngs) The error is Any suggestions, anyone? @cgarciae Shameless ping. Seems like You ~= NNX 😄 |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
I'm guessing I don't need that |
Beta Was this translation helpful? Give feedback.
-
Hey @rademacher-p, you need two different RNG keys, one that initializes import jax
import jax.numpy as jnp
from flax import nnx
class MyVar(nnx.Variable):
pass
state_axes = nnx.StateAxes({(MyVar, 'vars'): 0, ...: None})
class MyModule(nnx.Module):
@nnx.split_rngs(splits=2, only='vars')
@nnx.vmap(in_axes=(state_axes, state_axes))
def __init__(self, rngs):
self.param = nnx.Param(jax.random.uniform(rngs.params()))
# self.param = nnx.Param(jnp.float32(0.0))
self.var = MyVar(jax.random.uniform(rngs.vars()))
rngs = nnx.Rngs(params=1, vars=2)
model = MyModule(rngs) Also updated |
Beta Was this translation helpful? Give feedback.
Hey @rademacher-p, you need two different RNG keys, one that initializes
Param
s as broadcasts, and one that initializesMyVar
s vectorized. Easiest way to do this is to use named streams inRngs
, create aparams
stream and avars
stream, tellsplit_rngs
to only split thevars
keys, and use the stream names to sample some keys. Here's a running sample: