Closed
Description
All update
methods can receive a params
argument. It is None
by default.
Most update
methods can receive a value
argument, but not all:
import optax
from jax import numpy as jnp
params = jnp.zeros(5)
grads = jnp.ones_like(params)
value = 0.0
for opt in [
optax.sgd(1e-3),
optax.adam(1e-3),
optax.adabelief(1e-3),
optax.contrib.dadapt_adamw(),
optax.contrib.prodigy(),
]:
opt_state = opt.init(params)
try:
opt.update(grads, opt_state, params, value=value)
except Exception as e:
print(e)
dadapt_adamw.<locals>.update_fn() got an unexpected keyword argument 'value'
prodigy.<locals>.update_fn() got an unexpected keyword argument 'value'
This is inconvenient because it makes the interface non-uniform and requires one to call update
in different ways according to the optimizer, which makes code more complex.
Feature request: Allow all update
methods to receive a value
argument. It can be None
by default.
I can submit a PR editing dadapt_adamw
and prodigy
accordingly.
Metadata
Metadata
Assignees
Labels
No labels