How can I make non-trainable variable using nnx.Module? #4533
Unanswered
SangminLee0828
asked this question in
Q&A
Replies: 2 comments 3 replies
-
My quick take is : by assigning those parameter to self, they are found by jax.grad as parameters to differentiate with (altough I am suprised because nnx.param exists for a good reason). however, if you want something purely static, maybe the usage of nnx.Module or a class overall is not needed ? (if you want to make variance and mean on the fly) |
Beta Was this translation helpful? Give feedback.
0 replies
-
The solution is to create a filter for the train trainable Variable and pass it to both class Classifier(nnx.Module):
def __init__(self, embed_dim, num_classes, backbone, rngs):
self.backbone = backbone
self.head = nnx.Linear(embed_dim, num_classes, rngs=rngs)
def __call__(self, x):
x = self.backbone(x)
x = self.head(x)
return x
def load_model():
return nnx.Linear(784, 1024, rngs=nnx.Rngs(0))
backbone = load_model()
classifier = Classifier(1024, 10, backbone, rngs=nnx.Rngs(1))
# filter to select only Params on head path
head_params = nnx.All(nnx.Param, nnx.PathContains('head'))
optimizer = nnx.Optimizer(
classifier,
tx=optax.adamw(3e-4),
wrt=head_params, # filter head params
)
# simple train step
@nnx.jit
def train_step(model, optimizer, x, y):
def loss_fn(model):
logits = model(x)
return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
diff_state = nnx.DiffState(0, head_params) # filter head params of the first argument
grads = nnx.grad(loss_fn, argnums=diff_state)(model)
optimizer.update(grads)
x = jnp.ones((1, 784))
y = jnp.ones((1,), jnp.int32)
train_step(classifier, optimizer, x, y) |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I am trying to create a normalization layer. This normalization layer has 'mean' and 'variance' inside, so when the values come in, the output values will be normalized value using the stored mean and variance.
https://www.tensorflow.org/api_docs/python/tf/keras/layers/Normalization
How can I make 'self.mean' and 'self.variance' not trainable?
Beta Was this translation helpful? Give feedback.
All reactions