Skip to content

Improve support for optax.GradientTransformationExtraArgs in NNX #4545

Open
@NKlug

Description

@NKlug

At the moment, using an optax.GradientTransformationExtraArgs in NNX feels very hacky.

Consider the following example

optimizer = optax.chain(
    optax.sgd(learning_rate=1.),
    optax.scale_by_backtracking_linesearch(max_backtracking_steps=10) # this is an optax.GradientTransformationExtraArgs
)

state = nnx.Optimizer(model, optimizer)

# some loss function, e.g.
def l2_loss(model, inputs, targets):
    predictions = model(inputs)
    return optax.l2_loss(predictions, targets)

def train_step(model, loss_fn, state, batch):
    grad_loss_fn = nnx.value_and_grad(loss_fn)
    loss, grad = grad_loss_fn(model, batch['inputs'], batch['targets'])

    # hack for optax.GradientTransformationExtraArgs to work
    graph_def, _ = nnx.split(model)
    def loss_fn_wrapped(graph_state):
        m = nnx.merge(graph_def, graph_state)
        return loss_fn(m, batch['inputs'], batch['targets'])
    
    state.update(grads=grad, grad=grad, value=loss, value_fn=loss_fn_wrapped)
    return loss

# train_step is called like this:
train_step(model, l2_loss, state, batch)

I have two issues with this:

  1. Having to wrap the loss function by splitting and then merging the model feels very hacky and is not something I want to do every time I use an optax.GradientTransformationExtraArgs (in this case, optax.scale_by_backtracking_linesearch). This hack is necessary because the optax optimizer expects to pass a state to the value_fn as far as I can see. This was also was used in Issue #4144.

  2. Passing grads=grad, grad=grad feels redundant and not really clean (maybe I should also pass gradients=grad just to be sure? ^^ ). I know that it is necessary at the moment because the **kwargs are passed through to the optax.GradientTransformationExtraArgs and the naming in optax is grad.

Maybe there is a way to avoid the hack in 1. that I missed?
If not, possibly the "wrapping" could be moved inside nnx.Optimizer when required, instead of doing this manually.
I'm unsure if this is a good way of solving this, though.

Concerning 2., I think this could be solved with consistent naming conventions between optax and flax/NNX, which might be difficult.

What do you think?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions