Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for optax.GradientTransformationExtraArgs in NNX #4545

Open
NKlug opened this issue Feb 13, 2025 · 3 comments
Open

Improve support for optax.GradientTransformationExtraArgs in NNX #4545

NKlug opened this issue Feb 13, 2025 · 3 comments

Comments

@NKlug
Copy link

NKlug commented Feb 13, 2025

At the moment, using an optax.GradientTransformationExtraArgs in NNX feels very hacky.

Consider the following example

optimizer = optax.chain(
    optax.sgd(learning_rate=1.),
    optax.scale_by_backtracking_linesearch(max_backtracking_steps=10) # this is an optax.GradientTransformationExtraArgs
)

state = nnx.Optimizer(model, optimizer)

# some loss function, e.g.
def l2_loss(model, inputs, targets):
    predictions = model(inputs)
    return optax.l2_loss(predictions, targets)

def train_step(model, loss_fn, state, batch):
    grad_loss_fn = nnx.value_and_grad(loss_fn)
    loss, grad = grad_loss_fn(model, batch['inputs'], batch['targets'])

    # hack for optax.GradientTransformationExtraArgs to work
    graph_def, _ = nnx.split(model)
    def loss_fn_wrapped(graph_state):
        m = nnx.merge(graph_def, graph_state)
        return loss_fn(m, batch['inputs'], batch['targets'])
    
    state.update(grads=grad, grad=grad, value=loss, value_fn=loss_fn_wrapped)
    return loss

# train_step is called like this:
train_step(model, l2_loss, state, batch)

I have two issues with this:

  1. Having to wrap the loss function by splitting and then merging the model feels very hacky and is not something I want to do every time I use an optax.GradientTransformationExtraArgs (in this case, optax.scale_by_backtracking_linesearch). This hack is necessary because the optax optimizer expects to pass a state to the value_fn as far as I can see. This was also was used in Issue #4144.

  2. Passing grads=grad, grad=grad feels redundant and not really clean (maybe I should also pass gradients=grad just to be sure? ^^ ). I know that it is necessary at the moment because the **kwargs are passed through to the optax.GradientTransformationExtraArgs and the naming in optax is grad.

Maybe there is a way to avoid the hack in 1. that I missed?
If not, possibly the "wrapping" could be moved inside nnx.Optimizer when required, instead of doing this manually.
I'm unsure if this is a good way of solving this, though.

Concerning 2., I think this could be solved with consistent naming conventions between optax and flax/NNX, which might be difficult.

What do you think?

@jlperla
Copy link
Contributor

jlperla commented Feb 13, 2025

A related issue I have had is that the loss_fn_wrapped etc. becomes more complicated if you use the split/merge trick suggested in #4045 (comment) and https://flax.readthedocs.io/en/latest/guides/performance.html#functional-training-loop

The coding pattern there makes things harder because it does a graphdef, state = nnx.split((model, optimizer, metrics)) so that the graphdef no longer is just for the model. This means that you can't just reuse the graph_def for the nested loss_fn_wrapped above. Not insurmountable, but efficient coding is unclear to me.

I am not sure how to structure my code to use that faster pattern and also support linesearch.

@cgarciae Is that approach (i.e.,https://flax.readthedocs.io/en/latest/guides/performance.html#functional-training-loop )
still recommended for performance? I am a little confused with the new https://flax.readthedocs.io/en/latest/guides/performance.html#caching-graph-node-traversals

Is it one or the other? If the new cache_args is generally preferred, then maybe a unit test and example showing the line-search with it might help? Or, alternatively, if the functional style is preferred then maybe an example using that with linesearch might be useful?

@cgarciae
Copy link
Collaborator

cgarciae commented Feb 13, 2025

@NKlug thanks for bringing this up, @jlperla was also interested in this a while back. The conclusion back then was that for now nnx.Optimizer would force the user to implement any of the variants for GradientTransformationExtraArgs. Given that there is no unified behavior here (a new transforms could introduce new arguments with entirely new callbacks) then what would make more sense is:

  1. Implement new Optimizer types for some of the common use cases. For example, we could add a nnx.optimizers.BacktrackingLinesearch type that would handle the common patters for line search e.g. it could give you a Module in the value_fn.
  2. Users should be encouraged to extend Optimizer by creating their own types. Excluding docstrings and helper functions the implementation of Optimizer is not very complex:
class Optimizer(Object):
  def __init__(
    self,
    model: nnx.Module,
    tx: optax.GradientTransformation,
    wrt: filterlib.Filter = nnx.Param,
  ):
    self.step = OptState(jnp.array(0, dtype=jnp.uint32))
    self.model = model
    self.tx = tx
    self.opt_state = _wrap_optimizer_state(tx.init(nnx.state(model, wrt)))
    self.wrt = wrt

  def update(self, grads, **kwargs):
    params = nnx.state(self.model, self.wrt)
    opt_state = _opt_state_variables_to_state(self.opt_state)

    updates, new_opt_state = self.tx.update(grads, opt_state, params, **kwargs)
    new_params = optax.apply_updates(params, updates)
    assert isinstance(new_params, nnx.State)

    self.step.value += 1
    nnx.update(self.model, new_params)
    _update_opt_state(self.opt_state, new_opt_state)

That said, we don't have plans for developing new optimizers for now but I'd be happy to review a PR if you want to tackle the line search use case.

@cgarciae
Copy link
Collaborator

cgarciae commented Feb 13, 2025

@jlperla the docs for nnx.cached_partial include an example, and we also have the TestJIT.test_cache_args unit test. The API is very simple so I think this may be enough but let me know if you have any questions. The updated Performance guide gives you some of the options available for completeness, I think the cached_partial API is what I would recommend for most users, if that is not good enough then using split / merge would be the final resort.

Also, if you have example of the functional API becoming to painful verbose, please do bring them up, I'd like to learn what type of work loads users are running to we can consider simplifying them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants