-
Notifications
You must be signed in to change notification settings - Fork 679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve support for optax.GradientTransformationExtraArgs in NNX #4545
Comments
A related issue I have had is that the The coding pattern there makes things harder because it does a I am not sure how to structure my code to use that faster pattern and also support linesearch. @cgarciae Is that approach (i.e.,https://flax.readthedocs.io/en/latest/guides/performance.html#functional-training-loop ) Is it one or the other? If the new cache_args is generally preferred, then maybe a unit test and example showing the line-search with it might help? Or, alternatively, if the functional style is preferred then maybe an example using that with linesearch might be useful? |
@NKlug thanks for bringing this up, @jlperla was also interested in this a while back. The conclusion back then was that for now
class Optimizer(Object):
def __init__(
self,
model: nnx.Module,
tx: optax.GradientTransformation,
wrt: filterlib.Filter = nnx.Param,
):
self.step = OptState(jnp.array(0, dtype=jnp.uint32))
self.model = model
self.tx = tx
self.opt_state = _wrap_optimizer_state(tx.init(nnx.state(model, wrt)))
self.wrt = wrt
def update(self, grads, **kwargs):
params = nnx.state(self.model, self.wrt)
opt_state = _opt_state_variables_to_state(self.opt_state)
updates, new_opt_state = self.tx.update(grads, opt_state, params, **kwargs)
new_params = optax.apply_updates(params, updates)
assert isinstance(new_params, nnx.State)
self.step.value += 1
nnx.update(self.model, new_params)
_update_opt_state(self.opt_state, new_opt_state) That said, we don't have plans for developing new optimizers for now but I'd be happy to review a PR if you want to tackle the line search use case. |
@jlperla the docs for Also, if you have example of the functional API becoming to painful verbose, please do bring them up, I'd like to learn what type of work loads users are running to we can consider simplifying them. |
At the moment, using an optax.GradientTransformationExtraArgs in NNX feels very hacky.
Consider the following example
I have two issues with this:
Having to wrap the loss function by splitting and then merging the model feels very hacky and is not something I want to do every time I use an
optax.GradientTransformationExtraArgs
(in this case,optax.scale_by_backtracking_linesearch
). This hack is necessary because the optax optimizer expects to pass astate
to thevalue_fn
as far as I can see. This was also was used in Issue #4144.Passing
grads=grad, grad=grad
feels redundant and not really clean (maybe I should also passgradients=grad
just to be sure? ^^ ). I know that it is necessary at the moment because the**kwargs
are passed through to theoptax.GradientTransformationExtraArgs
and the naming in optax isgrad
.Maybe there is a way to avoid the hack in 1. that I missed?
If not, possibly the "wrapping" could be moved inside
nnx.Optimizer
when required, instead of doing this manually.I'm unsure if this is a good way of solving this, though.
Concerning 2., I think this could be solved with consistent naming conventions between optax and flax/NNX, which might be difficult.
What do you think?
The text was updated successfully, but these errors were encountered: