You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Here's a strategy for defining per-parameter learning rates using optax.multi_transform and leveraging Variable metadata.
importjaxfromflaximportnnxmodel=nnx.Linear(
2,
5,
kernel_init=nnx.with_metadata(
nnx.initializers.lecun_normal(),
learning_rate=0.2, # add learning rate metadata
),
bias_init=nnx.with_metadata(
nnx.initializers.zeros,
learning_rate=0.1, # add learning rate metadata
),
rngs=nnx.Rngs(42),
)
state=nnx.state(model, nnx.Param)
optimizers= {}
defleaf_optimizer(path, value):
# here we both populate `optimizer` and return the corresponding tag which is the learning rateoptimizers.setdefault(value.learning_rate, optax.adamw(value.learning_rate))
returnvalue.learning_ratestate_optimizer=nnx.map_state(leaf_optimizer, state)
optimizer=nnx.Optimizer(
model,
tx=optax.multi_transform(optimizers, state_optimizer),
wrt=nnx.Param,
)
@nnx.jitdeftrain_step(model, optimizer, x, y):
defloss_fn(model):
predictions=model(x)
loss=jnp.mean((predictions-y) **2)
returnlossloss, grads=nnx.value_and_grad(loss_fn)(model)
optimizer=optimizer.update(grads)
returnlossx=jax.random.normal(jax.random.key(42), (32, 2))
y=jax.random.normal(jax.random.key(43), (32, 5))
losses= []
for_inrange(50):
loss=train_step(model, optimizer, x, y)
losses.append(loss)
Using nnx.with_metadata is not necesary, the leaf_optimizer could also just decide the learning rate for each value on its own. If you control the creation of the Variables is easier to just assign the metadata directly e.g:
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Here's a strategy for defining per-parameter learning rates using
optax.multi_transform
and leveragingVariable
metadata.Using
nnx.with_metadata
is not necesary, theleaf_optimizer
could also just decide the learning rate for each value on its own. If you control the creation of the Variables is easier to just assign the metadata directly e.g:Beta Was this translation helpful? Give feedback.
All reactions