Skip to content

Error: vmapped nnx.Module initialization with selective Variable broadcasting #4526

Closed Answered by cgarciae
rademacher-p asked this question in Q&A
Discussion options

You must be logged in to vote

Hey @rademacher-p, you need two different RNG keys, one that initializes Params as broadcasts, and one that initializes MyVars vectorized. Easiest way to do this is to use named streams in Rngs, create a params stream and a vars stream, tell split_rngs to only split the vars keys, and use the stream names to sample some keys. Here's a running sample:

import jax
import jax.numpy as jnp
from flax import nnx

class MyVar(nnx.Variable):
  pass

state_axes = nnx.StateAxes({(MyVar, 'vars'): 0, ...: None})

class MyModule(nnx.Module):
  @nnx.split_rngs(splits=2, only='vars')
  @nnx.vmap(in_axes=(state_axes, state_axes))
  def __init__(self, rngs):
    self.param = nnx.Param(jax.random.uniform(rngs.

Replies: 2 comments 4 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
4 replies
@rademacher-p
Comment options

@cgarciae
Comment options

@cgarciae
Comment options

@rademacher-p
Comment options

Answer selected by rademacher-p
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants