Description
At the moment, using an optax.GradientTransformationExtraArgs in NNX feels very hacky.
Consider the following example
optimizer = optax.chain(
optax.sgd(learning_rate=1.),
optax.scale_by_backtracking_linesearch(max_backtracking_steps=10) # this is an optax.GradientTransformationExtraArgs
)
state = nnx.Optimizer(model, optimizer)
# some loss function, e.g.
def l2_loss(model, inputs, targets):
predictions = model(inputs)
return optax.l2_loss(predictions, targets)
def train_step(model, loss_fn, state, batch):
grad_loss_fn = nnx.value_and_grad(loss_fn)
loss, grad = grad_loss_fn(model, batch['inputs'], batch['targets'])
# hack for optax.GradientTransformationExtraArgs to work
graph_def, _ = nnx.split(model)
def loss_fn_wrapped(graph_state):
m = nnx.merge(graph_def, graph_state)
return loss_fn(m, batch['inputs'], batch['targets'])
state.update(grads=grad, grad=grad, value=loss, value_fn=loss_fn_wrapped)
return loss
# train_step is called like this:
train_step(model, l2_loss, state, batch)
I have two issues with this:
-
Having to wrap the loss function by splitting and then merging the model feels very hacky and is not something I want to do every time I use an
optax.GradientTransformationExtraArgs
(in this case,optax.scale_by_backtracking_linesearch
). This hack is necessary because the optax optimizer expects to pass astate
to thevalue_fn
as far as I can see. This was also was used in Issue #4144. -
Passing
grads=grad, grad=grad
feels redundant and not really clean (maybe I should also passgradients=grad
just to be sure? ^^ ). I know that it is necessary at the moment because the**kwargs
are passed through to theoptax.GradientTransformationExtraArgs
and the naming in optax isgrad
.
Maybe there is a way to avoid the hack in 1. that I missed?
If not, possibly the "wrapping" could be moved inside nnx.Optimizer
when required, instead of doing this manually.
I'm unsure if this is a good way of solving this, though.
Concerning 2., I think this could be solved with consistent naming conventions between optax and flax/NNX, which might be difficult.
What do you think?