Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow all optimizer update methods to receive an optional value argument #1131

Open
carlosgmartin opened this issue Nov 9, 2024 · 3 comments

Comments

@carlosgmartin
Copy link
Contributor

carlosgmartin commented Nov 9, 2024

All update methods can receive a params argument. It is None by default.

Most update methods can receive a value argument, but not all:

import optax
from jax import numpy as jnp

params = jnp.zeros(5)
grads = jnp.ones_like(params)
value = 0.0

for opt in [
    optax.sgd(1e-3),
    optax.adam(1e-3),
    optax.adabelief(1e-3),
    optax.contrib.dadapt_adamw(),
    optax.contrib.prodigy(),
]:
    opt_state = opt.init(params)
    try:
        opt.update(grads, opt_state, params, value=value)
    except Exception as e:
        print(e)
dadapt_adamw.<locals>.update_fn() got an unexpected keyword argument 'value'
prodigy.<locals>.update_fn() got an unexpected keyword argument 'value'

This is inconvenient because it makes the interface non-uniform and requires one to call update in different ways according to the optimizer, which makes code more complex.

Feature request: Allow all update methods to receive a value argument. It can be None by default.

I can submit a PR editing dadapt_adamw and prodigy accordingly.

@vroulet
Copy link
Collaborator

vroulet commented Nov 11, 2024

Hello @carlosgmartin,

You can simply made them support extra_args using with_extra_args_support.

import optax
from jax import numpy as jnp

params = jnp.zeros(5)
grads = jnp.ones_like(params)
value = 0.0

for opt in [
    optax.sgd(1e-3),
    optax.adam(1e-3),
    optax.adabelief(1e-3),
    optax.contrib.dadapt_adamw(),
    optax.contrib.prodigy(),
]:
    opt = optax.with_extra_args_support(opt)
    opt_state = opt.init(params)
    opt.update(grads, opt_state, params, value=value)

@carlosgmartin
Copy link
Contributor Author

@vroulet Thanks! Just out of curiosity, from a design POV, what's the reason for having the with_extra_args_support wrapper, rather than just letting all optimizers receive extra args by default? That would eliminate the need to have a GradientTransformationExtraArgs separate from GradientTransformation.

@vroulet
Copy link
Collaborator

vroulet commented Nov 12, 2024

I believe it was for backward compatibility.
I fully agree that ideally the gradient transformation api should be

def init(grads_like, **extra_args):
  ...
def update(grads, state, **extra_args):
  ...

I don't know if a revamp of the API is possible at this stage unfortunately.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants