Skip to content

Commit

Permalink
Merge pull request #1916 from berndbohnet:nlp_seq_optax
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 455076620
  • Loading branch information
Flax Authors committed Jun 15, 2022
2 parents 43c8057 + 3a3df8d commit 66b4a0e
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 26 deletions.
2 changes: 1 addition & 1 deletion examples/nlp_seq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ The model should run with other configurations and hardware, but explicitly test
| Hardware | Batch size | Learning rate | Training time | Accuracy | TensorBoard.dev |
|:---:|:---:|:---:|:---:|:---:|:---:|
| Nvidia V100 (16GB) | 64 | 0.05 | 5h 15m | 72.20% | [2020-03-22](https://tensorboard.dev/experiment/YkUAdwYaQ9OtYl2IVe3MvA/) |
| Nvidia Titan V (12GB) | 64 | 0.05 | 5:58h | 68.6% | [2022-05-01](https://tensorboard.dev/experiment/F5ULHlyzQlieVJn5PG8mRQ/) |
### Running
```
Expand Down
63 changes: 38 additions & 25 deletions examples/nlp_seq/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@
from flax import jax_utils
from flax import linen as nn
from flax.metrics import tensorboard
from flax import optim
from flax.training import common_utils
from flax.training import train_state
import jax
import jax.numpy as jnp
from jax import random
import numpy as np
import optax
import tensorflow as tf

import input_pipeline
Expand Down Expand Up @@ -191,35 +192,37 @@ def compute_metrics(logits, labels, weights):
return metrics


def train_step(optimizer, batch, learning_rate_fn, model, dropout_rng=None):
def train_step(state,
batch,
model,
learning_rate_fn,
dropout_rng=None):
"""Perform a single training step."""

train_keys = ['inputs', 'targets']
(inputs, targets) = (batch.get(k, None) for k in train_keys)

weights = jnp.where(targets > 0, 1, 0).astype(jnp.float32)
dropout_rng, new_dropout_rng = random.split(dropout_rng)

dropout_rng = jax.random.fold_in(dropout_rng, state.step)

def loss_fn(params):
"""Loss function used for training."""
"""loss function used for training."""
logits = model.apply({'params': params}, inputs=inputs, train=True,
rngs={'dropout': dropout_rng})
loss, weight_sum = compute_weighted_cross_entropy(logits, targets, weights)

mean_loss = loss / weight_sum
return mean_loss, logits

step = optimizer.state.step
lr = learning_rate_fn(step)
lr = learning_rate_fn(state.step)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grad = grad_fn(optimizer.target)
grad = jax.lax.pmean(grad, 'batch')
new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr)

(_, logits), grads = grad_fn(state.params)
grads = jax.lax.pmean(grads, "batch")
new_state = state.apply_gradients(grads=grads)
metrics = compute_metrics(logits, targets, weights)
metrics['learning_rate'] = lr

return new_optimizer, metrics, new_dropout_rng

metrics["learning_rate"] = lr

return new_state, metrics


def pad_examples(x, desired_batch_size):
Expand Down Expand Up @@ -296,17 +299,27 @@ def initialize_variables(init_rng):
return init_variables
init_variables = initialize_variables(init_rng)

optimizer_def = optim.Adam(learning_rate, beta1=0.9, beta2=0.98,
eps=1e-9, weight_decay=1e-1)
optimizer = optimizer_def.create(init_variables['params'])
optimizer = jax_utils.replicate(optimizer)

learning_rate_fn = create_learning_rate_scheduler(
base_learning_rate=learning_rate)

optimizer = optax.adamw(
learning_rate_fn, b1=0.9, b2=0.98, eps=1e-9,
weight_decay=1e-1)
state = train_state.TrainState.create(
apply_fn=model.apply,
params=init_variables["params"],
tx=optimizer)

# Replicate optimizer.
state = jax_utils.replicate(state)

p_train_step = jax.pmap(
functools.partial(train_step, model=model, learning_rate_fn=learning_rate_fn),
axis_name='batch')
functools.partial(
train_step,
model=model,
learning_rate_fn=learning_rate_fn),
axis_name='batch',
donate_argnums=(0,)) # pytype: disable=wrong-arg-types

def eval_step(params, batch):
"""Calculate evaluation metrics on a batch."""
Expand All @@ -326,9 +339,8 @@ def eval_step(params, batch):
for step, batch in zip(range(num_train_steps), train_iter):
batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access

optimizer, metrics, dropout_rngs = p_train_step(optimizer, batch, dropout_rng=dropout_rngs)
state, metrics = p_train_step(state, batch, dropout_rng=dropout_rngs)
metrics_all.append(metrics)

if (step + 1) % eval_freq == 0:
metrics_all = common_utils.get_metrics(metrics_all)
lr = metrics_all.pop('learning_rate').mean()
Expand Down Expand Up @@ -361,7 +373,8 @@ def eval_step(params, batch):
lambda x: pad_examples(x, batch_size), eval_batch)
eval_batch = common_utils.shard(eval_batch)

metrics = p_eval_step(optimizer.target, eval_batch)
metrics = p_eval_step(state.params, eval_batch)

eval_metrics.append(metrics)
eval_metrics = common_utils.get_metrics(eval_metrics)
eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
Expand Down

0 comments on commit 66b4a0e

Please sign in to comment.