diff --git a/CHANGELOG.md b/CHANGELOG.md index 0dad54899..bbe1c8b68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,8 @@ vNext Breaking changes: - flax.deprecated.nn is removed. Please pin to flax==0.3.6 if you are still using it. - PixelCNN++ example is removed. It was not working well on TPU. +- linen Normalization layers no longer downcast double and complex floats tofloat32 + when computing the mean and variance. New features: - Added `flax.linen.custom_vjp` for custom derivatives inside a `Module`. diff --git a/docs/design_notes/arguments.md b/docs/design_notes/arguments.md index 25eb8f5c4..40212e9ac 100644 --- a/docs/design_notes/arguments.md +++ b/docs/design_notes/arguments.md @@ -85,9 +85,9 @@ It also avoids a default value which would probably cause either the train step -## Functional Core and flax.nn +## Functional Core -The old NN api and functional core define functions rather than classes. -Therefore, there is no clear distinction between hyper parameters and call time arguments. -The only way to pre-determine the hyper parameters is by using `partial`. +Functional core defines functions rather than classes. +Therefore, there is no clear distinction between hyperparameters and call-time arguments. +The only way to pre-determine the hyperparameters is by using `partial`. On the upside, there are no ambiguous cases where method arguments could also be attributes. diff --git a/docs/flax.nn.rst b/docs/flax.nn.rst deleted file mode 100644 index d853712d9..000000000 --- a/docs/flax.nn.rst +++ /dev/null @@ -1,115 +0,0 @@ - -.. warning:: - **This package is deprecated**. See :mod:`flax.linen` for our new module API. - -flax.nn package (deprecated) -================= - -.. currentmodule:: flax.nn - - -Core: Module abstraction ------------------------- - -.. autoclass:: Module - :members: init, init_by_shape, partial, shared, apply, param, get_param, state, is_stateful, is_initializing - -Core: Additional ------------------------- - -.. autosummary:: - :toctree: _autosummary - - module - Model - Collection - capture_module_outputs - stateful - get_state - module_method - - -Linear modules ------------------------- - -.. autosummary:: - :toctree: _autosummary - - Dense - DenseGeneral - Conv - Embed - - -Normalization ------------------------- - -.. autosummary:: - :toctree: _autosummary - - BatchNorm - LayerNorm - GroupNorm - - -Pooling ------------------------- - -.. autosummary:: - :toctree: _autosummary - - max_pool - avg_pool - - -Activation functions ------------------------- - -.. autosummary:: - :toctree: _autosummary - - celu - elu - gelu - glu - log_sigmoid - log_softmax - relu - sigmoid - soft_sign - softmax - softplus - swish - - -Stochastic functions ------------------------- - -.. autosummary:: - :toctree: _autosummary - - make_rng - stochastic - is_stochastic - dropout - - -Attention primitives ------------------------- - -.. autosummary:: - :toctree: _autosummary - - dot_product_attention - SelfAttention - - -RNN primitives ------------------------- - -.. autosummary:: - :toctree: _autosummary - - LSTMCell - OptimizedLSTMCell - GRUCell diff --git a/docs/howtos/ensembling.rst b/docs/howtos/ensembling.rst index 8f04fc0ca..2ab72074f 100644 --- a/docs/howtos/ensembling.rst +++ b/docs/howtos/ensembling.rst @@ -1,307 +1,274 @@ Ensembling on multiple devices -============================= +============================== We show how to train an ensemble of CNNs on the MNIST dataset, where the size of the ensemble is equal to the number of available devices. In short, this change be described as: -* make a number of functions parallel using ``jax.pmap``, -* replicate the inputs carefully, -* make sure the parallel and non-parallel logic interacts correctly. +* make a number of functions parallel using |jax.pmap()|_, +* split the random seed to obtain different parameter initialization, +* replicate the inputs and unreplicate the outputs where necessary, +* average probabilities across devices to compute the predictions. In this HOWTO we omit some of the code such as imports, the CNN module, and metrics computation, but they can be found in the `MNIST example`_. .. testsetup:: - # Since this HOWTO's code is part of our tests (which are often ran locally on - # CPU), we use a very small CNN, we only run for 1 epoch, and we make sure we - # are using mock data. + import functools + from flax import jax_utils + # Copied from examples/mnist/train.py from absl import logging - from flax import jax_utils from flax import linen as nn - from flax import optim - from flax.metrics import tensorboard + from flax.training import train_state import jax import jax.numpy as jnp - from jax import random - import ml_collections import numpy as np - import tensorflow_datasets as tfds - import functools - - num_epochs = 1 + import optax class CNN(nn.Module): + """A simple CNN model.""" + @nn.compact def __call__(self, x): - x = nn.Conv(features=1, kernel_size=(3, 3))(x) + x = nn.Conv(features=32, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = nn.Conv(features=1, kernel_size=(3, 3))(x) + x = nn.Conv(features=64, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten - x = nn.Dense(features=1)(x) + x = nn.Dense(features=256)(x) x = nn.relu(x) - x = nn.Dense(features=1)(x) - x = nn.log_softmax(x) + x = nn.Dense(features=10)(x) return x + # Fake data for faster execution. def get_datasets(): - """Load fake MNIST data.""" - # Converts dataset from list of dicts to dict of lists. - to_dict = lambda x: {k: np.array([d[k] for d in x]) for k in ['image', 'label']} - with tfds.testing.mock_data(num_examples=100): - train_ds = to_dict(tfds.as_numpy(tfds.load('mnist', split='train'))) - test_ds = to_dict(tfds.as_numpy(tfds.load('mnist', split='test'))) - train_ds['image'] = jnp.float32(train_ds['image']) / 255. - test_ds['image'] = jnp.float32(test_ds['image']) / 255. + train_ds = test_ds = { + 'image': jnp.zeros([64, 28, 28, 1]), + 'label': jnp.zeros([64], jnp.int32), + } return train_ds, test_ds + # Modified from examples/mnist/configs.default.py + learning_rate = 0.1 + momentum = 0.9 + batch_size = 32 + num_epochs = 1 - def onehot(labels, num_classes=10): - x = (labels[..., None] == jnp.arange(num_classes)[None]) - return x.astype(jnp.float32) - - - def cross_entropy_loss(logits, labels): - return -jnp.mean(jnp.sum(onehot(labels) * logits, axis=-1)) - - - def compute_metrics(logits, labels): - loss = cross_entropy_loss(logits, labels) - accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) - metrics = { - 'loss': loss, - 'accuracy': accuracy, - } - return metrics Parallel functions --------------------------------- +------------------ -We start by creating a parallel version of ``get_initial_params``, which -retrieves the initial parameters of the models. We do this using `jax.pmap`_. +We start by creating a parallel version of ``create_train_state()``, which +retrieves the initial parameters of the models. We do this using |jax.pmap()|_. The effect of "pmapping" a function is that it will compile the function with -XLA (similar to `jax.jit`_), but execute it in parallel on XLA devices (e.g., +XLA (similar to |jax.jit()|_), but execute it in parallel on XLA devices (e.g., GPUs/TPUs). .. codediff:: :title_left: Single-model :title_right: Ensemble - @jax.jit #! - def get_initial_params(key): - init_val = jnp.ones((1, 28, 28, 1), jnp.float32) - initial_params = CNN().init(key, init_val)['params'] - return initial_params - + #! + def create_train_state(rng, learning_rate, momentum): + cnn = CNN() + params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params'] + tx = optax.sgd(learning_rate, momentum) + return train_state.TrainState.create( + apply_fn=cnn.apply, params=params, tx=tx) --- - @jax.pmap #! - def get_initial_params(key): - init_val = jnp.ones((1, 28, 28, 1), jnp.float32) - initial_params = CNN().init(key, init_val)['params'] - return initial_params - -Note that for the single-model code above, we use `jax.jit`_ to lazily + @functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2)) #! + def create_train_state(rng, learning_rate, momentum): + cnn = CNN() + params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params'] + tx = optax.sgd(learning_rate, momentum) + return train_state.TrainState.create( + apply_fn=cnn.apply, params=params, tx=tx) + +Note that for the single-model code above, we use |jax.jit()|_ to lazily initialize the model (see `Module.init`_'s documentation for more details). -For the ensembling case, `jax.pmap`_ will map over the first axis of the -provided argument ``key`` by default, so we should make sure that we provide -one key for each device when we call this function later on. - -Next we simply do the same for the functions ``create_optimizer``, -``train_step``, and ``eval_step``. We also make a minor change to -``eval_model``, which ensures the metrics are used correctly in the parallel -setting. +For the ensembling case, |jax.pmap()|_ will map over the first axis of the +provided argument ``rng`` by default, so we should make sure that we provide +a different value for each device when we call this function later on. + +Note also how we specify that ``learning_rate`` and ``momentum`` are static +arguments, which means the concrete values of these arguments will be used, +rather than abstract shapes. This is necessary because the provided arguments +will be scalar values. For more details see `JIT mechanics: tracing and static +variables`_. + +Next we simply do the same for the functions ``apply_model()`` and +``update_model()``. To compute the predictions from the ensemble, we take the +average of the individual probabilities. We use |jax.lax.pmean()|_ to compute +the average *across devices*. This also requires us to specify the +``axis_name`` to both |jax.pmap()|_ and |jax.lax.pmean()|_. .. codediff:: :title_left: Single-model :title_right: Ensemble - # #! - def create_optimizer(params, learning_rate=0.1, beta=0.9): - optimizer_def = optim.Momentum(learning_rate=learning_rate, - beta=beta) - optimizer = optimizer_def.create(params) - return optimizer - - @jax.jit #! - def train_step(optimizer, batch): - """Train for a single step.""" + @jax.jit #! + def apply_model(state, images, labels): def loss_fn(params): - logits = CNN().apply({'params': params}, batch['image']) - loss = cross_entropy_loss(logits, batch['label']) + logits = CNN().apply({'params': params}, images) + one_hot = jax.nn.one_hot(labels, 10) + loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot).mean() return loss, logits + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) - (_, logits), grad = grad_fn(optimizer.target) - optimizer = optimizer.apply_gradient(grad) - metrics = compute_metrics(logits, batch['label']) - return optimizer, metrics - - @jax.jit #! - def eval_step(params, batch): - logits = CNN().apply({'params': params}, batch['image']) - return compute_metrics(logits, batch['label']) - - def eval_model(params, test_ds): - metrics = eval_step(params, test_ds) - metrics = jax.device_get(metrics) - summary = jax.tree_map(lambda x: x.item(), metrics) #! - return summary['loss'], summary['accuracy'] + (loss, logits), grads = grad_fn(state.params) + #! + accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) #! + return grads, loss, accuracy + + @jax.jit #! + def update_model(state, grads): + return state.apply_gradients(grads=grads) --- - @functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2)) #! - def create_optimizer(params, learning_rate=0.1, beta=0.9): - optimizer_def = optim.Momentum(learning_rate=learning_rate, - beta=beta) - optimizer = optimizer_def.create(params) - return optimizer - - @jax.pmap #! - def train_step(optimizer, batch): - """Train for a single step.""" + @functools.partial(jax.pmap, axis_name='ensemble') #! + def apply_model(state, images, labels): def loss_fn(params): - logits = CNN().apply({'params': params}, batch['image']) - loss = cross_entropy_loss(logits, batch['label']) + logits = CNN().apply({'params': params}, images) + one_hot = jax.nn.one_hot(labels, 10) + loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot).mean() return loss, logits + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) - (_, logits), grad = grad_fn(optimizer.target) - optimizer = optimizer.apply_gradient(grad) - metrics = compute_metrics(logits, batch['label']) - return optimizer, metrics - - @jax.pmap #! - def eval_step(params, batch): - logits = CNN().apply({'params': params}, batch['image']) - return compute_metrics(logits, batch['label']) - - def eval_model(params, test_ds): - metrics = eval_step(params, test_ds) - metrics = jax.device_get(metrics) - summary = metrics #! - return summary['loss'], summary['accuracy'] - -Note that for ``create_optimizer`` we also specify that ``learning_rate`` -and ``beta`` are static arguments, which means the concrete values of these -arguments will be used, rather than abstract shapes. This is necessary because -the provided arguments will be scalar values. For more details see -`JIT mechanics: tracing and static variables`_. + (loss, logits), grads = grad_fn(state.params) + probs = jax.lax.pmean(jax.nn.softmax(logits), axis_name='ensemble') #! + accuracy = jnp.mean(jnp.argmax(probs, -1) == labels) #! + return grads, loss, accuracy + + @jax.pmap #! + def update_model(state, grads): + return state.apply_gradients(grads=grads) Training the Ensemble --------------------------------- +--------------------- -Next we transform the ``train_epoch`` function. +Next we transform the ``train_epoch()`` function. When calling the pmapped +functions from above, we mainly need to take care of duplicating the arguments +for all devices where necessary, and de-duplicating the return values. .. codediff:: :title_left: Single-model :title_right: Ensemble - def train_epoch(optimizer, train_ds, rng, batch_size=10): + def train_epoch(state, train_ds, batch_size, rng): train_ds_size = len(train_ds['image']) steps_per_epoch = train_ds_size // batch_size - perms = random.permutation(rng, len(train_ds['image'])) + perms = jax.random.permutation(rng, len(train_ds['image'])) perms = perms[:steps_per_epoch * batch_size] perms = perms.reshape((steps_per_epoch, batch_size)) - batch_metrics = [] - for perm in perms: - batch = {k: v[perm, ...] for k, v in train_ds.items()} - optimizer, metrics = train_step(optimizer, batch) - batch_metrics.append(metrics) + epoch_loss = [] + epoch_accuracy = [] - batch_metrics_np = jax.device_get(batch_metrics) - - - epoch_metrics_np = { - k: np.mean([metrics[k] for metrics in batch_metrics_np]) #! - for k in batch_metrics_np[0]} #! - - return optimizer, epoch_metrics_np + for perm in perms: + batch_images = train_ds['image'][perm, ...] #! + batch_labels = train_ds['label'][perm, ...] #! + grads, loss, accuracy = apply_model(state, batch_images, batch_labels) + state = update_model(state, grads) + epoch_loss.append(loss) #! + epoch_accuracy.append(accuracy) #! + train_loss = np.mean(epoch_loss) + train_accuracy = np.mean(epoch_accuracy) + return state, train_loss, train_accuracy --- - def train_epoch(optimizer, train_ds, rng, batch_size=10): + def train_epoch(state, train_ds, batch_size, rng): train_ds_size = len(train_ds['image']) steps_per_epoch = train_ds_size // batch_size - perms = random.permutation(rng, len(train_ds['image'])) - perms = perms[:steps_per_epoch * batch_size] + perms = jax.random.permutation(rng, len(train_ds['image'])) + perms = perms[:steps_per_epoch * batch_size] perms = perms.reshape((steps_per_epoch, batch_size)) - batch_metrics = [] - for perm in perms: - batch = {k: v[perm, ...] for k, v in train_ds.items()} - batch = jax_utils.replicate(batch) #! - optimizer, metrics = train_step(optimizer, batch) - batch_metrics.append(metrics) - batch_metrics_np = jax.device_get(batch_metrics) - batch_metrics_np = jax.tree_multimap(lambda *xs: np.array(xs), #! - *batch_metrics_np) #! - epoch_metrics_np = { - k: np.mean(batch_metrics_np[k], axis=0) #! - for k in batch_metrics_np} #! + epoch_loss = [] + epoch_accuracy = [] - return optimizer, epoch_metrics_np + for perm in perms: + batch_images = jax_utils.replicate(train_ds['image'][perm, ...]) #! + batch_labels = jax_utils.replicate(train_ds['label'][perm, ...]) #! + grads, loss, accuracy = apply_model(state, batch_images, batch_labels) + state = update_model(state, grads) + epoch_loss.append(jax_utils.unreplicate(loss)) #! + epoch_accuracy.append(jax_utils.unreplicate(accuracy)) #! + train_loss = np.mean(epoch_loss) + train_accuracy = np.mean(epoch_accuracy) + return state, train_loss, train_accuracy As can be seen, we do not have to make any changes to the logic around the -``optimizer``. This is because, as we will see below in our training code, -the optimizer is replicated already, so when we pass it to ``train_step``, -things will just work fine since ``train_step`` is pmapped. However, +``state``. This is because, as we will see below in our training code, +the train state is replicated already, so when we pass it to ``train_step()``, +things will just work fine since ``train_step()`` is pmapped. However, the train dataset is not yet replicated, so we do that here. Since replicating the entire train dataset is too memory intensive we do it at the batch level. -The rest of the changes relate to making sure the batch metrics are stored -correctly for all devices. We use ``jax.tree_multimap`` to stack all of the -metrics from each device into numpy arrays, such that e.g., -``batch_metrics_np['loss']`` has shape ``(steps_per_epoch, jax.device_count())``. - We can now rewrite the actual training logic. This consists of two simple changes: making sure the RNGs are replicate when we pass them to -``get_initial_params``, and replicating the test dataset, which is much smaller -than the train dataset so we can do this for the entire dataset directly. +``create_train_state()``, and replicating the test dataset, which is much +smaller than the train dataset so we can do this for the entire dataset +directly. .. codediff:: :title_left: Single-model :title_right: Ensemble train_ds, test_ds = get_datasets() + #! + rng = jax.random.PRNGKey(0) + rng, init_rng = jax.random.split(rng) + state = create_train_state(init_rng, learning_rate, momentum) #! + #! - rng, init_rng = random.split(random.PRNGKey(0)) - params = get_initial_params(init_rng) #! - optimizer = create_optimizer(params, learning_rate=0.1, #! - momentum=0.9) #! - - for epoch in range(num_epochs): - rng, input_rng = random.split(rng) - optimizer, _ = train_epoch(optimizer, train_ds, input_rng) - loss, accuracy = eval_model(optimizer.target, test_ds) + for epoch in range(1, num_epochs + 1): + rng, input_rng = jax.random.split(rng) + state, train_loss, train_accuracy = train_epoch( + state, train_ds, batch_size, input_rng) - logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', #! - epoch, loss, accuracy * 100) + _, test_loss, test_accuracy = apply_model( #! + state, test_ds['image'], test_ds['label']) #! + + logging.info( + 'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, ' + 'test_loss: %.4f, test_accuracy: %.2f' + % (epoch, train_loss, train_accuracy * 100, test_loss, + test_accuracy * 100)) --- train_ds, test_ds = get_datasets() - test_ds = jax_utils.replicate(test_ds) #! - - rng, init_rng = random.split(random.PRNGKey(0)) - params = get_initial_params(random.split(rng, #! - jax.device_count())) #! - optimizer = create_optimizer(params, 0.1, 0.9) #! + test_ds = jax_utils.replicate(test_ds) #! + rng = jax.random.PRNGKey(0) - for epoch in range(num_epochs): - rng, input_rng = random.split(rng) - optimizer, _ = train_epoch(optimizer, train_ds, input_rng) - loss, accuracy = eval_model(optimizer.target, test_ds) + rng, init_rng = jax.random.split(rng) + state = create_train_state(jax.random.split(init_rng, jax.device_count()), #! + learning_rate, momentum) #! - logging.info('eval epoch: %d, loss: %s, accuracy: %s', #! - epoch, loss, accuracy * 100) + for epoch in range(1, num_epochs + 1): + rng, input_rng = jax.random.split(rng) + state, train_loss, train_accuracy = train_epoch( + state, train_ds, batch_size, input_rng) -Note that ``create_optimizer`` is using positional arguments in the ensembling -case. This is because we defined those arguments as static broadcasted -arguments, and those should be positional rather then keyword arguments. + _, test_loss, test_accuracy = jax_utils.unreplicate( #! + apply_model(state, test_ds['image'], test_ds['label'])) #! + + logging.info( + 'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, ' + 'test_loss: %.4f, test_accuracy: %.2f' + % (epoch, train_loss, train_accuracy * 100, test_loss, + test_accuracy * 100)) -.. _jax.jit: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#To-JIT-or-not-to-JIT -.. _jax.pmap: https://jax.readthedocs.io/en/latest/jax.html#jax.pmap + +.. |jax.jit()| replace:: ``jax.jit()`` +.. _jax.jit(): https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#To-JIT-or-not-to-JIT +.. |jax.pmap()| replace:: ``jax.pmap()`` +.. _jax.pmap(): https://jax.readthedocs.io/en/latest/jax.html#jax.pmap +.. |jax.lax.pmean()| replace:: ``jax.lax.pmean()`` +.. _jax.lax.pmean(): https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.pmean.html .. _Module.init: https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.init .. _`JIT mechanics: tracing and static variables`: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#JIT-mechanics:-tracing-and-static-variables .. _`MNIST example`: https://github.com/google/flax/blob/main/examples/mnist/train.py diff --git a/docs/index.rst b/docs/index.rst index fbb51bca0..e932af5cf 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -74,9 +74,3 @@ For a quick introduction and short example snippets, see our `README flax.training flax.config flax.errors - -.. toctree:: - :maxdepth: 1 - :caption: (deprecated) - - flax.nn (deprecated) diff --git a/docs/notebooks/jax_for_the_impatient.ipynb b/docs/notebooks/jax_for_the_impatient.ipynb index fc1c305bc..4f631738c 100644 --- a/docs/notebooks/jax_for_the_impatient.ipynb +++ b/docs/notebooks/jax_for_the_impatient.ipynb @@ -492,7 +492,7 @@ } }, "source": [ - "key, *subkeys = random.split(key, 4) # TODO Question : why is there no overlapping result with random.split(key) (so only one number generated with the same key)\n", + "key, *subkeys = random.split(key, 4)\n", "key, subkeys" ], "execution_count": 13, diff --git a/examples/README.md b/examples/README.md index bcfd59fd8..2cd572e25 100644 --- a/examples/README.md +++ b/examples/README.md @@ -57,8 +57,6 @@ Generative models - [Variational auto-encoder](https://github.com/google/flax/tree/main/examples/vae/): Trained on binarized MNIST (featuring simple code, vmap). -- [PixelCNN++](https://github.com/google/flax/tree/main/examples/pixelcnn/): - Trained on cifar10 (featuring single host SPMD, checkpointing, Polyak decay). Graph modeling @@ -67,6 +65,17 @@ Graph modeling [#231]: https://github.com/google/flax/issues/231 +## Repositories Using Flax + +The following code bases use Flax and provide training frameworks and a wealth +of examples, in many cases with pre-trained weights: + +- https://github.com/google-research/scenic: *Scenic* is a codebase/library + for computer vision research and beyond. Scenic's main focus is around + attention-based models. Scenic has been successfully used to develop + classification, segmentation, and detection models for multiple modalities + including images, video, audio, and multimodal combinations of them. + ## Community Examples In addition to the curated list of official Flax examples, there is a growing @@ -77,7 +86,7 @@ official Flax example, and start from there. | Link | Author | Task type | Reference | | ---------------------------- | ------------------ | --------------------------------- | --------------------------------------------------------------------- | -| [matthias-wright/flaxmodels] | [@matthias-wright] | Various | Various | +| [matthias-wright/flaxmodels] | [@matthias-wright] | Various | GPT-2, ResNet, StyleGAN-2, VGG, ... | | [google/vision_transformer] | [@andsteing] | Image classification, fine-tuning | https://arxiv.org/abs/2010.11929 and https://arxiv.org/abs/2105.01601 | | [JAX-RL] | [@henry-prior] | Reinforcement learning | N/A | | [DCGAN] Colab | [@bkkaggle] | Image Synthesis | https://arxiv.org/abs/1511.06434 | @@ -95,4 +104,4 @@ official Flax example, and start from there. [@henry-prior]: https://github.com/henry-prior [@bkkaggle]: https://github.com/bkkaggle [@vasudevgupta7]: https://github.com/vasudevgupta7 -[@n2cholas]: https://github.com/n2cholas \ No newline at end of file +[@n2cholas]: https://github.com/n2cholas diff --git a/examples/linen_design_test/attention_simple.py b/examples/linen_design_test/attention_simple.py index e83cad55f..9e3ac997c 100644 --- a/examples/linen_design_test/attention_simple.py +++ b/examples/linen_design_test/attention_simple.py @@ -17,7 +17,7 @@ from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type, Union from flax.core import Scope from flax.core.frozen_dict import freeze, unfreeze -from flax.deprecated.nn import initializers +from flax.linen import initializers from flax.linen import Module, compact, vmap import jax from jax import lax, numpy as jnp, random diff --git a/examples/mnist/train.py b/examples/mnist/train.py index 0e7bf0c8c..20ee09001 100644 --- a/examples/mnist/train.py +++ b/examples/mnist/train.py @@ -21,6 +21,7 @@ # See issue #620. # pytype: disable=wrong-keyword-args +from absl import logging from flax import linen as nn from flax.metrics import tensorboard from flax.training import train_state @@ -31,6 +32,7 @@ import optax import tensorflow_datasets as tfds + class CNN(nn.Module): """A simple CNN model.""" @@ -141,7 +143,7 @@ def train_and_evaluate(config: ml_collections.ConfigDict, _, test_loss, test_accuracy = apply_model(state, test_ds['image'], test_ds['label']) - print( + logging.info( 'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f' % (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100)) diff --git a/examples/pixelcnn/README.md b/examples/pixelcnn/README.md deleted file mode 100644 index 46040b0aa..000000000 --- a/examples/pixelcnn/README.md +++ /dev/null @@ -1,43 +0,0 @@ -## PixelCNN++ image modelling -Trains a PixelCNN++ model [(Salimans et al., -2017)](https://arxiv.org/abs/1701.05517) for image generation on the CIFAR-10 dataset. -Only unconditional image generation is implemented, trained using ADAM -on the negative log-likelihood. As in the original [OpenAI implementation](https://github.com/openai/pixel-cnn) -we use weightnorm parameterization with data-dependent initialization. - -Code for sampling is also provided. The following image, containing 256 samples, was generated in 4m 24s -on an 8 x Nvidia V100 machine. - -![alt text](sample.png "PixelCNN++ samples.") -### Requirements (Training) -* [TF datasets](https://www.tensorflow.org/datasets), which will download and cache the CIFAR-10 dataset the first time you - run `train.py`. - -### Requirements (Sampling) -* [Pillow](https://pillow.readthedocs.io/en/stable/) for saving samples as PNG files. - -### Supported setups -The model should run with other configurations and hardware, but was tested on the following. - -| Hardware | Batch size | Training time | Log-likelihood (bits/dimension) | TensorBoard.dev | -| --- | --- | --- | --- | --- | -| 8 x Nvidia V100 (16GB) | 320 | 1d 14h | 2.923 | [2020-04-23](https://tensorboard.dev/experiment/t8fM3u2zSJG7tAx6YbXHkQ/) | -| 8 x TPUv3 (16GB) | 320 | 4d4h | 2.927 | [2020-08-15](https://tensorboard.dev/experiment/6rTypNzlSN2o7pfNWJOjMw/) | - -### How to run -#### 8 x Nvidia V100 (16GB), 8 x TPUv3 (16GB) - -To run training: - -``` -python main.py --workdir=/tmp/pixelcnn --config=configs/default.py \ - --config.batch_size=320 -``` - -To run sampling (this will automatically load model parameters from the most -recent trained checkpoint): - -``` -python main.py --workdir=/tmp/pixelcnn --sample --config=configs/default.py \ - --config.sample_batch_size=256 -``` diff --git a/examples/pixelcnn/configs/default.py b/examples/pixelcnn/configs/default.py deleted file mode 100644 index 31a0e09cd..000000000 --- a/examples/pixelcnn/configs/default.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Default Hyperparameter configuration.""" - -import ml_collections - - -def get_config(): - """Get the default hyperparameter configuration.""" - config = ml_collections.ConfigDict() - - # The initial learning rate. - config.learning_rate = 0.001 - - # Learning rate decay, applied each optimization step. - config.lr_decay = 0.999995 - - # Batch size to use for data-dependent initialization. - config.init_batch_size = 16 - - # Batch size for training. - config.batch_size = 64 - - # Number of training epochs. - config.num_epochs = 200 - - # Dropout rate. - config.dropout_rate = 0.5 - - # Number of resnet layers per block. - config.n_resnet = 5 - - # Number of features in each conv layer. - config.n_feature = 160 - - # Number of components in the output distribution. - config.n_logistic_mix = 10 - - # Exponential decay rate of the sum of previous model iterates during Polyak - # averaging. - config.polyak_decay = 0.9995 - - # Batch size for sampling. - config.sample_batch_size = 256 - # Random number generator seed for sampling. - config.sample_rng_seed = 0 - - # Integer for PRNG random seed. - config.seed = 0 - - return config diff --git a/examples/pixelcnn/input_pipeline.py b/examples/pixelcnn/input_pipeline.py deleted file mode 100644 index 2916f8295..000000000 --- a/examples/pixelcnn/input_pipeline.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Cifar 10 input pipeline.""" - -import ml_collections -import tensorflow as tf -import tensorflow_datasets as tfds - - -class DataSource(object): - """CIFAR10 data source.""" - - TRAIN_IMAGES = 50000 - EVAL_IMAGES = 10000 - - def __init__(self, config: ml_collections.ConfigDict, shuffle_seed: int = 1): - - dataset_builder = tfds.builder('cifar10') - dataset_builder.download_and_prepare() - self.ds_info = dataset_builder.info - - # Training set - train_ds = dataset_builder.as_dataset(split='train').cache() - train_ds = train_ds.repeat(config.num_epochs) - train_ds = train_ds.shuffle(16 * config.batch_size, seed=shuffle_seed) - - def process_sample(x): - image = tf.cast(x['image'], tf.float32) - image = image / 127.5 - 1 - batch = {'image': image, 'label': x['label']} - return batch - - train_ds = train_ds.map(process_sample, num_parallel_calls=128) - train_ds = train_ds.batch(config.batch_size, drop_remainder=True) - train_ds = train_ds.prefetch(10) - self.train_ds = train_ds - - # Test set - eval_ds = dataset_builder.as_dataset(split='test').cache() - eval_ds = eval_ds.map(process_sample, num_parallel_calls=128) - # Note: samples will be dropped if the number of test samples is not - # divisible by the evaluation batch size - eval_ds = eval_ds.batch(config.batch_size, drop_remainder=True) - eval_ds = eval_ds.prefetch(10) - self.eval_ds = eval_ds diff --git a/examples/pixelcnn/main.py b/examples/pixelcnn/main.py deleted file mode 100644 index 6eae925ce..000000000 --- a/examples/pixelcnn/main.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Main file for running the PixelCNN example. - -This file is intentionally kept short. The majority for logic is in libraries -that can be easily tested and imported in Colab. -""" - -from absl import app -from absl import flags -from absl import logging -from clu import platform -import jax -from ml_collections import config_flags -import tensorflow as tf - -import sample -import train - - -FLAGS = flags.FLAGS - -flags.DEFINE_string('workdir', None, 'Directory to store model data.') -config_flags.DEFINE_config_file( - 'config', - None, - 'File path to the training hyperparameter configuration.', - lock_config=True) -flags.DEFINE_bool('sample', False, 'Sample from a model in workdir.') -flags.mark_flags_as_required(['config', 'workdir']) - - -def main(argv): - if len(argv) > 1: - raise app.UsageError('Too many command-line arguments.') - - # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make - # it unavailable to JAX. - tf.config.experimental.set_visible_devices([], 'GPU') - - logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) - logging.info('JAX local devices: %r', jax.local_devices()) - - # Add a note so that we can tell which task is which JAX host. - # (Depending on the platform task 0 is not guaranteed to be host 0) - platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, ' - f'process_count: {jax.process_count()}') - platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, - FLAGS.workdir, 'workdir') - - if FLAGS.sample: - sample.save_images( - sample.generate_sample(FLAGS.config, FLAGS.workdir), 'sample.png') - else: - train.train_and_evaluate(FLAGS.config, FLAGS.workdir) - - -if __name__ == '__main__': - app.run(main) diff --git a/examples/pixelcnn/model_test.py b/examples/pixelcnn/model_test.py deleted file mode 100644 index b3cf23fa9..000000000 --- a/examples/pixelcnn/model_test.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Lint as: python3 -"""Tests for PixelCNN Modules.""" - -from absl.testing import absltest -from absl.testing import parameterized -from flax import linen as nn -from jax import random -import jax.numpy as jnp -from jax.config import config -import numpy.testing as np_testing - -import pixelcnn - - -class ModelTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self.rng = random.PRNGKey(0) - self.x = jnp.arange(24).reshape(1, 4, 3, 2) - - - def get_weightnorm(self, params): - return [params[k] for k in ('direction', 'scale', 'bias')] - - - def assert_mean_and_variance(self, out): - # Weightnorm should ensure that, at initialization time, the outputs of the - # module have mean 0 and variance 1 over the non-feature dimensions. - np_testing.assert_allclose(jnp.mean(out, (0, 1, 2)), 0., atol=1e-5) - np_testing.assert_allclose(jnp.var(out, (0, 1, 2)), 1., atol=1e-5) - - - def test_conv(self): - model = pixelcnn.ConvWeightNorm(features=4, kernel_size=(3, 2)) - out, variables = model.init_with_output(self.rng, self.x) - params = variables['params']['weightnorm_params'] - direction, scale, bias = self.get_weightnorm(params) - - self.assertEqual(direction.shape, (3, 2, 2, 4)) - self.assertEqual(scale.shape, (4,)) - self.assertEqual(bias.shape, (4,)) - self.assertEqual(out.shape, (1, 2, 2, 4)) - self.assert_mean_and_variance(out) - - - def test_conv_down(self): - model = pixelcnn.ConvDown(features=4) - out, variables = model.init_with_output(self.rng, self.x) - params = variables['params']['ConvWeightNorm_0']['weightnorm_params'] - direction, scale, bias = self.get_weightnorm(params) - - self.assertEqual(direction.shape, (2, 3, 2, 4)) - self.assertEqual(scale.shape, (4,)) - self.assertEqual(bias.shape, (4,)) - self.assertEqual(out.shape, (1, 4, 3, 4)) - self.assert_mean_and_variance(out) - - - def test_conv_down_right(self): - model = pixelcnn.ConvDownRight(features=4) - out, variables = model.init_with_output(self.rng, self.x) - params = variables['params']['ConvWeightNorm_0']['weightnorm_params'] - direction, scale, bias = self.get_weightnorm(params) - - self.assertEqual(direction.shape, (2, 2, 2, 4)) - self.assertEqual(scale.shape, (4,)) - self.assertEqual(bias.shape, (4,)) - self.assertEqual(out.shape, (1, 4, 3, 4)) - self.assert_mean_and_variance(out) - - - def test_conv_transpose(self): - model = pixelcnn.ConvTranspose(features=4, kernel_size = (3, 2)) - out, variables = model.init_with_output(self.rng, self.x) - params = variables['params']['weightnorm_params'] - direction, scale, bias = self.get_weightnorm(params) - - self.assertEqual(direction.shape, (3, 2, 2, 4)) - self.assertEqual(scale.shape, (4,)) - self.assertEqual(bias.shape, (4,)) - self.assertEqual(out.shape, (1, 6, 4, 4)) - self.assert_mean_and_variance(out) - - - def test_conv_transpose_down(self): - model = pixelcnn.ConvTransposeDown(features=4) - out, variables = model.init_with_output(self.rng, self.x) - params = variables['params']["ConvWeightNorm_0"]["weightnorm_params"] - direction, scale, bias = self.get_weightnorm(params) - - self.assertEqual(direction.shape, (2, 3, 2, 4)) - self.assertEqual(scale.shape, (4,)) - self.assertEqual(bias.shape, (4,)) - self.assertEqual(out.shape, (1, 8, 6, 4)) - - - def test_conv_transpose_down_right(self): - model = pixelcnn.ConvTransposeDownRight(features=4) - out, variables = model.init_with_output(self.rng, self.x) - params = variables['params']['ConvWeightNorm_0']['weightnorm_params'] - direction, scale, bias = self.get_weightnorm(params) - - self.assertEqual(direction.shape, (2, 2, 2, 4)) - self.assertEqual(scale.shape, (4,)) - self.assertEqual(bias.shape, (4,)) - self.assertEqual(out.shape, (1, 8, 6, 4)) - - - def test_pcnn_shape(self): - x = random.normal(self.rng, (2, 4, 4, 3)) - model = pixelcnn.PixelCNNPP(depth=0, features=2, dropout_p=0) - out, _ = model.init_with_output(self.rng, x, train=False) - self.assertEqual(out.shape, (2, 4, 4, 100)) - - -if __name__ == '__main__': - absltest.main() diff --git a/examples/pixelcnn/pixelcnn.py b/examples/pixelcnn/pixelcnn.py deleted file mode 100644 index cba75c44e..000000000 --- a/examples/pixelcnn/pixelcnn.py +++ /dev/null @@ -1,430 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Flax implementation of PixelCNN++ - -Based on the paper - - PixelCNN++: Improving the PixelCNN with discretized logistic mixture - likelihood and other modifications - -published at ICLR '17 (https://openreview.net/forum?id=BJrFC6ceg). -""" - -# See issue #620. -# pytype: disable=wrong-arg-count - -from functools import partial -from typing import Any, Callable, Iterable, Tuple, Optional, Union - -import flax.linen as nn -import jax -from jax import lax -import jax.numpy as jnp -from jax.scipy.special import logsumexp -import numpy as np - - -class PixelCNNPP(nn.Module): - """PixelCNN++ module.""" - depth: int = 5 - features: int = 160 - logistic_components: int = 10 - dropout_p: float = 0.5 - - @nn.compact - def __call__(self, images, *, train): - # Special convolutional and resnet blocks which allow information flow - # downwards and to the right. - conv_down = partial(ConvDown, features=self.features) - conv_down_right = partial(ConvDownRight, features=self.features) - - dropout = partial( - nn.Dropout, rate=self.dropout_p, deterministic=not train) - - res_down = partial(ResDown, dropout=dropout) - res_down_right = partial(ResDownRight, dropout=dropout) - - # Conv Modules which halve or double the spatial dimensions - halve_down = partial(conv_down, strides=(2, 2)) - halve_down_right = partial(conv_down_right, strides=(2, 2)) - - double_down = partial(ConvTransposeDown, features=self.features) - double_down_right = partial(ConvTransposeDownRight, features=self.features) - - # Add channel of ones to distinguish image from padding later on - images = jnp.pad(images, ((0, 0), (0, 0), (0, 0), (0, 1)), constant_values=1) - - # Stack of `(down, down_right)` pairs, where information flows downwards - # through `down` and downwards and to the right through `down_right`. - # We refer to the building of the stack as the 'forward pass' and the - # undoing of the stack as the 'reverse pass'. - stack = [] - - # -------------------------- FORWARD PASS ---------------------------------- - down = shift_down(conv_down(kernel_size=(2, 3))(images)) - down_right = ( - shift_down(conv_down(kernel_size=(1, 3))(images)) - + shift_right(conv_down_right(kernel_size=(2, 1))(images))) - - stack.append((down, down_right)) - for _ in range(self.depth): - down, down_right = res_down()(down), res_down_right()(down_right, down) - stack.append((down, down_right)) - - # Resize spatial dims 32 x 32 --> 16 x 16 - down, down_right = halve_down()(down), halve_down_right()(down_right) - stack.append((down, down_right)) - - for _ in range(self.depth): - down, down_right = res_down()(down), res_down_right()(down_right, down) - stack.append((down, down_right)) - - # Resize spatial dims 16 x 16 --> 8 x 8 - down, down_right = halve_down()(down), halve_down_right()(down_right) - stack.append((down, down_right)) - - for _ in range(self.depth): - down, down_right = res_down()(down), res_down_right()(down_right, down) - stack.append((down, down_right)) - - # The stack now contains (in order from last appended): - # - # Number of layers Spatial dims - # depth + 1 8 x 8 - # depth + 1 16 x 16 - # depth + 1 32 x 32 - - # -------------------------- REVERSE PASS ---------------------------------- - down, down_right = stack.pop() - - for _ in range(self.depth): - down_fwd, down_right_fwd = stack.pop() - down = res_down()(down, down_fwd) - down_right = res_down_right()( - down_right, jnp.concatenate((down, down_right_fwd), -1)) - - # Resize spatial dims 8 x 8 --> 16 x 16 - down, down_right = double_down()(down), double_down_right()(down_right) - - for _ in range(self.depth + 1): - down_fwd, down_right_fwd = stack.pop() - down = res_down()(down, down_fwd) - down_right = res_down_right()( - down_right, jnp.concatenate((down, down_right_fwd), -1)) - - # Resize spatial dims 16 x 16 --> 32 x 32 - down, down_right = double_down()(down), double_down_right()(down_right) - - for _ in range(self.depth + 1): - down_fwd, down_right_fwd = stack.pop() - down = res_down()(down, down_fwd) - down_right = res_down_right()( - down_right, jnp.concatenate((down, down_right_fwd), -1)) - - assert not stack - - # Note init_scale=0.1 on this layer was not in the original implementation, - # but seems to make training more stable. - return ConvOneByOne(10 * self.logistic_components, - init_scale=0.1)(nn.elu(down_right)) - - -def concat_elu(x): - return nn.elu(jnp.concatenate((x, -x), -1)) - - -def spatial_pad(pad_vertical, pad_horizontal, operand): - """Wrapper around lax.pad which pads spatial dimensions (horizontal and - vertical) with zeros, without any interior padding.""" - zero = (0, 0, 0) - return lax.pad(operand, jnp.zeros((), operand.dtype), - (zero, pad_vertical + (0,), pad_horizontal + (0,), zero)) - - -shift_down = partial(spatial_pad, (1, -1), (0, 0)) -shift_right = partial(spatial_pad, (0, 0), (1, -1)) - - -# Weightnorm utils -def _l2_normalize(v): - """Normalize a convolution kernel direction over the in_features and spatial - dimensions.""" - return v / jnp.sqrt(jnp.sum(jnp.square(v), (0, 1, 2), keepdims=True)) - - -def _make_kernel(direction, scale): - """Maps weightnorm parameterization (direction, scale) to standard - parameterization. The direction has shape (spatial..., in_features, - out_features), scale has shape (out_features,).""" - scale = scale.reshape((1,) * (direction.ndim - 1) + (-1,)) - return scale * _l2_normalize(direction) - - -# 2D convolution Modules with weightnorm -class ConvWeightNorm(nn.Module): - """2D convolution Modules with weightnorm.""" - features: int - kernel_size: Tuple[int, int] - strides: Optional[Tuple[int, int]] = None - padding: Union[str, Iterable[Iterable[int]]] = 'VALID' - transpose: bool = False - init_scale: float = 1. - dtype: Any = jnp.float32 - precision: Any = None - - @nn.compact - def __call__(self, inputs): - inputs = jnp.asarray(inputs, self.dtype) - strides = self.strides or (1,) * (inputs.ndim - 2) - - if self.transpose: - conv = partial(lax.conv_transpose, strides=strides, padding=self.padding, - precision=self.precision) - else: - conv = partial(lax.conv_general_dilated, window_strides=strides, - padding=self.padding, - dimension_numbers=('NHWC', 'HWIO', 'NHWC'), - precision=self.precision) - - in_features = inputs.shape[-1] - kernel_shape = self.kernel_size + (in_features, self.features) - - def initializer(key): - # A weightnorm initializer generating a (direction, scale, bias) tuple. - direction = nn.initializers.normal()(key, kernel_shape, self.dtype) - unnormed_out = conv(inputs, _l2_normalize(direction)) - mean = jnp.mean(unnormed_out, (0, 1, 2)) - var = jnp.std(unnormed_out, (0, 1, 2)) - return dict( - direction=direction, scale=self.init_scale / var, bias=-mean / var) - - params = self.param('weightnorm_params', initializer) - direction, scale, bias = [params[k] for k in ('direction', 'scale', 'bias')] - y = conv(inputs, _make_kernel(direction, scale)) - y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) - return y - - -ConvOneByOne = partial(ConvWeightNorm, kernel_size=(1, 1)) -ConvTranspose = partial(ConvWeightNorm, transpose=True) - - -class ConvDown(nn.Module): - """Convolution with padding so that information cannot flow upwards.""" - features: int - kernel_size: Tuple[int, int] = (2, 3) - strides: Optional[Tuple[int, int]] = None - init_scale: float = 1. - - @nn.compact - def __call__(self, inputs): - k_h, k_w = self.kernel_size - assert k_w % 2 == 1, 'kernel width must be odd.' - padding = ((k_h - 1, 0), # Vertical padding - (k_w // 2, k_w // 2)) # Horizontal padding - - return ConvWeightNorm( - self.features, self.kernel_size, self.strides, padding, - init_scale=self.init_scale)(inputs) - - -class ConvDownRight(nn.Module): - """Convolution with padding so that information cannot flow left/upwards.""" - features: Any - kernel_size: Tuple[int, int] = (2, 2) - strides: Optional[Tuple[int, int]] = None - init_scale: float = 1.0 - - @nn.compact - def __call__(self, inputs): - k_h, k_w = self.kernel_size - padding = ((k_h - 1, 0), # Vertical padding - (k_w - 1, 0)) # Horizontal padding - - return ConvWeightNorm( - self.features, self.kernel_size, self.strides, padding, - init_scale=self.init_scale)(inputs) - - -class ConvTransposeDown(nn.Module): - """Transpose convolution with output slicing so that information cannot flow - upwards. Strides are (2, 2) by default which implies the spatial dimensions - of the output shape are double those of the input shape. - """ - features: Any - kernel_size: Tuple[int, int] = (2, 3) - strides: Optional[Tuple[int, int]] = (2, 2) - - @nn.compact - def __call__(self, inputs): - _, k_w = self.kernel_size - out_h, out_w = np.multiply(self.strides, inputs.shape[1:3]) - return ConvTranspose(self.features, self.kernel_size, self.strides)(inputs)[ - :, :out_h, (k_w - 1) // 2:out_w + (k_w - 1) // 2, :] - -class ConvTransposeDownRight(nn.Module): - """Transpose convolution with output slicing so that information cannot flow. - - to the left or upwards. Strides are (2, 2) by default which implies the - spatial dimensions of the output shape are double those of the input shape. - """ - features: Any - kernel_size: Tuple[int, int] = (2, 2) - strides: Optional[Tuple[int, int]] = (2, 2) - - @nn.compact - def __call__(self, inputs): - out_h, out_w = np.multiply(self.strides, inputs.shape[1:3]) - return ConvTranspose(self.features, self.kernel_size, - self.strides)(inputs)[:, :out_h, :out_w] - - -# Resnet modules -class GatedResnet(nn.Module): - conv_module: Callable[..., Any] - dropout: Callable[..., Any] - nonlinearity: Callable[..., Any] = concat_elu - - @nn.compact - def __call__(self, inputs, aux=None): - c = inputs.shape[-1] - y = self.conv_module(c)(self.nonlinearity(inputs)) - if aux is not None: - y = self.nonlinearity(y + ConvOneByOne(c)(self.nonlinearity(aux))) - - y = self.dropout()(y) - - # Set init_scale=0.1 so that the res block is close to the identity at - # initialization. - a, b = jnp.split(self.conv_module(2 * c, init_scale=0.1)(y), 2, axis=-1) - return inputs + a * nn.sigmoid(b) - - -ResDown = partial(GatedResnet, conv_module=ConvDown) -ResDownRight = partial(GatedResnet, conv_module=ConvDownRight) - - -# Logistic mixture distribution utils -def conditional_params_from_outputs(theta, img): - """Maps an image `img` and the PixelCNN++ convnet output `theta` to - conditional parameters for a mixture of k logistics over each pixel. - - Returns a tuple `(means, inverse_scales, logit_weights)` where `means` and - `inverse_scales` are the conditional means and inverse scales of each mixture - component (for each pixel-channel) and `logit_weights` are the logits of the - mixture weights (for each pixel). These have the following shapes: - - means.shape == inv_scales.shape == (batch..., k, h, w, c) - logit_weights.shape == (batch..., k, h, w) - - Args: - theta: outputs of PixelCNN++ neural net with shape - (batch..., h, w, (1 + 3 * c) * k) - img: an image with shape (batch..., h, w, c) - - Returns: - The tuple `(means, inverse_scales, logit_weights)`. - """ - *batch, h, w, c = img.shape - assert theta.shape[-1] % (3 * c + 1) == 0 - k = theta.shape[-1] // (3 * c + 1) - - logit_weights, theta = theta[..., :k], theta[..., k:] - assert theta.shape[-3:] == (h, w, 3 * c * k) - - # Each of m, s and t must have shape (batch..., k, h, w, c), we effectively - # spread the last dimension of theta out into c, k, 3, move the k dimension to - # after batch and split along the 3 dimension. - m, s, t = jnp.moveaxis( - jnp.reshape(theta, - tuple(batch) + (h, w, c, k, 3)), (-2, -1), (-4, 0)) - assert m.shape[-4:] == (k, h, w, c) - t = jnp.tanh(t) - - # Add a mixture dimension to images - img = jnp.expand_dims(img, -4) - - # Ensure inv_scales cannot be zero (zeros cause nans in sampling) - inv_scales = jnp.maximum(nn.softplus(s), 1e-7) - - # now condition the means for the last 2 channels (assuming c == 3) - mean_red = m[..., 0] - mean_green = m[..., 1] + t[..., 0] * img[..., 0] - mean_blue = m[..., 2] + t[..., 1] * img[..., 0] + t[..., 2] * img[..., 1] - means = jnp.stack((mean_red, mean_green, mean_blue), axis=-1) - return means, inv_scales, jnp.moveaxis(logit_weights, -1, -3) - - -def logprob_from_conditional_params(images, means, inv_scales, logit_weights): - """Compute log-likelihoods. - - Computes the log-likelihoods of images given the conditional logistic mixture - parameters produced by `conditional_params_from_outputs`. The 8-bit pixel - values are assumed to be scaled so that they are in the discrete set - - {-1, -1 + 1/127.5, -1 + 2/127.5, ..., 1 - 1/127.5, 1} - """ - # Add a 'mixture' dimension to images. - images = jnp.expand_dims(images, -4) - - # Calculate log probabilities under all mixture components. - all_logprobs = discretized_logistic_logpmf(images, means, inv_scales) - - # Sum over the channel dimension because mixture components are shared - # across channels. - logprobs = jnp.sum(all_logprobs, -1) - - # Normalize the mixture weights. - log_mix_coeffs = logit_weights - logsumexp(logit_weights, -3, keepdims=True) - - # Finally marginalize out mixture components and sum over pixels. - return jnp.sum(logsumexp(log_mix_coeffs + logprobs, -3), (-2, -1)) - - -def discretized_logistic_logpmf(images, means, inv_scales): - """Compute log-probabilities for each mixture component, pixel and channel.""" - # Compute the difference between the logistic cdf half a level above and half - # a level below the image value. - centered = images - means - - # Where images == 1 we use log(1 - cdf(images - 1 / 255)) - top = -jnp.logaddexp(0, (centered - 1 / 255) * inv_scales) - - # Where images == -1 we use log(cdf(images + 1 / 255)) - bottom = -jnp.logaddexp(0, -(centered + 1 / 255) * inv_scales) - - # Elsewhere we use log(cdf(images + 1 / 255) - cdf(images - 1 / 255)) - mid = log1mexp(inv_scales / 127.5) + top + bottom - - return jnp.where(images == 1, top, jnp.where(images == -1, bottom, mid)) - - -@jax.custom_jvp -def log1mexp(x): - """Accurate computation of log(1 - exp(-x)) for x > 0.""" - - # Method from - # https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf - return jnp.where(x > jnp.log(2), jnp.log1p(-jnp.exp(-x)), - jnp.log(-jnp.expm1(-x))) - - -# log1mexp produces NAN gradients for small inputs because the derivative of the -# log1p(-exp(-eps)) branch has a zero divisor (1 + -jnp.exp(-eps)), and NANs in -# the derivative of one branch of a where cause NANs in the where's vjp, even -# when the NAN branch is not taken. See -# https://github.com/google/jax/issues/1052. We work around this by defining a -# custom jvp. -log1mexp.defjvps(lambda t, _, x: t / jnp.expm1(x)) diff --git a/examples/pixelcnn/requirements.txt b/examples/pixelcnn/requirements.txt deleted file mode 100644 index c199b03f2..000000000 --- a/examples/pixelcnn/requirements.txt +++ /dev/null @@ -1,11 +0,0 @@ -absl-py==1.0.0 -clu==0.0.6 -flax==0.3.6 -jax==0.2.21 ---find-links https://storage.googleapis.com/jax-releases/jax_releases.html -jaxlib==0.1.70+cuda110 # Make sure CUDA version matches the base image. -ml-collections==0.1.0 -numpy==1.21.4 -Pillow==9.0.0 -tensorflow==2.7.0 -tensorflow-datasets==4.4.0 diff --git a/examples/pixelcnn/sample.png b/examples/pixelcnn/sample.png deleted file mode 100644 index 1fe879b32..000000000 Binary files a/examples/pixelcnn/sample.png and /dev/null differ diff --git a/examples/pixelcnn/sample.py b/examples/pixelcnn/sample.py deleted file mode 100644 index 9288875ea..000000000 --- a/examples/pixelcnn/sample.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Lint as: python3 -"""Sampling from PixelCNN++ using fixed-point iteration. -""" - -import functools - -from flax import optim -import jax -from jax import random -import jax.numpy as jnp -import ml_collections -import numpy as np -from PIL import Image - -import pixelcnn -import train - - -def generate_sample(config: ml_collections.ConfigDict, workdir: str): - """Loads the latest model in `workdir` and samples a batch of images.""" - batch_size = config.sample_batch_size - rng = random.PRNGKey(config.sample_rng_seed) - rng, model_rng = random.split(rng) - rng, dropout_rng = random.split(rng) - - # Create a model with dummy parameters and a dummy optimizer. - init_batch = jnp.zeros((1, 32, 32, 3)) - - params = train.model(config).init( - { - 'params': model_rng, - 'dropout': dropout_rng - }, init_batch)['params'] - optimizer_def = optim.Adam( - learning_rate=config.learning_rate, beta1=0.95, beta2=0.9995) - optimizer = optimizer_def.create(params) - - _, params = train.restore_checkpoint(workdir, optimizer, params) - - # Initialize batch of images - device_count = jax.local_device_count() - assert not batch_size % device_count, ( - 'Sampling batch size must be a multiple of the device count, got ' - 'sample_batch_size={}, device_count={}.'.format(batch_size, - device_count)) - sample_prev = jnp.zeros((device_count, batch_size // device_count, 32, 32, 3)) - - # and batch of rng keys - sample_rng = random.split(rng, device_count) - - # Generate sample using fixed-point iteration - sample = sample_iteration(config, sample_rng, params, sample_prev) - while jnp.any(sample != sample_prev): - sample_prev, sample = sample, sample_iteration(config, sample_rng, params, - sample) - return jnp.reshape(sample, (batch_size, 32, 32, 3)) - - -def _categorical_onehot(rng, logit_probs): - """Sample from a categorical distribution and one-hot encode the sample.""" - nr_mix = logit_probs.shape[-3] - idxs = random.categorical(rng, logit_probs, axis=-3) - return jnp.moveaxis(idxs[..., jnp.newaxis] == jnp.arange(nr_mix), -1, -3) - - -def conditional_params_to_sample(rng, conditional_params): - means, inv_scales, logit_probs = conditional_params - rng_mix, rng_logistic = random.split(rng) - # Add channel dimension to one-hot mixture indicator - mix_indicator = _categorical_onehot(rng_mix, logit_probs)[..., jnp.newaxis] - # Use the mixture indicator to select the mean and inverse scale - mean = jnp.sum(means * mix_indicator, -4) - inv_scale = jnp.sum(inv_scales * mix_indicator, -4) - sample = mean + random.logistic(rng_logistic, mean.shape) / inv_scale - return snap_to_grid(sample) - - -@functools.partial(jax.pmap, static_broadcasted_argnums=2) -def sample_iteration(config, rng, params, sample): - """PixelCNN++ sampling expressed as a fixed-point iteration.""" - rng, dropout_rng = random.split(rng) - out = train.model(config).apply({'params': params}, - sample, - rngs={'dropout': dropout_rng}) - c_params = pixelcnn.conditional_params_from_outputs(out, sample) - return conditional_params_to_sample(rng, c_params) - - -def snap_to_grid(sample): - return jnp.clip(jnp.round((sample + 1) * 127.5) / 127.5 - 1, -1., 1.) - - -def save_images(batch, fname): - n_rows = batch.shape[0] // 16 - batch = np.uint8(jnp.round((batch + 1) * 127.5)) - out = np.full((1 + 33 * n_rows, 1 + 33 * 16, 3), 255, 'uint8') - for i, im in enumerate(batch): - top = 1 + 33 * (i // 16) - left = 1 + 33 * (i % 16) - out[top:top + 32, left:left + 32] = im - Image.fromarray(out).save(fname) diff --git a/examples/pixelcnn/train.py b/examples/pixelcnn/train.py deleted file mode 100644 index 1ea678ceb..000000000 --- a/examples/pixelcnn/train.py +++ /dev/null @@ -1,237 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""PixelCNN++ example.""" - -# See issue #620. -# pytype: disable=wrong-keyword-args - -import functools -import datetime - -from absl import logging -from flax import jax_utils -from flax import optim -from flax.metrics import tensorboard -from flax.training import checkpoints -from flax.training import common_utils -import jax -import jax.numpy as jnp -import ml_collections -import numpy as np -import tensorflow as tf - -import input_pipeline -import pixelcnn - - -def get_summary_writers(workdir): - current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') - log_dir = workdir + '/log/' + current_time - train_log_dir = log_dir + '/train' - eval_log_dir = log_dir + '/eval' - train_summary_writer = tensorboard.SummaryWriter(train_log_dir) - eval_summary_writer = tensorboard.SummaryWriter(eval_log_dir) - return train_summary_writer, eval_summary_writer - - -def model(config: ml_collections.ConfigDict, **kwargs): - return pixelcnn.PixelCNNPP( - depth=config.n_resnet, - features=config.n_feature, - logistic_components=config.n_logistic_mix, - **kwargs) - - -def neg_log_likelihood_loss(nn_out, images): - # The log-likelihood in bits per pixel-channel - means, inv_scales, logit_weights = ( - pixelcnn.conditional_params_from_outputs(nn_out, images)) - log_likelihoods = pixelcnn.logprob_from_conditional_params( - images, means, inv_scales, logit_weights) - return -jnp.mean(log_likelihoods) / (jnp.log(2) * np.prod(images.shape[-3:])) - - -def train_step(config: ml_collections.ConfigDict, learning_rate_fn, optimizer, - ema, batch, dropout_rng): - """Perform a single training step.""" - - def loss_fn(params): - """loss function used for training.""" - pcnn_out = model( - config, - dropout_p=config.dropout_rate).apply({'params': params}, - batch['image'], - rngs={'dropout': dropout_rng}, - train=True) - return neg_log_likelihood_loss(pcnn_out, batch['image']) - - lr = learning_rate_fn(optimizer.state.step) - grad_fn = jax.value_and_grad(loss_fn) - loss, grad = grad_fn(optimizer.target) - grad = jax.lax.pmean(grad, 'batch') - optimizer = optimizer.apply_gradient(grad, learning_rate=lr) - - # Compute exponential moving average (aka Polyak decay) - ema_decay = config.polyak_decay - ema = jax.tree_multimap(lambda ema, p: ema * ema_decay + (1 - ema_decay) * p, - ema, optimizer.target) - - metrics = {'loss': jax.lax.pmean(loss, 'batch'), 'learning_rate': lr} - return optimizer, ema, metrics - - -def eval_step(config, params, batch): - images = batch['image'] - pcnn_out = model(config).apply({'params': params}, images, train=False) - return { - 'loss': jax.lax.pmean(neg_log_likelihood_loss(pcnn_out, images), 'batch') - } - - -def load_and_shard_tf_batch(xs): - local_device_count = jax.local_device_count() - - def _prepare(x): - # Use _numpy() for zero-copy conversion between TF and NumPy. - x = x._numpy() # pylint: disable=protected-access - return x.reshape((local_device_count, -1) + x.shape[1:]) - - return jax.tree_map(_prepare, xs) - - -def restore_checkpoint(workdir: str, optimizer, ema): - return checkpoints.restore_checkpoint(workdir, (optimizer, ema)) - - -def save_checkpoint(workdir: str, optimizer, ema, step): - optimizer, ema = jax_utils.unreplicate((optimizer, ema)) - checkpoints.save_checkpoint(workdir, (optimizer, ema), step, keep=3) - - -def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): - """Runs a training and evaluation loop. - - Args: - config: Configuration to use. - workdir: Working directory for checkpoints and TF summaries. If this - contains checkpoint training will be resumed from the latest checkpoint. - """ - tf.io.gfile.makedirs(workdir) - - batch_size = config.batch_size - n_devices = jax.device_count() - if jax.process_count() > 1: - raise ValueError('PixelCNN++ example should not be run on more than 1 host' - ' (for now)') - if batch_size % n_devices > 0: - raise ValueError('Batch size must be divisible by the number of devices') - - train_summary_writer, eval_summary_writer = get_summary_writers(workdir) - # Load dataset - data_source = input_pipeline.DataSource(config) - train_ds = data_source.train_ds - eval_ds = data_source.eval_ds - steps_per_epoch = data_source.ds_info.splits[ - 'train'].num_examples // config.batch_size - # Create dataset batch iterators - train_iter = iter(train_ds) - num_train_steps = train_ds.cardinality().numpy() - steps_per_checkpoint = 1000 - - # Create the model using data-dependent initialization. Don't shard the init - # batch. - assert config.init_batch_size <= batch_size - init_batch = next(train_iter)['image']._numpy()[:config.init_batch_size] - - rng = jax.random.PRNGKey(config.seed) - rng, init_rng, dropout_rng = jax.random.split(rng, 3) - - initial_variables = model(config).init( - { - 'params': init_rng, - 'dropout': dropout_rng - }, init_batch, train=False)['params'] - optimizer_def = optim.Adam(beta1=0.95, beta2=0.9995) - optimizer = optimizer_def.create(initial_variables) - - optimizer, ema = restore_checkpoint(workdir, optimizer, initial_variables) - ema = initial_variables - step_offset = int(optimizer.state.step) - - optimizer, ema = jax_utils.replicate((optimizer, ema)) - - # Learning rate schedule - learning_rate_fn = lambda step: config.learning_rate * config.lr_decay**step - - # pmap the train and eval functions - p_train_step = jax.pmap( - functools.partial(train_step, config, learning_rate_fn), - axis_name='batch') - p_eval_step = jax.pmap( - functools.partial(eval_step, config), axis_name='batch') - - # Gather metrics - train_metrics = [] - - for step, batch in zip(range(step_offset, num_train_steps), train_iter): - # Load and shard the TF batch - batch = load_and_shard_tf_batch(batch) - - # Generate a PRNG key that will be rolled into the batch. - rng, step_rng = jax.random.split(rng) - sharded_rngs = common_utils.shard_prng_key(step_rng) - - # Train step - optimizer, ema, metrics = p_train_step(optimizer, ema, batch, sharded_rngs) - train_metrics.append(metrics) - - # Quick indication that training is happening. - logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) - - if (step + 1) % steps_per_epoch == 0: - epoch = step // steps_per_epoch - # We've finished an epoch - train_metrics = common_utils.get_metrics(train_metrics) - # Get training epoch summary for logging - train_summary = jax.tree_map(lambda x: x.mean(), train_metrics) - # Send stats to Tensorboard - for key, vals in train_metrics.items(): - for i, val in enumerate(vals): - train_summary_writer.scalar(key, val, step - len(vals) + i + 1) - # Reset train metrics - train_metrics = [] - - # Evaluation - eval_metrics = [] - for eval_batch in eval_ds: - # Load and shard the TF batch - eval_batch = load_and_shard_tf_batch(eval_batch) - # Step - metrics = p_eval_step(ema, eval_batch) - eval_metrics.append(metrics) - eval_metrics = common_utils.get_metrics(eval_metrics) - # Get eval epoch summary for logging - eval_summary = jax.tree_map(lambda x: x.mean(), eval_metrics) - - # Log epoch summary - logging.info('Epoch %d: TRAIN loss=%.6f, EVAL loss=%.6f', epoch, - train_summary['loss'], eval_summary['loss']) - - eval_summary_writer.scalar('loss', eval_summary['loss'], step) - train_summary_writer.flush() - eval_summary_writer.flush() - - if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_train_steps: - save_checkpoint(workdir, optimizer, ema, step) diff --git a/examples/pixelcnn/train_test.py b/examples/pixelcnn/train_test.py deleted file mode 100644 index 2d08a77f8..000000000 --- a/examples/pixelcnn/train_test.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pathlib -import tempfile - -from absl import logging -from absl.testing import absltest -import tensorflow as tf -import tensorflow_datasets as tfds - -from configs import default -import train - - -def get_test_config(): - config = default.get_config() - config.init_batch_size = 8 - config.batch_size = 8 - config.num_epochs = 1 - config.n_resent = 1 - config.n_feature = 8 - return config - - -class TrainTest(absltest.TestCase): - """Test cases for PixelCNN library.""" - - def setUp(self): - super().setUp() - tf.config.experimental.set_visible_devices([], 'GPU') - - def test_train_and_evaluate(self): - config = get_test_config() - workdir = tempfile.mkdtemp() - - # Go two directories up to the root of the flax directory. - flax_root_dir = pathlib.Path(__file__).parents[2] - data_dir = str(flax_root_dir) + '/.tfds/metadata' # pylint: disable=unused-variable - - with tfds.testing.mock_data(num_examples=8, data_dir=data_dir): - train.train_and_evaluate(config, workdir) - logging.info('workdir content: %s', tf.io.gfile.listdir(workdir)) - - -if __name__ == '__main__': - absltest.main() diff --git a/flax/__init__.py b/flax/__init__.py index 5924d055e..687f3e405 100644 --- a/flax/__init__.py +++ b/flax/__init__.py @@ -19,7 +19,7 @@ from . import core from . import linen from . import optim -from .deprecated import nn +# DO NOT REMOVE - Marker for internal deprecated API. # DO NOT REMOVE - Marker for internal logging. from .version import __version__ diff --git a/flax/core/flax_functional_engine.ipynb b/flax/core/flax_functional_engine.ipynb index b35bd6b53..96732c79f 100644 --- a/flax/core/flax_functional_engine.ipynb +++ b/flax/core/flax_functional_engine.ipynb @@ -24,7 +24,7 @@ "colab": {} }, "source": [ - "from flax import nn, struct" + "from flax import linen as nn, struct" ], "execution_count": 2, "outputs": [] diff --git a/flax/core/nn/__init__.py b/flax/core/nn/__init__.py index 0f53e7f7c..f8ff50ae3 100644 --- a/flax/core/nn/__init__.py +++ b/flax/core/nn/__init__.py @@ -17,13 +17,13 @@ # pylint: disable=g-multiple-import # re-export commonly used modules and functions from .attention import (dot_product_attention, multi_head_dot_product_attention) -from flax.deprecated.nn import activation -from flax.deprecated.nn import initializers -from flax.deprecated.nn.activation import (celu, elu, gelu, glu, leaky_relu, - log_sigmoid, log_softmax, relu, - sigmoid, silu, soft_sign, softmax, - softplus, swish, tanh) -from flax.deprecated.nn.pooling import avg_pool, max_pool +from flax.linen import activation +from flax.linen import initializers +from flax.linen.activation import (celu, elu, gelu, glu, leaky_relu, + log_sigmoid, log_softmax, relu, sigmoid, + silu, soft_sign, softmax, softplus, swish, + tanh) +from flax.linen.pooling import avg_pool, max_pool from .linear import Embedding, conv, conv_transpose, dense, dense_general, embedding from .normalization import batch_norm, group_norm, layer_norm from .stochastic import dropout diff --git a/flax/core/nn/attention.py b/flax/core/nn/attention.py index 014b16be3..2ec9d0f5a 100644 --- a/flax/core/nn/attention.py +++ b/flax/core/nn/attention.py @@ -23,7 +23,7 @@ from flax import jax_utils from flax import struct from flax.core import Scope -from flax.deprecated.nn import initializers +from flax.linen import initializers import jax from jax import lax from jax import random diff --git a/flax/core/nn/linear.py b/flax/core/nn/linear.py index 3835723c3..2184bd3fc 100644 --- a/flax/core/nn/linear.py +++ b/flax/core/nn/linear.py @@ -17,7 +17,7 @@ from collections.abc import Iterable # pylint: disable=g-importing-member from flax import struct from flax.core import Scope -from flax.deprecated.nn import initializers +from flax.linen import initializers from jax import lax import jax.numpy as jnp diff --git a/flax/core/nn/normalization.py b/flax/core/nn/normalization.py index 7d73dfb60..e49b097d4 100644 --- a/flax/core/nn/normalization.py +++ b/flax/core/nn/normalization.py @@ -15,7 +15,7 @@ """Normalization modules for Flax.""" from flax.core import Scope -from flax.deprecated.nn import initializers +from flax.linen import initializers from jax import lax import jax.numpy as jnp diff --git a/flax/core/scope.py b/flax/core/scope.py index 1f3920ed8..c2f38703d 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -18,7 +18,10 @@ import functools import hashlib import dataclasses -from typing import Any, Callable, Container, Dict, Generic, Iterable, Mapping, Optional, Sequence, Set, Tuple, TypeVar, Union + +import typing +from typing import (Any, Callable, Dict, Generic, Iterable, Mapping, Optional, + Sequence, Set, Tuple, TypeVar, Union) from . import tracers from flax import errors @@ -44,7 +47,7 @@ RNGSequences = Dict[str, PRNGKey] -Filter = Union[bool, str, Container[str], 'DenyList'] +Filter = Union[bool, str, typing.Collection[str], 'DenyList'] @dataclasses.dataclass(frozen=True, eq=True) class DenyList: @@ -113,7 +116,7 @@ def _legacy_rng_fold_in(rng: PRNGKey, data: Iterable[PRNGFoldable]) -> PRNGKey: return rng -def _fold_in_static(rng: PRNGKey, data: Iterable[PRNGFoldable]) -> PRNGKey: +def _fold_in_static(rng: PRNGKey, data: typing.Collection[PRNGFoldable]) -> PRNGKey: """Folds static data (strings & ints) into a jax.random.PRNGKey using its SHA-1 hash. This is faster than splitting an PRNGKey because it allows generating new PRNG @@ -144,7 +147,7 @@ def _fold_in_static(rng: PRNGKey, data: Iterable[PRNGFoldable]) -> PRNGKey: def is_filter_empty(filter_like: Filter) -> bool: if isinstance(filter_like, str): return False - if isinstance(filter_like, Container): + if isinstance(filter_like, typing.Collection): return len(filter_like) == 0 if isinstance(filter_like, bool): return not filter_like @@ -172,7 +175,7 @@ def in_filter(filter_like: Filter, col: str) -> bool: """ if isinstance(filter_like, str): return col == filter_like - if isinstance(filter_like, Container): + if isinstance(filter_like, typing.Collection): return col in filter_like if isinstance(filter_like, bool): return filter_like @@ -195,7 +198,7 @@ def filter_to_set(x: Filter) -> Set[str]: return set() if isinstance(x, str): return set([x]) - if isinstance(x, Iterable): + if isinstance(x, typing.Collection): return set(x) raise errors.InvalidFilterError(x) diff --git a/flax/deprecated/__init__.py b/flax/deprecated/__init__.py deleted file mode 100644 index 9143e1822..000000000 --- a/flax/deprecated/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/flax/deprecated/nn/__init__.py b/flax/deprecated/nn/__init__.py deleted file mode 100644 index b51860ca7..000000000 --- a/flax/deprecated/nn/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Flax Neural Network api.""" - -# pylint: disable=g-multiple-import -# re-export commonly used modules and functions -from .activation import (celu, elu, gelu, glu, leaky_relu, log_sigmoid, - log_softmax, relu, sigmoid, soft_sign, softmax, - softplus, swish, silu, tanh) -from .attention import (dot_product_attention, MultiHeadDotProductAttention, - SelfAttention) -from .base import (Module, Model, Collection, capture_module_outputs, - module, stateful, get_state, module_method) -from .linear import Dense, DenseGeneral, Conv, ConvTranspose, Embed -from .normalization import BatchNorm, LayerNorm, GroupNorm -from .pooling import max_pool, avg_pool -from .recurrent import LSTMCell, GRUCell, ConvLSTM, OptimizedLSTMCell -from .stochastic import make_rng, stochastic, dropout, is_stochastic -# pylint: enable=g-multiple-import -import warnings -# Makes sure the user sees the warning, as deprecation warnings are silent by default -warnings.filterwarnings("default", category=DeprecationWarning, module=__name__) diff --git a/flax/deprecated/nn/activation.py b/flax/deprecated/nn/activation.py deleted file mode 100644 index 16d7aa6a8..000000000 --- a/flax/deprecated/nn/activation.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Activation functions. -""" - -# pylint: disable=unused-import -# re-export activation functions from jax.nn -from jax.nn import celu -from jax.nn import elu -from jax.nn import gelu -from jax.nn import glu -from jax.nn import leaky_relu -from jax.nn import log_sigmoid -from jax.nn import log_softmax -from jax.nn import normalize -from jax.nn import relu -from jax.nn import sigmoid -from jax.nn import soft_sign -from jax.nn import softmax -from jax.nn import softplus -from jax.nn import swish -from jax.nn import silu -from jax.nn import selu -from jax.nn import hard_tanh -from jax.nn import relu6 -from jax.nn import hard_sigmoid -from jax.nn import hard_swish - -from jax.numpy import tanh -# pylint: enable=unused-import diff --git a/flax/deprecated/nn/attention.py b/flax/deprecated/nn/attention.py deleted file mode 100644 index 5b5d409fa..000000000 --- a/flax/deprecated/nn/attention.py +++ /dev/null @@ -1,551 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Attention core modules for Flax.""" - -from collections.abc import Iterable # pylint: disable=g-importing-member - -import warnings - -from flax import jax_utils -from flax import struct -from flax.deprecated.nn.activation import softmax -from flax.deprecated.nn.base import Collection, Module, collection_from_iterable, iterate_collection -from flax.deprecated.nn.initializers import zeros -from flax.deprecated.nn.linear import DenseGeneral, default_kernel_init -from flax.deprecated.nn.stochastic import make_rng -import jax -from jax import lax -from jax import random -import jax.numpy as jnp - -import numpy as np - - -def dot_product_attention(query, - key, - value, - dtype=jnp.float32, - bias=None, - axis=None, - broadcast_dropout=True, - dropout_rng=None, - dropout_rate=0., - deterministic=False, - precision=None): - """DEPRECATION WARNING: - "The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Computes dot-product attention given query, key, and value. - - This is the core function for applying attention based on - https://arxiv.org/abs/1706.03762. It calculates the attention weights given - query and key and combines the values using the attention weights. This - function supports multi-dimensional inputs. - - - Args: - query: queries for calculating attention with shape of `[batch_size, dim1, - dim2, ..., dimN, num_heads, mem_channels]`. - key: keys for calculating attention with shape of `[batch_size, dim1, dim2, - ..., dimN, num_heads, mem_channels]`. - value: values to be used in attention with shape of `[batch_size, dim1, - dim2,..., dimN, num_heads, value_channels]`. - dtype: the dtype of the computation (default: float32) - bias: bias for the attention weights. This can be used for incorporating - autoregressive mask, padding mask, proximity bias. - axis: axises over which the attention is applied. - broadcast_dropout: bool: use a broadcasted dropout along batch dims. - dropout_rng: JAX PRNGKey: to be used for dropout - dropout_rate: dropout rate - deterministic: bool, deterministic or not (to apply dropout) - precision: numerical precision of the computation see `jax.lax.Precision` - for details. - - Returns: - Output of shape `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`. - """ - assert key.shape[:-1] == value.shape[:-1] - assert (query.shape[0:1] == key.shape[0:1] and - query.shape[-1] == key.shape[-1]) - - if axis is None: - axis = tuple(range(1, key.ndim - 2)) - if not isinstance(axis, Iterable): - axis = (axis,) - assert key.ndim == query.ndim - assert key.ndim == value.ndim - for ax in axis: - if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2): - raise ValueError('Attention axis must be between the batch ' - 'axis and the last-two axes.') - depth = query.shape[-1] - n = key.ndim - # batch_dims is , num_heads> - batch_dims = tuple(np.delete(range(n), axis + (n - 1,))) - # q & k -> (bs, , num_heads, , channels) - qk_perm = batch_dims + axis + (n - 1,) - key = key.transpose(qk_perm) - query = query.transpose(qk_perm) - # v -> (bs, , num_heads, channels, ) - v_perm = batch_dims + (n - 1,) + axis - value = value.transpose(v_perm) - - query = query / jnp.sqrt(depth).astype(dtype) - batch_dims_t = tuple(range(len(batch_dims))) - attn_weights = lax.dot_general( - query, - key, (((n - 1,), (n - 1,)), (batch_dims_t, batch_dims_t)), - precision=precision) - - # apply attention bias: masking, droput, proximity bias, ect. - if bias is not None: - attn_weights = attn_weights + bias - - # normalize the attention weights - norm_dims = tuple(range(attn_weights.ndim - len(axis), attn_weights.ndim)) - attn_weights = softmax(attn_weights, axis=norm_dims) - attn_weights = attn_weights.astype(dtype) - - # apply dropout - if not deterministic and dropout_rate > 0.: - if dropout_rng is None: - dropout_rng = make_rng() - keep_prob = jax.lax.tie_in(attn_weights, 1.0 - dropout_rate) - if broadcast_dropout: - # dropout is broadcast across the batch+head+non-attention dimension - dropout_dims = attn_weights.shape[-(2 * len(axis)):] - dropout_shape = (tuple([1] * len(batch_dims_t)) + dropout_dims) - keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) - else: - keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) - multiplier = (keep.astype(attn_weights.dtype) / - jnp.asarray(keep_prob, dtype=dtype)) - attn_weights = attn_weights * multiplier - - # compute the new values given the attention weights - wv_contracting_dims = (norm_dims, range(value.ndim - len(axis), value.ndim)) - y = lax.dot_general( - attn_weights, - value, (wv_contracting_dims, (batch_dims_t, batch_dims_t)), - precision=precision) - - # back to (bs, dim1, dim2, ..., dimN, num_heads, channels) - perm_inv = _invert_perm(qk_perm) - y = y.transpose(perm_inv) - return y - - -def _invert_perm(perm): - perm_inv = [0] * len(perm) - for i, j in enumerate(perm): - perm_inv[j] = i - return tuple(perm_inv) - - -@struct.dataclass -class _CacheEntry: - key: np.ndarray - value: np.ndarray - i: np.ndarray - - -def scan_in_dim(*args, **kwargs): - warnings.warn('scan_in_dim moved to flax.jax_utils', - DeprecationWarning) - return jax_utils.scan_in_dim(*args, **kwargs) - - -class Cache(Collection): - """The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Collect intermediate activations for efficient autoregressive decoding.""" - - def initialize_cache(self, shape, dtype=None): - """Initialize the cache for the given input shape. - - Args: - shape: the shape of the batch and attention dimensions. - dtype: the dtype of the autoregressive cache. - Returns: - the initialized cache - """ - if dtype is None: - dtype = jnp.float32 - def _init(shape_data): - ndim = int(shape_data[0]) - tail_shape = tuple(shape_data[1:]) - full_shape = shape + tail_shape - if len(full_shape) != ndim: - raise ValueError('Shape should be a tuple with the shape of the batch' - 'and attention dims.') - return _CacheEntry(key=jnp.zeros(full_shape, dtype=dtype), - value=jnp.zeros(full_shape, dtype=dtype), - i=jnp.zeros((), jnp.uint32)) - return Cache(jax.tree_map(_init, self.state)) - - -jax.tree_util.register_pytree_node( - Cache, iterate_collection, collection_from_iterable) - - -class MultiHeadDotProductAttention(Module): - """The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Multi-head dot-product attention.""" - - def apply(self, - inputs_q, - inputs_kv, - num_heads, - dtype=jnp.float32, - qkv_features=None, - out_features=None, - attention_axis=None, - causal_mask=False, - padding_mask=None, - key_padding_mask=None, - segmentation=None, - key_segmentation=None, - cache=None, - broadcast_dropout=True, - dropout_rng=None, - dropout_rate=0., - deterministic=False, - precision=None, - kernel_init=default_kernel_init, - bias_init=zeros, - bias=True, - attention_fn=dot_product_attention): - """Applies multi-head dot product attention on the input data. - - Projects the inputs into multi-headed query, key, and value vectors, - applies dot-product attention and project the results to an output vector. - - This can be used for encoder-decoder attention by specifying both `inputs_q` - and `inputs_kv` orfor self-attention by only specifying `inputs_q` and - setting `inputs_kv` to None. - - Args: - inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`. - inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` - or None for self-attention, inn which case key/values will be derived - from inputs_q. - num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) - should be divisible by the number of heads. - dtype: the dtype of the computation (default: float32) - qkv_features: dimension of the key, query, and value. - out_features: dimension of the last projection - attention_axis: axes over which the attention is applied ( 'None' means - attention over all axes, but batch, heads, and features). - causal_mask: boolean specifying whether to apply a causal mask on the - attention weights. If True, the output at timestep `t` will not depend - on inputs at timesteps strictly greater than `t`. - padding_mask: boolean specifying query tokens that are pad token w/ False. - key_padding_mask: boolean specifying key-value tokens that are pad token - w/ False. - segmentation: segment indices for packed inputs_q data. - key_segmentation: segment indices for packed inputs_kv data. - cache: an instance of `flax.nn.attention.Cache` used for efficient - autoregressive decoding. - broadcast_dropout: bool: use a broadcasted dropout along batch dims. - dropout_rng: JAX PRNGKey: to be used for dropout - dropout_rate: dropout rate - deterministic: bool, deterministic or not (to apply dropout) - precision: numerical precision of the computation see `jax.lax.Precision` - for details. - kernel_init: initializer for the kernel of the Dense layers. - bias_init: initializer for the bias of the Dense layers. - bias: bool: whether pointwise QKVO dense transforms use bias. - attention_fn: dot_product_attention or compatible function. Accepts - query, key, value, and returns output of shape - `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` - - Returns: - output of shape `[bs, dim1, dim2, ..., dimN, features]`. - """ - - assert causal_mask or not cache, ( - 'Caching is only support for causal attention.') - - if inputs_kv is None: - inputs_kv = inputs_q - - is_self_attention = inputs_kv is inputs_q - - if attention_axis is None: - attention_axis = tuple(range(1, inputs_q.ndim - 1)) - - features = out_features or inputs_q.shape[-1] - qkv_features = qkv_features or inputs_q.shape[-1] - - assert qkv_features % num_heads == 0, ( - 'Memory dimension must be divisible by number of heads.') - head_dim = qkv_features // num_heads - - dense = DenseGeneral.partial( - axis=-1, - features=(num_heads, head_dim), - kernel_init=kernel_init, - bias_init=bias_init, - bias=bias, - precision=precision) - # project inputs_q to multi-headed q/k/v - # dimensions are then [bs, dims..., n_heads, n_features_per_head] - query, key, value = (dense(inputs_q, dtype=dtype, name='query'), - dense(inputs_kv, dtype=dtype, name='key'), - dense(inputs_kv, dtype=dtype, name='value')) - - if cache: - assert isinstance(cache, Cache), 'cache must be an instance of Cache' - if self.is_initializing(): - cache.store(np.array((key.ndim,) + key.shape[-2:], dtype=np.int32)) - else: - cache_entry = cache.retrieve(None) - expected_shape = list(cache_entry.key.shape[:-2]) - for attn_dim in attention_axis: - expected_shape[attn_dim] = 1 - expected_shape = tuple(expected_shape) + inputs_q.shape[-1:] - if expected_shape != inputs_q.shape: - raise ValueError('Invalid shape provided, ' - 'expected shape %s instead got %s.' % - (expected_shape, inputs_q.shape)) - - if not isinstance(cache_entry, _CacheEntry): - raise ValueError('Cache is not initialized.') - - cshape = cache_entry.key.shape - indices = [0] * len(cshape) - i = cache_entry.i - attn_size = np.prod(np.take(cshape, attention_axis)) - for attn_dim in attention_axis: - attn_size //= cshape[attn_dim] - indices[attn_dim] = i // attn_size - i = i % attn_size - - key = lax.dynamic_update_slice(cache_entry.key, key, indices) - value = lax.dynamic_update_slice(cache_entry.value, value, indices) - one = jnp.array(1, jnp.uint32) - cache_entry = cache_entry.replace(i=cache_entry.i + one, - key=key, - value=value) - cache.store(cache_entry) - - # create attention masks - mask_components = [] - - if causal_mask: - if cache and not self.is_initializing(): - bias_pre_shape = (1,) * (key.ndim - 1) - attn_shape = tuple(np.take(key.shape, attention_axis)) - attn_size = np.prod(attn_shape) - ii = jnp.arange(attn_size, dtype=jnp.uint32) - mask = ii < cache_entry.i - mask_components.append(mask.reshape(bias_pre_shape + attn_shape)) - else: - mask_components.append(_make_causal_mask(key, attention_axis)) - - if (padding_mask is not None or key_padding_mask is not None) and not cache: - if key_padding_mask is None: - if is_self_attention: - key_padding_mask = padding_mask - else: - key_padding_shape = [inputs_kv.shape[dim] for dim in attention_axis] - key_padding_mask = jnp.full(key_padding_shape, True) - if padding_mask is None: - if is_self_attention: - padding_mask = key_padding_mask - else: - padding_shape = [inputs_q.shape[dim] for dim in attention_axis] - padding_mask = jnp.full(padding_shape, True) - - padding_mask = make_padding_mask( - padding_mask_query=padding_mask, - padding_mask_key=key_padding_mask, - query_shape=query.shape, - key_shape=key.shape, - attention_axis=attention_axis) - mask_components.append(padding_mask) - - if segmentation is not None: - if key_segmentation is None: - assert is_self_attention - key_segmentation = segmentation - segmentation_mask = make_padding_mask( - padding_mask_query=segmentation, - padding_mask_key=key_segmentation, - query_shape=query.shape, - key_shape=key.shape, - attention_axis=attention_axis, - segmentation_mask=True) - mask_components.append(segmentation_mask) - - if mask_components: - attention_mask = mask_components[0] - for component in mask_components[1:]: - attention_mask = jnp.logical_and(attention_mask, component) - - # attention mask in the form of attention bias - attention_bias = lax.select( - attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(dtype), - jnp.full(attention_mask.shape, -1e10).astype(dtype)) - else: - attention_bias = None - - # apply attention - x = attention_fn( - query, - key, - value, - dtype=dtype, - axis=attention_axis, - bias=attention_bias, - precision=precision, - dropout_rng=dropout_rng, - dropout_rate=dropout_rate, - broadcast_dropout=broadcast_dropout, - deterministic=deterministic) - - # back to the original inputs dimensions - out = DenseGeneral( - x, - features=features, - axis=(-2, -1), - kernel_init=kernel_init, - bias_init=bias_init, - bias=bias, - dtype=dtype, - precision=precision, - name='out') - - return out - - -# TODO(flax-dev): Consider refactoring MultiHeadDotProductAttention and moving -# causal_mask and cache support into this class instead. -SelfAttention = MultiHeadDotProductAttention.partial(inputs_kv=None) - - -def make_padding_mask(padding_mask_query, - padding_mask_key, - query_shape, - key_shape, - attention_axis=None, - segmentation_mask=False): - """The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Makes padding mask for attention weights. - - In case of 1d inputs (i.e., `[bs, len, features]`, the attention weights will - be `[bs, len, len]` and this function makes a square matrix [len, len]. - - Args: - padding_mask_query: padding mask of query - padding_mask_key: padding mask of query - query_shape: shape of the query - key_shape: shape of the key, which is equal to the shape of value. - attention_axis: axis over which attention is applied. - segmentation_mask: bool: if true use equality on cartesian product rather - than outer product for constructing segmentation masks. - Returns: - The padding mask for attention weights. - """ - assert query_shape[0] == key_shape[0] - assert len(query_shape) == len(key_shape) - - ndim = len(key_shape) - if attention_axis is None: - attention_axis = tuple(range(1, ndim - 2)) - assert isinstance(attention_axis, tuple) - for ax in attention_axis: - if not (ndim >= 3 and 1 <= ax < ndim - 2): - raise ValueError( - 'Attention axis must be between the batch axis and the last-two axes.' - ) - - mask_shape_final = (query_shape[0], 1) # batch_size, 1 (for all heads)s - for ax in attention_axis: - mask_shape_final += (query_shape[ax],) - for ax in attention_axis: - mask_shape_final += (key_shape[ax],) - - padding_mask_query = padding_mask_query[..., None] - padding_mask_key = padding_mask_key[..., None] - perm = (0,) + tuple(np.flip(np.arange(padding_mask_key.ndim)))[:-1] - if segmentation_mask: - mask = jnp.equal(padding_mask_query, padding_mask_key.transpose(perm)) - else: - mask = jnp.multiply(padding_mask_query, padding_mask_key.transpose(perm)) - - mask = mask.reshape(mask_shape_final) - mask = jax.lax.convert_element_type(mask, jnp.float32) - return mask - - -def _make_causal_mask(key, attention_axis=None, self_mask=False): - """The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Makes a causal mask, to be used for masking out the future for attention. - - In case of 1d inputs (i.e., `[bs, len, features]`, the attention weights will - be `[bs, len, len]` and this function makes a square matrix [len, len] with - zeros in upper triangle and ones in lower triangle. - - Args: - key: shape of the key, which is equal to the shape of value and is - assumed to be equal to the shape of the query (since this is used in - self-attention when decoding). - attention_axis: axis over which attention is applied. - self_mask: if mask out the diagonal or not. - - Returns: - A causal mask to be used to mask out future positions. - """ - if attention_axis is None: - attention_axis = tuple(range(1, key.ndim - 2)) - assert isinstance(attention_axis, tuple) - for ax in attention_axis: - if not (key.ndim >= 3 and 1 <= ax < key.ndim - 2): - raise ValueError( - 'Attention axis must be between the batch axis and the last-two axes.' - ) - - mask_shape = tuple([1] * (key.ndim - len(attention_axis) - 1)) - mask_shape_final = mask_shape - for _ in range(2): - flatten_dim = 1 - for ax in attention_axis: - mask_shape_final += (key.shape[ax],) - flatten_dim *= key.shape[ax] - mask_shape += (flatten_dim,) - - def tri(n, m, k=0): - # Tie in the key to avoid the mask becoming a constant. - # This way XLA can construct the mask during computation and fuse it - # with the attention ops. - x = lax.tie_in(key, jnp.arange(n, dtype=jnp.int32)) - y = lax.tie_in(key, jnp.arange(m, dtype=jnp.int32)) - mask = lax.ge( - (lax.broadcast_in_dim(x, shape=(n, m), broadcast_dimensions=(0,))) + k, - lax.broadcast(y, [n])) - return mask - - k = -1 if self_mask else 0 - mask = tri(*mask_shape[-2:], k=k).reshape(mask_shape_final) - return mask diff --git a/flax/deprecated/nn/base.py b/flax/deprecated/nn/base.py deleted file mode 100644 index bd72a0971..000000000 --- a/flax/deprecated/nn/base.py +++ /dev/null @@ -1,1196 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - NN base modules for JAX.""" - -from typing import Type - -import abc -import contextlib -import functools -import hashlib -import inspect -from typing import Any -import warnings - -from . import utils -from . import stochastic -from flax import jax_utils -from flax import serialization -from flax import struct - -import jax -from jax import random - - -_module_stack = utils.CallStack() -_module_output_trackers = utils.CallStack() -_state_stack = utils.CallStack() - - -def _track_outputs(x): - for module_output_tracker in _module_output_trackers: - xs = module_output_tracker.retrieve(default=[]) - xs.append(x) - module_output_tracker.store(xs) - - -class _ModuleFrame: - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - A ModuleFrame the context needed to initialize (init) or apply a Module. - - In particular, `self.params` is a dictionary where parameters are - stored (during module init) and read from (during module application). - - When `module.init()` is first called, a new ModuleFrame is created with - an empty `params` dictionary. When `self.param` is called within that - module, a new key is added to track that parameter, with the computed - parameter's initial value. - - When a module calls into a submodule, a new key is added, with a value - being an empty dictionary. Then that new dictionary is passed in as `params` - on a new sub-ModuleFrame. That new sub-ModuleFrame keeps track of its parent - with the `parent` attribute. - - When the whole init process is complete, the top-level ModuleFrame' - `params` are returned, which contain a nested dictionary of parameters. - - During module application, a similar process happens but this time - the parameters are only read from. - - Additional attributes on ModuleFrame track context needed to assist error - handling, shared parameters and transparent modules that are wrapped without - creating additional sub-parameters. TODO: Consider elaborating on this - last paragraph. - """ - - def __init__(self, name, - parent=None, params=None, rng=None, - transparent=False): - if params is None: - params = {} - self.parent = parent - self.rng = rng - self.params = params - self.shared = {} - self.shared_names = set() - self.name = name - self.transparent = transparent - - self._name_counter = 0 - - @property - def is_init(self): - return self.rng is not None - - @property - def path(self): - """Returns the path of the Module scope. - - Paths are similar to Unix file names (e.g. '/module/nested/dense') - - Returns: - The path of this Module scope. - """ - if self.parent is None: - if self.name is None: - return '/' - else: - return '/' + self.name - - path = self.parent.path - if not self.parent.transparent: - if path[-1] != '/': - path += '/' - path += self.name - return path - - def is_descendent_of(self, frame): - """Checks whether this frame is a descendent of the given frame.""" - if frame is self.parent: - return True - elif self.parent: - return self.parent.is_descendent_of(frame) - else: - return False - - def create_name(self): - name = str(self._name_counter) - self._name_counter += 1 - return name - - -def module_method(fn): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Decorates a function as a module method. - - The `module_method` allows modules to have multiple methods that make use of - the modules parameters. - - Example:: - - class MyLinearModule(nn.Module): - def apply(self, x, features, kernel_init): - kernel = self.param('kernel', (x.shape[-1], features), kernel_init) - return jnp.dot(x, kernel) - - @nn.module_method - def apply_transpose(self, x, **kwargs): - kernel = self.get_param('kernel') - return jnp.dot(x, kernel.transpose((1, 0))) - - A module method can be called on A Model instance directly:: - - y, initial_params = MyLinearModule.init(rng, x) - model = nn.Model(MyLinearModule, initial_params) - z = model.apply_transpose(y) - - Module methods can also be called on shared modules:: - - class AutoEncoder(nn.module): - def apply(self, x, features): - linear_fn = MyLinearModule.shared(features=features) - h = linear_fn(x) - y = linear_fn.apply_transpose(h) - return y - - - Args: - fn: the function to be decorated - Returns: - the decorated function - """ - - cache = {} - - # Module methods are just Module class instances. - # But we want it to inherit from the class such that we can call other methods - # of the module. We need a class property to find out which class the method - # is defined on. - def wrapper(cls): - if cls not in cache: - class ModuleMethod(cls): - apply = fn - ModuleMethod.__name__ = fn.__name__ - ModuleMethod.__qualname__ = f'{cls.__qualname__}.{fn.__name__}' - cache[cls] = ModuleMethod - return cache[cls] - - return utils.classproperty(wrapper) - - -def _fn_parameters(fn): - return tuple(inspect.signature(fn).parameters.values()) - - -MODULE_CLASSMETHODS = [ - 'create', 'create_by_shape', 'init', 'init_by_shape', 'call', 'partial' -] - - -class _ModuleMeta(abc.ABCMeta): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - A meta class for automatically setting the doc of Modules.""" - - def __init__(cls, name, bases, attrs): - super(_ModuleMeta, cls).__init__(name, bases, attrs) - apply_fn = cls.apply # pytype: disable=attribute-error - apply_doc = apply_fn.__doc__ - cls.__doc__ = apply_doc - apply_params = _fn_parameters(apply_fn) - cls.__signature__ = inspect.signature(cls).replace( - parameters=apply_params[1:]) - - if not bases: - return # skip method signature overrides for Module class. - - def wrap_special_method(name): - """Overrides the signature and docstring for one of Module's class methods.""" - orig_fn = getattr(Module, name) # pytype: disable=name-error - - @functools.wraps(orig_fn) - def wrapper(class_, *args, **kwargs): - super_fn = getattr(super(cls, class_), name) - return super_fn(*args, **kwargs) - wrapper.__doc__ = f'''{orig_fn.__doc__} - - Apply docstring: - - {apply_doc} - ''' - base_params = tuple(x for x in _fn_parameters(orig_fn) - if x.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD) - new_params = base_params + apply_params[1:] - wrapper.__signature__ = inspect.signature(orig_fn).replace( - parameters=new_params) - setattr(cls, name, classmethod(wrapper)) - - for name in MODULE_CLASSMETHODS: - wrap_special_method(name) - - -def _fold_in_str(rng, data): - """Folds a string into a jax.random.PRNGKey using its SHA-1 hash.""" - m = hashlib.sha1() - m.update(data.encode('utf-8')) - d = m.digest() - hash_int = int.from_bytes(d[:4], byteorder='big', signed=True) - return random.fold_in(rng, hash_int) - - -class Module(metaclass=_ModuleMeta): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Functional modules.""" - - def __new__(cls, *args, name=None, **kwargs): - warnings.warn("The `flax.nn` module is Deprecated, use `flax.linen` instead. Learn more and find an upgrade guide at https://github.com/google/flax/blob/main/flax/linen/README.md", DeprecationWarning) - # DO NOT REMOVE - Marker for internal logging. - if not _module_stack: - raise ValueError('A Module should only be instantiated directly inside' - ' another module.') - parent = cls._get_construction_frame() - apply_kwargs = cls._extend_kwargs(kwargs) - if name is None: - name = cls._default_name() - elif cls._is_shared(): - raise ValueError('Cannot override the name of a shared module') - if name is None: # also no default name - name = cls.__name__ + '_' + parent.create_name() - cls._check_name(name, parent) - if parent.is_init and name not in parent.params: - with jax.core.eval_context(): - rng = _fold_in_str(parent.rng, name) - params = {} - parent.params[name] = params - else: # apply - if name not in parent.params: - raise ValueError(f'No module named {name} was created during' - ' initialization.') - params = parent.params[name] - rng = None - frame = _ModuleFrame(name, parent=parent, rng=rng, params=params, - transparent=cls._is_transparent()) - with cls._with_instance(frame) as instance: - y = instance.apply(*args, **apply_kwargs) - _track_outputs(y) - return y - - @abc.abstractmethod - def apply(self, *args, **kwargs): - pass - - @classmethod - def shared(class_, *, name=None, **kwargs): - """Partially applies a module and shared parameters for each call. - - Args: - name: name of this module. - **kwargs: keyword arguments that should be partially applied. - Returns: - A subclass of Module that shares parameters when called multiple times. - """ - if not _module_stack: - raise ValueError( - 'The shared module should be used during Module application') - - parent = _module_stack[-1] - if name is None: - name = parent.create_name() - if name in parent.shared_names: - raise ValueError(f'Shared module named "{name}" already exists.') - parent.shared_names.add(name) - - partial_module = class_.partial(**kwargs) - - class SharedModule(partial_module): - """Wraps a module to enable shared parameters.""" - - @classmethod - def _default_name(cls): - return name - - @classmethod - def _is_shared(cls): - return True - - @classmethod - def _get_construction_frame(cls): - return parent - - SharedModule.__name__ = class_.__name__ - SharedModule.__qualname__ = class_.__qualname__ - - return SharedModule - - @classmethod - def _get_construction_frame(cls): - """Returns the ModuleFrame where this module was constructed. - - Modules can be shared across different parts of a parameter tree. - We need to ensure that the parameter object is the same in every instance - of the same shared module. We resolve this by deciding on a canonical - ModuleFrame (corresponding to a particular part of the top-level parameter - tree) where parameters are stored. Concretely, it is the - "construction frame" -- that is, the frame in which the module is first - defined. For non-shared modules, that's where it's called. For shared - modules, it's where `submodule.shared(...)` is called (which may or may - not be the frame in which it is used.) - - Returns: - The ModuleFrame instance where this module was constructed. - """ - return _module_stack[-1] - - @classmethod - def partial(class_, *, name=None, **kwargs): - """Partially applies a module with the given arguments. - - Unlike `functools.partial` this will return a subclass of Module. - - Args: - name: the name used the module - **kwargs: the argument to be applied. - Returns: - A subclass of Module which partially applies the given keyword arguments. - """ - - class PartialModule(class_): - """Wraps a module with partial application.""" - - @classmethod - def _default_name(cls): - if name is not None: - return name - else: - return super()._default_name() - - @classmethod - def _extend_kwargs(cls, invoke_kwargs): - extended_kwargs = kwargs.copy() - extended_kwargs.update(invoke_kwargs) - return super()._extend_kwargs(extended_kwargs) - # __doc__ is handled by the Module meta class - PartialModule.__name__ = class_.__name__ - PartialModule.__qualname__ = class_.__qualname__ - - return PartialModule - - @classmethod - def create(cls, _rng, *args, name=None, **kwargs): - """Creates a module instance by evaluating the model. - - DEPRECATION WARNING: - `create()` is deprecated use `init()` to initialize parameters and - then explicitly create a `nn.Model` given the module and initialized - parameters. - - Use create_by_shape instead to initialize without doing computation. - Initializer functions can depend both on the shape and the value of inputs. - - Args: - _rng: the random number generator used to initialize parameters. - *args: arguments passed to the module's apply function - name: name of this module - **kwargs: keyword arguments passed to the module's apply function - Returns: - A pair consisting of the model output and an instance of Model - """ - warnings.warn("`create()` will be removed soon." - " Use `init()` to initialize parameters and then explicitly" - " create a `nn.Model` given the module and initialized" - " parameters.", - DeprecationWarning) - y, params = cls.init(_rng, *args, name=name, **kwargs) - model = Model(cls, params) - return y, model - - @classmethod - def create_by_shape(cls, _rng, input_specs, *args, name=None, **kwargs): - """Creates a module instance using only shape and dtype information. - - DEPRECATION WARNING: - `create_by_shape()` is deprecated use `init_by_shape()` to initialize - parameters and then explicitly create a `nn.Model` given the module and - initialized parameters. - - - This method will initialize the model without computation. - Initializer functions can depend on the shape but not the value of inputs. - - Args: - _rng: the random number generator used to initialize parameters. - input_specs: an iterable of (shape, dtype) pairs specifying the inputs - *args: other arguments passed to the module's apply function - name: name of this module. - **kwargs: keyword arguments passed to the module's apply function - Returns: - A pair consisting of the model output and an instance of Model - """ - warnings.warn("`create_by_shape()` will be removed soon." - " Use `init_by_shape()` to initialize parameters and then" - " explicitly create a `nn.Model` given the module and " - " initialized parameters.", - DeprecationWarning) - - y, params = cls.init_by_shape(_rng, input_specs, *args, name=name, **kwargs) - model = Model(cls, params) - return y, model - - @classmethod - def init(cls, _rng, *args, name=None, **kwargs): - """Initializes the module parameters. - - Args: - _rng: the random number generator used to initialize parameters. - *args: arguments passed to the module's apply function - name: name of this module. - **kwargs: keyword arguments passed to the module's apply function - Returns: - A pair consisting of the model output and the initialized parameters - """ - kwargs = cls._extend_kwargs(kwargs) - if _module_stack: - parent = _module_stack[-1] - else: - parent = None - if name is None: - name = cls._default_name() - - frame = _ModuleFrame(name, rng=_rng, parent=parent, - transparent=cls._is_transparent()) - with cls._with_instance(frame) as instance: - y = instance.apply(*args, **kwargs) - _track_outputs(y) - return y, cls._post_process_params(frame.params) - - @classmethod - def init_by_shape(cls, _rng, input_specs, *args, name=None, **kwargs): - """Initialize the module parameters. - - This method will initialize the module parameters without computation. - Initializer functions can depend on the shape but not the value of inputs. - - Example:: - - input_shape = (batch_size, image_size, image_size, 3) - model_output, initial_params = model.init_by_shape(jax.random.PRNGKey(0), - input_specs=[(input_shape, jnp.float32)]) - - Args: - _rng: the random number generator used to initialize parameters. - input_specs: an iterable of (shape, dtype) pairs specifying the inputs - *args: arguments passed to the module's apply function - name: name of this module. - **kwargs: keyword arguments passed to the module's apply function - Returns: - A pair consisting of the model output and the initialized parameters - """ - stochastic_rng = None - try: - stochastic_rng = stochastic.make_rng() - except ValueError: - # Either there is no stochastic scope or the current - # scope is invalid due to another jax transformation. - # In both cases we should not try to lift the stochastic - # scope into the lazy evaluation - pass - - def lazy_init(*inputs): - def init_fn(): - return cls.init(_rng, *(inputs + args), name=name, **kwargs) - if stochastic_rng is not None: - # Create a new stochastic scope inside the lazy evaluation - # this way we can use a stochastic scope in combination - # with init_by_shape. - with stochastic.stochastic(stochastic_rng): - return init_fn() - else: - return init_fn() - return jax_utils.partial_eval_by_shape(lazy_init, input_specs) - - @classmethod - def call(cls, params, *args, name=None, **kwargs): - """Evaluate the module with the given parameters. - - Args: - params: the parameters of the module. Typically, initial parameter values - are constructed using `Module.init` or `Module.init_by_shape`. - *args: arguments passed to the module's apply function - name: name of this module. - **kwargs: keyword arguments passed to the module's apply function - Returns: - The output of the module's apply function. - """ - params = cls._pre_process_params(params) - kwargs = cls._extend_kwargs(kwargs) - if _module_stack: - parent = _module_stack[-1] - else: - parent = None - if name is None: - name = cls._default_name() - frame = _ModuleFrame(name, params=params, parent=parent, - transparent=cls._is_transparent()) - with cls._with_instance(frame) as instance: - y = instance.apply(*args, **kwargs) - _track_outputs(y) - return y - - def param(self, name, shape, initializer): - """Defines a parameter within the module's apply function. - - Args: - name: The name of the parameter. - shape: The shape of the parameter. If None the param be any type. - initializer: An initializer function - taking an RNG and the shape as arguments. - Returns: - The value of the parameter. - """ - frame = self._frame - if frame.is_init: - if name in frame.params: - raise ValueError( - "Name '%s' was already used for another parameter." % name) - with jax.core.eval_context(): - key = _fold_in_str(frame.rng, name) - frame.params[name] = initializer(key, shape) - if name not in frame.params: - raise ValueError("Parameter with name '%s' does not exist." % name) - param = frame.params[name] - if shape is not None and param.shape != shape: - raise ValueError( - 'Existing shape {} differs from requested shape {}'.format( - param.shape, shape)) - return param - - def get_param(self, name): - """Retrieves a parameter within the module's apply function. - - Args: - name: The name of the parameter. - Returns: - The value of the parameter. - """ - frame = self._frame - if name not in frame.params: - raise ValueError("Parameter with name '%s' does not exist." % name) - return frame.params[name] - - def state(self, name, shape=None, initializer=None, collection=None): - """Declare a state variable within the module's apply function. - - A state variable has an attribute value which can be updated by simply - assigning a value to it. For example:: - - class Example(nn.Module): - def apply(self, inputs, decay=0.9): - ema = self.state('ema', inputs.shape, initializers.zeros) - ema.value = decay * ema.value + (1 - decay) * inputs - return inputs - - By default, Modules are stateless. See `flax.nn.stateful` to enable stateful - computations. - - Args: - name: the name of the state variable. - shape: optional shape passed to the initializer (default: None) - initializer: optional initializer function - taking an RNG and the shape as arguments. - collection: optional `flax.nn.Collection` used to store the state. - By default the state collection passed to the `nn.stateful` context is - used. - Returns: - An instance of ModuleState. - """ - _top_frame('state') - if collection is None: - collection = get_state() - state = ModuleState(collection, name) - # find the frames that are in init mode - init_frames = [f for f in _module_stack if f.is_init] - if initializer is not None and init_frames: - # use the closest frame that is initializing to get an rng - init_frame = init_frames[-1] - with jax.core.eval_context(): - init_frame.rng, key = random.split(init_frame.rng) - init_value = initializer(key, shape) - state.value = init_value - return state - - def is_stateful(self): - return is_stateful() - - def is_initializing(self): - _top_frame('is_initializing') - return self._frame.is_init - - @classmethod - @contextlib.contextmanager - def _with_instance(cls, frame): - """A private constructor for Module. - - A module instance is constructed using a scope and is tied to a _ModuleFrame - This way the methods on the Module instance can rely on the _ModuleFrame - being available. - - Args: - frame: an instance of _ModuleFrame - Yields: - An instance of Module - """ - instance = object.__new__(cls) - instance._frame = frame # pylint: disable=protected-access - with _module_stack.frame(frame): - yield instance - - @classmethod - def _check_name(cls, name, parent): - """Check whether the name of the module is valid within the parent scope.""" - if name is not None: - if not isinstance(name, str): - raise ValueError('Name must be a string.') - if '/' in name or ':' in name: - raise ValueError('Name should not contain slashes or colons.') - shared = cls._is_shared() - if name in parent.shared: - # a module with this name already exists. Check validity of sharing - if shared != parent.shared[name]: - raise ValueError(f'The name "{name}" is used for both a shared' - ' and unshared module.') - if not parent.shared[name]: - raise ValueError(f'A module with named "{name}" already exists.') - parent.shared[name] = shared - - @classmethod - def _extend_kwargs(cls, kwargs): - return kwargs - - @classmethod - def _pre_process_params(cls, params): - return params - - @classmethod - def _post_process_params(cls, params): - return params - - @classmethod - def _is_transparent(cls): - return False - - @classmethod - def _is_shared(cls): - return False - - @classmethod - def _default_name(cls): - return None - - -def module(fun): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Convert a function into the apply method of a new Module. - - This is convenient shortcut for writing higher level modules that don't need - access to `self` for creating parameters directly. - - Example usage:: - - @nn.module - def DenseLayer(x, features): - x = flax.nn.Dense(x, features) - x = flax.nn.relu(x) - return x - - This is exactly equivalent to defining the following `nn.Module` subclass:: - - class DenseLayer(nn.Module): - def apply(self, x, features): - x = flax.nn.Dense(x, features) - x = flax.nn.relu(x) - return x - - Args: - fun: the function to convert. - Returns: - New Module subclass. - """ - @functools.wraps(fun) - def apply(self, *args, **kwargs): - del self # unused - return fun(*args, **kwargs) - return type(fun.__name__, (Module,), dict(apply=apply)) - - -# TODO(flax-dev) consider removing this... -class TransparentModule(Module): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - A transparent module. - - A transparent module can only have one parameter named '0'. - """ - - @classmethod - def _pre_process_params(cls, params): - return {'0': params} - - @classmethod - def _post_process_params(cls, params): - entries = list(params.items()) - if len(entries) != 1: - raise ValueError('Transparent modules should have exactly one child.') - key, value = entries[0] - if key != '0': - raise ValueError('Transparent modules should contain an unnamed child.') - return value - - @classmethod - def _is_transparent(cls): - return True - - -class TruncatedModule(TransparentModule): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Wraps a Module and returns the requested intermediate outputs instead. - - Check `Model.truncate_at` for a simple API to get the intermediate outputs of - an existing Model. - """ - - def apply(self, *args, wrapped_module=None, truncate_path=None, **kwargs): - """Applies the wrapped module and return some of its intermediate outputs. - - Args: - *args: the positional arguments for the wrapped module. - wrapped_module: The module class to be wrapped. - truncate_path: the full name of the module (eg. '/module/sub_module'). - A list or dict of module paths can be provided to obtain the - intermediate outputs of multiple modules. - **kwargs: the keyword arguments for the wrapped module. - Returns: - The intermediate outputs specified by truncate_path. - """ - if wrapped_module is None or truncate_path is None: - raise ValueError( - '`wrapped_module` and `truncate_path` are required keyword arguments') - with capture_module_outputs() as module_outputs: - wrapped_module(*args, **kwargs, name='0') - - def lookup_output(path): - return module_outputs[path] - return jax.tree_map(lookup_output, truncate_path) - - -@contextlib.contextmanager -def capture_module_outputs(): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - A context manager that captures all model outputs. - - Yields: - A `flax.nn.Collection` containing all module outputs. - """ - with Collection().mutate() as module_outputs: - with _module_output_trackers.frame(module_outputs): - yield module_outputs - - -class ModuleState(): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Tracks a state variable. - - ModuleState instances should not be created directly. See `Module.state` on - how to create state variables inside modules. - """ - - def __init__(self, collection, name): - self._collection = collection - self._name = name - - def _get_state_dict(self): - state_dict = self._collection.retrieve(default={}) - assert isinstance(state_dict, dict) - return state_dict - - @property - def name(self): - return self._name - - @property - def value(self): - state_dict = self._get_state_dict() - if self._name not in state_dict: - raise ValueError(f'No state variable named `{self._name}` exists.') - return state_dict[self._name] - - @value.setter - def value(self, v): - state_dict = self._get_state_dict() - state_dict[self._name] = v - self._collection.store(state_dict) - - -@contextlib.contextmanager -def stateful(state=None, mutable=True): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - A context manager for stateful computations. - - Module's that use the `Module.state` by default store state inside the - `Collection` specified by the (innermost) `nn.stateful` context manager. - - Typically stateful is used in 3 different modes: - - 1. During init no existing state is available and the stateful context creates - a new state collection. - 2. During training the state is passed to `nn.stateful` and the new state - is returned which will contain the updated state. - 3. During evaluation the state is passed with `mutable=False` such that the - model can retrieve the state but is not allowed to mutate it. - - Example:: - - class MyModel(nn.Module): - def apply(self, x): - x = nn.Dense(x, 12) - x = nn.BatchNorm(x) - return x - - with nn.stateful() as state: - _, initial_params = MyModel.init(rng, x) - model = nn.Model(MyModel, initial_params) - - with nn.stateful(state) as new_state: - model(x2) - - with nn.stateful(new_state, mutable=False): - evaluate_model(model) - - Args: - state: a `flax.nn.Collection` containing the current state. - By default a new collection will be created. - mutable: If true the state will be mutable otherwise it will be frozen. - Yields: - A `flax.nn.Collection` containing the new state. - """ - if state is None: - state = Collection() - if mutable: - with state.mutate() as new_state: - with _state_stack.frame(new_state): - yield new_state - else: - with _state_stack.frame(state): - yield state - - -def is_stateful(): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Returns true if a stateful scope is currently active (see `flax.nn.stateful`).""" - return bool(_state_stack) - - -def get_state(): - if not _state_stack: - raise ValueError('Use the flax.nn.stateful context manager to enable' - ' stateful computations.') - return _state_stack[-1] - - -def _top_frame(call_name): - if not _module_stack: - raise ValueError('%s should only be used inside a ' - 'module\'s apply function.' % call_name) - return _module_stack[-1] - - -@struct.dataclass -class Model: - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md - - A Model contains the model parameters, state and definition.""" - - module: Type[Module] = struct.field(pytree_node=False) - params: Any = struct.field(pytree_node=True) - - def __call__(self, *args, **kwargs): - return self.module.call(self.params, *args, **kwargs) - - def truncate_at(self, module_path): - """Truncates the model by returning the outputs of the given sub-module. - - Args: - module_path: the full name of the module (e.g. '/module/sub_module'). - A list or dict of module paths can be provided to obtain the - intermediate outputs of multiple modules. - Returns: - A new model with the truncated outputs. If module_path is a pytree of - paths the outputs will be have the same structure where each path is - replaced by the corresponding intermediate output. - """ - truncated_module_cls = TruncatedModule.partial( - wrapped_module=self.module, truncate_path=module_path) - return self.replace(module=truncated_module_cls) - - def __getattr__(self, name): - value = getattr(self.module, name) - if inspect.isclass(value) and issubclass(value, Module): - def wrapper(*args, **kwargs): - return value.call(self.params, *args, **kwargs) - return wrapper - raise AttributeError(f'No attribute named "{name}".') - - def __hash__(self): - # Jax will call hash when the model is passed to a function transform. - # The compiled function should not be shared among model instances because - # it closes over the specific parameters of this model instance. - return id(self) - - -class Collection: - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - A collection of tensors useful for tracking state. - - A Collection can be used to associate data with the application of a Module. - For example, a collection can be used to collect activations across modules. - Another common use case for collections is to track internal state. - For example, the running averages in BatchNorm can be stored in a collection. - - Attributes: - state: the initial state by default an empty collection is created. - """ - - def __init__(self, state=None): - if state is None: - state = {} - self.state = state - # The anchor is used to determine the prefix of the collection. - # This way we can create/nest collections inside modules. - self._anchor = _module_stack[-1] if _module_stack else None - - self._mutable = False - self._main_level = None - self._root = None - - def as_dict(self): - """Returns a dictionary with module paths as keys and the stored values. - - Returns: - The stored values as a dictionary. - """ - return self.state.copy() - - def __getitem__(self, key): - return self.state[key] - - @contextlib.contextmanager - def mutate(self): - # pylint: disable=protected-access - new_col = jax.tree_map(lambda x: x, self) # clone the collection - new_col._mutable = True - new_col._main_level = utils._trace_level(utils._current_trace()) - try: - yield new_col - finally: - new_col._mutable = False - - def retrieve(self, default=None): - """Retrieves a value from the Collection. - - This function should only be called with the apply function of a module. - - Args: - default: The default returned when nothing is stored (default: None) - Returns: - The value previously stored in the collection. - """ - _top_frame('retrieve') - path = self._current_path() - return self.state.get(path, default) - - def store(self, value): - """Stores a value in the Collection. - - This function should only be called with the apply function of a module. - - Args: - value: The value to be stored in the collection - Returns: - The previous value stored in the collection or None. - """ - frame = _top_frame('store') - if not self._mutable: - raise ValueError('Collection is not mutable. Use the `mutate` method to' - ' create a mutable copy.') - # Use the Jax TraceMain to determine if a Collection is modified from - # inside a nested jax transformation. - # In this case, we throw an error because transforming a stateful function - # is ill-defined (eg. what does vmap of BatchNorm do?). - # TODO(jheek): Add doc guide on combining jax transforms and state. - # TODO(jheek): Should some transformations be exempt from this error? - value_level = utils._level_of_value(value) - if value_level > self._main_level: - raise ValueError('Stateful operations are not allowed when the Collection' - ' is created outside of the current Jax transformation') - - # The root of a Collection is the first module scope that gets created - # inside the mutate scope of the Collection. By allowing only one unique - # root scope, we guarantee that state is not accidentally shared - # between different models. When a user specifies an explicit name we can - # distinguish models and a collection can have multiple roots. - if frame == self._anchor: - # Example: - # with nn.Collection.mutate() as coll: - # coll.store(1) - raise ValueError('State should be stored from within a module.' - ' Consider using the value directly instead of' - ' storing it in a Collection.') - if not frame.is_descendent_of(self._anchor): - # edge case where the Collection cannot capture the scope of a shared Module - # See test_collection_store_fails_if_out_of_scope in nn_test.py - raise ValueError('Trying to capture state outside the scope of this Collection.' - ' Most likely due to passing around a shared Module.') - root = self._find_root(frame) - if self._root is None: - self._root = root - elif self._root != root: - if self._root.name is None or root.name is None: - # In the following examples, should the two calls to `StatefulModule` share state or not? - # Because it's ambiguous, we throw an error and require the user to explicitly separate state - # by giving each instance a separate name, or to explicitly pass the same name - # in order to share state. - # with nn.statefull(state) as new_state: - # StatefulModule.call(params) - # StatefulModule.call(params2) - raise ValueError('When multiple top-level module calls use a Collection' - ' each top-level module should have a name.') - path = self._current_path() - old_value = self.state.get(path, None) - self.state[path] = value - return old_value - - def _find_root(self, frame): - """Finds the root frame with respect to the anchor. - - The root frame is defined as the child of anchor - that is an ancestor of a frame. - The root is used to verify that a Collection does not - have multiple unnamed roots. - - Args: - - frame: the frame of which we want to know the root - Returns: - The root of the given frame. - """ - assert frame.is_descendent_of(self._anchor) - root = frame - while root.parent != self._anchor: - root = root.parent - return root - - def _current_path(self): - """"The relative path from the currently active module scope to the root of the collection. - - For example: If a collection is created in the path '/module/nested' and - something is stored by a module with the path '/module/nested/block/conv' - the key in the collection dict will be '/block/conv'. - - Returns: - the relative path of the active module scope. - """ - frame = _module_stack[-1] - assert frame.is_descendent_of(self._anchor) - path = _module_stack[-1].path - if self._anchor is not None and self._anchor.path != '/': - prefix = self._anchor.path - assert prefix == path[:len(prefix)] - return path[len(prefix):] - else: - return path - -def iterate_collection(collection): - # jax iterates through pytrees for each argument/return value of a functional - # transformations. When the collection is mutable we throw an error this way - # we avoid silent errors due to impurity of a traced function. - if collection._mutable: # pylint: disable=protected-access - raise ValueError('A mutable collection should not be transformed by Jax.') - meta = (type(collection), collection._anchor) # pylint: disable=protected-access - return (collection.state,), meta - - -def collection_from_iterable(meta, state): - ty, anchor = meta - coll = ty(state[0]) - coll._anchor = anchor # pylint: disable=protected-access - return coll - -# make sure a collection is traced. -jax.tree_util.register_pytree_node(Collection, - iterate_collection, - collection_from_iterable) - - -def _collection_state_dict(collection): - return serialization._dict_state_dict(collection.as_dict()) # pylint: disable=protected-access - - -def _collection_from_state_dict(xs, state): - restored_state = serialization._restore_dict(xs.as_dict(), state) # pylint: disable=protected-access - - return Collection(restored_state) - - -serialization.register_serialization_state( - Collection, _collection_state_dict, _collection_from_state_dict) diff --git a/flax/deprecated/nn/initializers.py b/flax/deprecated/nn/initializers.py deleted file mode 100644 index 6cb872c17..000000000 --- a/flax/deprecated/nn/initializers.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Initializers for Flax. -""" - -# pylint: disable=unused-import -# re-export initializer functions from jax.nn -from jax.nn.initializers import kaiming_normal -from jax.nn.initializers import kaiming_uniform -from jax.nn.initializers import lecun_normal -from jax.nn.initializers import lecun_uniform -from jax.nn.initializers import normal -from jax.nn.initializers import ones -from jax.nn.initializers import orthogonal -from jax.nn.initializers import delta_orthogonal -from jax.nn.initializers import uniform -from jax.nn.initializers import variance_scaling -from jax.nn.initializers import xavier_normal -from jax.nn.initializers import xavier_uniform -from jax.nn.initializers import zeros -# pylint: enable=unused-import diff --git a/flax/deprecated/nn/linear.py b/flax/deprecated/nn/linear.py deleted file mode 100644 index fffd058db..000000000 --- a/flax/deprecated/nn/linear.py +++ /dev/null @@ -1,411 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Linear modules.""" - -from collections.abc import Iterable # pylint: disable=g-importing-member - -from . import base -from . import initializers - -from jax import lax - -import jax.numpy as jnp -import numpy as np - - -default_kernel_init = initializers.lecun_normal() - - -def _normalize_axes(axes, ndim): - # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. - return tuple([ax if ax >= 0 else ndim + ax for ax in axes]) - - -class DenseGeneral(base.Module): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - A linear transformation with flexible axes.""" - - def apply(self, - inputs, - features, - axis=-1, - batch_dims=(), - bias=True, - dtype=jnp.float32, - kernel_init=default_kernel_init, - bias_init=initializers.zeros, - precision=None): - """Applies a linear transformation to the inputs along multiple dimensions. - - Args: - inputs: The nd-array to be transformed. - features: tuple with numbers of output features. - axis: tuple with axes to apply the transformation on. - batch_dims: tuple with batch axes. - bias: whether to add a bias to the output (default: True). - dtype: the dtype of the computation (default: float32). - kernel_init: initializer function for the weight matrix. - bias_init: initializer function for the bias. - precision: numerical precision of the computation see `jax.lax.Precision` - for details. - Returns: - The transformed input. - """ - inputs = jnp.asarray(inputs, dtype) - - if not isinstance(features, Iterable): - features = (features,) - if not isinstance(axis, Iterable): - axis = (axis,) - if not isinstance(batch_dims, Iterable): - batch_dims = (batch_dims,) - features, axis, batch_dims = tuple(features), tuple(axis), tuple(batch_dims) - - if batch_dims: - max_dim = np.max(batch_dims) - if set(batch_dims) != set(range(max_dim + 1)): - raise ValueError('batch_dims %s must be consecutive leading ' - 'dimensions starting from 0.' % str(batch_dims)) - - ndim = inputs.ndim - n_batch_dims = len(batch_dims) - axis = _normalize_axes(axis, ndim) - batch_dims = _normalize_axes(batch_dims, ndim) - n_axis, n_features = len(axis), len(features) - - def kernel_init_wrap(rng, shape, dtype=jnp.float32): - size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32) - flat_shape = (np.prod(shape[n_batch_dims:n_axis + n_batch_dims]), - np.prod(shape[-n_features:]),) - kernel = jnp.concatenate([kernel_init(rng, flat_shape, dtype) - for _ in range(size_batch_dims)], axis=0) - return jnp.reshape(kernel, shape) - - batch_shape = tuple([inputs.shape[ax] for ax in batch_dims]) - kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features - kernel = self.param('kernel', batch_shape + kernel_shape, kernel_init_wrap) - kernel = jnp.asarray(kernel, dtype) - - batch_ind = tuple(range(n_batch_dims)) - contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims)) - out = lax.dot_general(inputs, - kernel, - ((axis, contract_ind), (batch_dims, batch_ind)), - precision=precision) - if bias: - def bias_init_wrap(rng, shape, dtype=jnp.float32): - size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32) - flat_shape = (np.prod(shape[-n_features:]),) - bias = jnp.concatenate([bias_init(rng, flat_shape, dtype) - for _ in range(size_batch_dims)], axis=0) - return jnp.reshape(bias, shape) - - bias = self.param('bias', batch_shape + features, bias_init_wrap) - - # Reshape bias for broadcast. - expand_dims = sorted( - set(range(inputs.ndim)) - set(axis) - set(batch_dims)) - for ax in expand_dims: - bias = jnp.expand_dims(bias, ax) - bias = jnp.asarray(bias, dtype) - out = out + bias - return out - - -class Dense(base.Module): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - A linear transformation applied over the last dimension of the input.""" - - def apply(self, - inputs, - features, - bias=True, - dtype=jnp.float32, - precision=None, - kernel_init=default_kernel_init, - bias_init=initializers.zeros): - """Applies a linear transformation to the inputs along the last dimension. - - Args: - inputs: The nd-array to be transformed. - features: the number of output features. - bias: whether to add a bias to the output (default: True). - dtype: the dtype of the computation (default: float32). - precision: numerical precision of the computation see `jax.lax.Precision` - for details. - kernel_init: initializer function for the weight matrix. - bias_init: initializer function for the bias. - Returns: - The transformed input. - """ - inputs = jnp.asarray(inputs, dtype) - kernel = self.param('kernel', (inputs.shape[-1], features), kernel_init) - kernel = jnp.asarray(kernel, dtype) - y = lax.dot_general(inputs, kernel, - (((inputs.ndim - 1,), (0,)), ((), ())), - precision=precision) - if bias: - bias = self.param('bias', (features,), bias_init) - bias = jnp.asarray(bias, dtype) - y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) - return y - - -def _conv_dimension_numbers(input_shape): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Computes the dimension numbers based on the input shape.""" - ndim = len(input_shape) - lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1)) - rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2)) - out_spec = lhs_spec - return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec) - - -class Conv(base.Module): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Convolution Module wrapping lax.conv_general_dilated.""" - - def apply(self, - inputs, - features, - kernel_size, - strides=None, - padding='SAME', - input_dilation=None, - kernel_dilation=None, - feature_group_count=1, - bias=True, - dtype=jnp.float32, - precision=None, - kernel_init=default_kernel_init, - bias_init=initializers.zeros): - """Applies a convolution to the inputs. - - Args: - inputs: input data with dimensions (batch, spatial_dims..., features). - features: number of convolution filters. - kernel_size: shape of the convolutional kernel. For 1D convolution, - the kernel size can be passed as an integer. For all other cases, it must - be a sequence of integers. - strides: a sequence of `n` integers, representing the inter-window - strides. - padding: either the string `'SAME'`, the string `'VALID'`, or a sequence - of `n` `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - input_dilation: `None`, or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of `inputs`. - Convolution with input dilation `d` is equivalent to transposed - convolution with stride `d`. - kernel_dilation: `None`, or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of the convolution - kernel. Convolution with kernel dilation is also known as 'atrous - convolution'. - feature_group_count: integer, default 1. If specified divides the input - features into groups. - bias: whether to add a bias to the output (default: True). - dtype: the dtype of the computation (default: float32). - precision: numerical precision of the computation see `jax.lax.Precision` - for details. - kernel_init: initializer for the convolutional kernel. - bias_init: initializer for the bias. - Returns: - The convolved data. - """ - - inputs = jnp.asarray(inputs, dtype) - if isinstance(kernel_size, int): - kernel_size = (kernel_size,) - - is_single_input = False - if inputs.ndim == len(kernel_size) + 1: - is_single_input = True - inputs = jnp.expand_dims(inputs, axis=0) - - if strides is None: - strides = (1,) * (inputs.ndim - 2) - - in_features = inputs.shape[-1] - assert in_features % feature_group_count == 0 - kernel_shape = kernel_size + (in_features // feature_group_count, features) - kernel = self.param('kernel', kernel_shape, kernel_init) - kernel = jnp.asarray(kernel, dtype) - - dimension_numbers = _conv_dimension_numbers(inputs.shape) - y = lax.conv_general_dilated( - inputs, - kernel, - strides, - padding, - lhs_dilation=input_dilation, - rhs_dilation=kernel_dilation, - dimension_numbers=dimension_numbers, - feature_group_count=feature_group_count, - precision=precision) - - if is_single_input: - y = jnp.squeeze(y, axis=0) - if bias: - bias = self.param('bias', (features,), bias_init) - bias = jnp.asarray(bias, dtype) - y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) - return y - - -class ConvTranspose(base.Module): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Transposed convolution Module wrapping lax.conv_transpose.""" - - def apply(self, - inputs, - features, - kernel_size, - strides=None, - padding='SAME', - kernel_dilation=None, - bias=True, - dtype=jnp.float32, - precision=None, - kernel_init=default_kernel_init, - bias_init=initializers.zeros): - """Applies a transposed convolution to the inputs. Behaviour mirrors that of - `jax.lax.conv_transpose`. - - Args: - inputs: input data with dimensions (batch, spatial_dims..., features). - features: number of convolution filters. - kernel_size: shape of the convolutional kernel. For 1D convolution, - the kernel size can be passed as an integer. For all other cases, it must - be a sequence of integers. - strides: a sequence of `n` integers, representing the inter-window - strides. - padding: either the string `'SAME'`, the string `'VALID'`, or a sequence - of `n` `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - kernel_dilation: `None`, or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of the convolution - kernel. Convolution with kernel dilation is also known as 'atrous - convolution'. - bias: whether to add a bias to the output (default: True). - dtype: the dtype of the computation (default: float32). - precision: numerical precision of the computation see `jax.lax.Precision` - for details. - kernel_init: initializer for the convolutional kernel. - bias_init: initializer for the bias. - Returns: - The convolved data. - """ - inputs = jnp.asarray(inputs, dtype) - if isinstance(kernel_size, int): - kernel_size = (kernel_size,) - - is_single_input = False - if inputs.ndim == len(kernel_size) + 1: - is_single_input = True - inputs = jnp.expand_dims(inputs, axis=0) - - strides = strides or (1,) * (inputs.ndim - 2) - - in_features = inputs.shape[-1] - kernel_shape = kernel_size + (in_features, features) - kernel = self.param('kernel', kernel_shape, kernel_init) - kernel = jnp.asarray(kernel, dtype) - - y = lax.conv_transpose(inputs, kernel, strides, padding, - rhs_dilation=kernel_dilation, precision=precision) - - if is_single_input: - y = jnp.squeeze(y, axis=0) - if bias: - bias = self.param('bias', (features,), bias_init) - bias = jnp.asarray(bias, dtype) - y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) - return y - - -default_embed_init = initializers.variance_scaling(1.0, 'fan_in', 'normal', - out_axis=0) - - -class Embed(base.Module): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Embedding Module. - - A parameterized function from integers [0, n) to d-dimensional vectors. - """ - - def apply(self, - inputs, - num_embeddings, - features, - embedding_init=default_embed_init): - """Embeds the inputs along the last dimension. - - Args: - inputs: input data, all dimensions are considered batch dimensions. - num_embeddings: number of embeddings. - features: Number of feature dimensions for each embedding. - embedding_init: embedding initializer. - - Returns: - Output which is embedded input data. The output shape follows the input, - with an additional `features` dimension appended. - """ - if not jnp.issubdtype(inputs.dtype, jnp.integer): - raise ValueError('Input type must be an integer or unsigned integer.') - embedding = self.param('embedding', (num_embeddings, features), - embedding_init) - return embedding[inputs] - - @base.module_method - def attend(self, query, **unused_kwargs): - """Attend over the embedding using a query array. - - Args: - query: array with last dimension equal the feature depth `features` of the - embedding. - **unused_kwargs: unused arguments passed from the apply method. - - Returns: - An array with final dim `num_embeddings` corresponding to the batched - inner-product of the array of query vectors against each embedding. - Commonly used for weight-sharing between embeddings and logit transform - in NLP models. - """ - del unused_kwargs - embedding = self.get_param('embedding') - return lax.dot_general( - query, embedding, (((query.ndim - 1,), (1,)), ((), ()))) diff --git a/flax/deprecated/nn/normalization.py b/flax/deprecated/nn/normalization.py deleted file mode 100644 index 24b45af99..000000000 --- a/flax/deprecated/nn/normalization.py +++ /dev/null @@ -1,266 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Normalization modules for Flax.""" - -from . import base - -from jax import lax -from jax.nn import initializers -import jax.numpy as jnp - - -_no_init = lambda rng, shape: () - - -def _absolute_dims(rank, dims): - return tuple([rank + dim if dim < 0 else dim for dim in dims]) - - -class BatchNorm(base.Module): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - BatchNorm Module.""" - - def apply(self, - x, - batch_stats=None, - use_running_average=False, - axis=-1, - momentum=0.99, - epsilon=1e-5, - dtype=jnp.float32, - bias=True, - scale=True, - bias_init=initializers.zeros, - scale_init=initializers.ones, - axis_name=None, - axis_index_groups=None): - """Normalizes the input using batch statistics. - - Args: - x: the input to be normalized. - batch_stats: a `flax.nn.Collection` used to store an exponential moving - average of the batch statistics (default: None). - use_running_average: if true, the statistics stored in batch_stats - will be used instead of computing the batch statistics on the input. - axis: the feature or non-batch axis of the input. - momentum: decay rate for the exponential moving average of - the batch statistics. - epsilon: a small float added to variance to avoid dividing by zero. - dtype: the dtype of the computation (default: float32). - bias: if True, bias (beta) is added. - scale: if True, multiply by scale (gamma). - When the next layer is linear (also e.g. nn.relu), this can be disabled - since the scaling will be done by the next layer. - bias_init: initializer for bias, by default, zero. - scale_init: initializer for scale, by default, one. - axis_name: the axis name used to combine batch statistics from multiple - devices. See `jax.pmap` for a description of axis names (default: None). - axis_index_groups: groups of axis indices within that named axis - representing subsets of devices to reduce over (default: None). For example, - `[[0, 1], [2, 3]]` would independently batch-normalize over the examples - on the first two and last two devices. See `jax.lax.psum` for more details. - - Returns: - Normalized inputs (the same shape as inputs). - """ - x = jnp.asarray(x, jnp.float32) - axis = axis if isinstance(axis, tuple) else (axis,) - axis = _absolute_dims(x.ndim, axis) - feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) - reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) - reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) - if self.is_stateful() or batch_stats: - ra_mean = self.state('mean', reduced_feature_shape, - initializers.zeros, collection=batch_stats) - ra_var = self.state('var', reduced_feature_shape, - initializers.ones, collection=batch_stats) - else: - ra_mean = None - ra_var = None - - if use_running_average: - if ra_mean is None: - raise ValueError('when use_running_averages is True ' - 'either use a stateful context or provide batch_stats') - mean, var = ra_mean.value, ra_var.value - else: - mean = jnp.mean(x, axis=reduction_axis, keepdims=False) - mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False) - if axis_name is not None and not self.is_initializing(): - concatenated_mean = jnp.concatenate([mean, mean2]) - mean, mean2 = jnp.split( - lax.pmean( - concatenated_mean, - axis_name=axis_name, - axis_index_groups=axis_index_groups), 2) - var = mean2 - lax.square(mean) - - if ra_mean and not self.is_initializing(): - ra_mean.value = momentum * ra_mean.value + (1 - momentum) * mean - ra_var.value = momentum * ra_var.value + (1 - momentum) * var - - y = x - mean.reshape(feature_shape) - mul = lax.rsqrt(var + epsilon).reshape(feature_shape) - if scale: - mul = mul * self.param( - 'scale', reduced_feature_shape, scale_init).reshape(feature_shape) - y = y * mul - if bias: - y = y + self.param( - 'bias', reduced_feature_shape, bias_init).reshape(feature_shape) - return jnp.asarray(y, dtype) - - -class LayerNorm(base.Module): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Layer normalization (https://arxiv.org/abs/1607.06450). - - Operates on the last axis of the input data. - """ - - def apply(self, - x, - epsilon=1e-6, - dtype=jnp.float32, - bias=True, - scale=True, - bias_init=initializers.zeros, - scale_init=initializers.ones): - """Applies layer normalization on the input. - - It normalizes the activations of the layer for each given example in a - batch independently, rather than across a batch like Batch Normalization. - i.e. applies a transformation that maintains the mean activation within - each example close to 0 and the activation standard deviation close to 1. - - Args: - x: the inputs - epsilon: A small float added to variance to avoid dividing by zero. - dtype: the dtype of the computation (default: float32). - bias: If True, bias (beta) is added. - scale: If True, multiply by scale (gamma). When the next layer is linear - (also e.g. nn.relu), this can be disabled since the scaling will be done - by the next layer. - bias_init: Initializer for bias, by default, zero. - scale_init: Initializer for scale, by default, one. - - Returns: - Normalized inputs (the same shape as inputs). - - """ - x = jnp.asarray(x, jnp.float32) - features = x.shape[-1] - mean = jnp.mean(x, axis=-1, keepdims=True) - mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) - var = mean2 - lax.square(mean) - mul = lax.rsqrt(var + epsilon) - if scale: - mul = mul * jnp.asarray(self.param('scale', (features,), scale_init), - dtype) - y = (x - mean) * mul - if bias: - y = y + jnp.asarray(self.param('bias', (features,), bias_init), dtype) - return jnp.asarray(y, dtype) - - -class GroupNorm(base.Module): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Group normalization (arxiv.org/abs/1803.08494).""" - - def apply(self, - x, - num_groups=32, - group_size=None, - epsilon=1e-6, - dtype=jnp.float32, - bias=True, - scale=True, - bias_init=initializers.zeros, - scale_init=initializers.ones): - """Applies group normalization to the input (arxiv.org/abs/1803.08494). - - This op is similar to batch normalization, but statistics are shared across - equally-sized groups of channels and not shared across batch dimension. - Thus, group normalization does not depend on the batch composition and does - not require maintaining internal state for storing statistics. - - The user should either specify the total number of channel groups or the - number of channels per group. - - Args: - x: the input of shape N...C, where N is a batch dimension and C is a - channels dimensions. `...` represents an arbitrary number of extra - dimensions that are used to accumulate statistics over. - num_groups: the total number of channel groups. The default value of 32 is - proposed by the original group normalization paper. - group_size: the number of channels in a group. - epsilon: A small float added to variance to avoid dividing by zero. - dtype: the dtype of the computation (default: float32). - bias: If True, bias (beta) is added. - scale: If True, multiply by scale (gamma). When the next layer is linear - (also e.g. nn.relu), this can be disabled since the scaling will be done - by the next layer. - bias_init: Initializer for bias, by default, zero. - scale_init: Initializer for scale, by default, one. - - Returns: - Normalized inputs (the same shape as inputs). - - """ - x = jnp.asarray(x, jnp.float32) - if ((num_groups is None and group_size is None) or - (num_groups is not None and group_size is not None)): - raise ValueError('Either `num_groups` or `group_size` should be ' - 'specified, but not both of them.') - - if group_size is not None: - channels = x.shape[-1] - if channels % group_size != 0: - raise ValueError('Number of channels ({}) is not multiple of the ' - 'group size ({}).'.format(channels, group_size)) - num_groups = channels // group_size - - input_shape = x.shape - group_shape = x.shape[:-1] + (num_groups, x.shape[-1] // num_groups) - - x = x.reshape(group_shape) - - reduction_axis = [d for d in range(1, x.ndim - 2)] + [x.ndim - 1] - - mean = jnp.mean(x, axis=reduction_axis, keepdims=True) - mean_of_squares = jnp.mean(jnp.square(x), axis=reduction_axis, - keepdims=True) - var = mean_of_squares - jnp.square(mean) - - x = (x - mean) * lax.rsqrt(var + epsilon) - - x = x.reshape(input_shape) - - feature_shape = tuple([1 for d in input_shape[:-1]] + [input_shape[-1]]) - if scale: - x = x * self.param('scale', feature_shape, scale_init) - if bias: - x = x + self.param('bias', feature_shape, bias_init) - - return x.astype(dtype) diff --git a/flax/deprecated/nn/pooling.py b/flax/deprecated/nn/pooling.py deleted file mode 100644 index 45b332d49..000000000 --- a/flax/deprecated/nn/pooling.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Pooling modules.""" - -from jax import lax -import jax.numpy as jnp - -import numpy as np - - -def pool(inputs, init, reduce_fn, window_shape, strides, padding): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Helper function to define pooling functions. - - Pooling functions are implemented using the ReduceWindow XLA op. - NOTE: Be aware that pooling is not generally differentiable. - That means providing a reduce_fn that is differentiable does not imply - that pool is differentiable. - - Args: - inputs: input data with dimensions (batch, window dims..., features). - init: the initial value for the reduction - reduce_fn: a reduce function of the form `(T, T) -> T`. - window_shape: a shape tuple defining the window to reduce over. - strides: a sequence of `n` integers, representing the inter-window - strides. - padding: either the string `'SAME'`, the string `'VALID'`, or a sequence - of `n` `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - Returns: - The output of the reduction for each window slice. - """ - strides = strides or (1,) * len(window_shape) - strides = (1,) + strides + (1,) - dims = (1,) + window_shape + (1,) - if not isinstance(padding, str): - padding = tuple(map(tuple, padding)) - assert(len(padding) == len(window_shape)), ( - f"padding {padding} must specify pads for same number of dims as " - f"window_shape {window_shape}") - assert(all([len(x) == 2 for x in padding])), ( - f"each entry in padding {padding} must be length 2") - padding = ((0,0),) + padding + ((0,0),) - return lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding) - - -def avg_pool(inputs, window_shape, strides=None, padding="VALID"): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Pools the input by taking the average over a window. - - Args: - inputs: input data with dimensions (batch, window dims..., features). - window_shape: a shape tuple defining the window to reduce over. - strides: a sequence of `n` integers, representing the inter-window - strides (default: `(1, ..., 1)`). - padding: either the string `'SAME'`, the string `'VALID'`, or a sequence - of `n` `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension (default: `'VALID'`). - Returns: - The average for each window slice. - """ - y = pool(inputs, 0., lax.add, window_shape, strides, padding) - y = y / np.prod(window_shape) - return y - - -def max_pool(inputs, window_shape, strides=None, padding="VALID"): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Pools the input by taking the maximum of a window slice. - - Args: - inputs: input data with dimensions (batch, window dims..., features). - window_shape: a shape tuple defining the window to reduce over. - strides: a sequence of `n` integers, representing the inter-window - strides (default: `(1, ..., 1)`). - padding: either the string `'SAME'`, the string `'VALID'`, or a sequence - of `n` `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension (default: `'VALID'`). - Returns: - The maximum for each window slice. - """ - y = pool(inputs, -jnp.inf, lax.max, window_shape, strides, padding) - return y diff --git a/flax/deprecated/nn/recurrent.py b/flax/deprecated/nn/recurrent.py deleted file mode 100644 index cd1079f17..000000000 --- a/flax/deprecated/nn/recurrent.py +++ /dev/null @@ -1,440 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Recurrent neural network modules. - -THe RNNCell modules are designed to fit in with the scan function in JAX:: - - _, initial_params = LSTMCell.init(rng_1, time_series[0]) - model = nn.Model(LSTMCell, initial_params) - carry = LSTMCell.initialize_carry(rng_2, (batch_size,), memory_size) - carry, y = jax.lax.scan(model, carry, time_series) - -""" - -import abc - -from . import activation -from . import base -from . import initializers -from . import linear - -from jax import numpy as jnp -from jax import random -from jax import lax -import numpy as np - - -class RNNCellBase(base.Module): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - RNN cell base class.""" - - @staticmethod - @abc.abstractmethod - def initialize_carry(rng, batch_dims, size, init_fn=initializers.zeros): - """initialize the RNN cell carry. - - Args: - rng: random number generator passed to the init_fn. - batch_dims: a tuple providing the shape of the batch dimensions. - size: the size or number of features of the memory. - init_fn: initializer function for the carry. - Returns: - An initialized carry for the given RNN cell. - """ - pass - - -class LSTMCell(RNNCellBase): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - LSTM cell.""" - - def apply(self, carry, inputs, - gate_fn=activation.sigmoid, activation_fn=activation.tanh, - kernel_init=linear.default_kernel_init, - recurrent_kernel_init=initializers.orthogonal(), - bias_init=initializers.zeros): - r"""A long short-term memory (LSTM) cell. - - the mathematical definition of the cell is as follows - .. math:: - \begin{array}{ll} - i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ - f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ - g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ - o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ - c' = f * c + i * g \\ - h' = o * \tanh(c') \\ - \end{array} - where x is the input, h is the output of the previous time step, and c is - the memory. - - Args: - carry: the hidden state of the LSTM cell, - initialized using `LSTMCell.initialize_carry`. - inputs: an ndarray with the input for the current time step. - All dimensions except the final are considered batch dimensions. - gate_fn: activation function used for gates (default: sigmoid) - activation_fn: activation function used for output and memory update - (default: tanh). - kernel_init: initializer function for the kernels that transform - the input (default: lecun_normal). - recurrent_kernel_init: initializer function for the kernels that transform - the hidden state (default: orthogonal). - bias_init: initializer for the bias parameters (default: zeros) - Returns: - A tuple with the new carry and the output. - """ - c, h = carry - hidden_features = h.shape[-1] - # input and recurrent layers are summed so only one needs a bias. - dense_h = linear.Dense.partial( - inputs=h, features=hidden_features, bias=True, - kernel_init=recurrent_kernel_init, bias_init=bias_init) - dense_i = linear.Dense.partial( - inputs=inputs, features=hidden_features, bias=False, - kernel_init=kernel_init) - i = gate_fn(dense_i(name='ii') + dense_h(name='hi')) - f = gate_fn(dense_i(name='if') + dense_h(name='hf')) - g = activation_fn(dense_i(name='ig') + dense_h(name='hg')) - o = gate_fn(dense_i(name='io') + dense_h(name='ho')) - new_c = f * c + i * g - new_h = o * activation_fn(new_c) - return (new_c, new_h), new_h - - @staticmethod - def initialize_carry(rng, batch_dims, size, init_fn=initializers.zeros): - """initialize the RNN cell carry. - - Args: - rng: random number generator passed to the init_fn. - batch_dims: a tuple providing the shape of the batch dimensions. - size: the size or number of features of the memory. - init_fn: initializer function for the carry. - Returns: - An initialized carry for the given RNN cell. - """ - key1, key2 = random.split(rng) - mem_shape = batch_dims + (size,) - return init_fn(key1, mem_shape), init_fn(key2, mem_shape) - - -class OptimizedLSTMCell(RNNCellBase): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - More efficient LSTM Cell that concatenates state components before matmul. - - Parameters are compatible with `flax.nn.LSTMCell`. - """ - - class DummyDense(base.Module): - """Dummy module for creating parameters matching `flax.nn.Dense`.""" - - def apply(self, - inputs, - features, - kernel_init, - bias_init, - bias=True): - k = self.param('kernel', (inputs.shape[-1], features), kernel_init) - b = (self.param('bias', (features,), bias_init) - if bias else jnp.zeros((features,))) - return k, b - - def apply(self, - carry, - inputs, - gate_fn=activation.sigmoid, - activation_fn=activation.tanh, - kernel_init=linear.default_kernel_init, - recurrent_kernel_init=initializers.orthogonal(), - bias_init=initializers.zeros): - r"""A long short-term memory (LSTM) cell. - - the mathematical definition of the cell is as follows - .. math:: - \begin{array}{ll} - i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ - f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ - g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ - o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ - c' = f * c + i * g \\ - h' = o * \tanh(c') \\ - \end{array} - where x is the input, h is the output of the previous time step, and c is - the memory. - - Args: - carry: the hidden state of the LSTM cell, initialized using - `LSTMCell.initialize_carry`. - inputs: an ndarray with the input for the current time step. All - dimensions except the final are considered batch dimensions. - gate_fn: activation function used for gates (default: sigmoid) - activation_fn: activation function used for output and memory update - (default: tanh). - kernel_init: initializer function for the kernels that transform - the input (default: lecun_normal). - recurrent_kernel_init: initializer function for the kernels that transform - the hidden state (default: orthogonal). - bias_init: initializer for the bias parameters (default: zeros) - - Returns: - A tuple with the new carry and the output. - """ - c, h = carry - hidden_features = h.shape[-1] - - def _concat_dense(inputs, params, use_bias=True): - kernels, biases = zip(*params.values()) - kernel = jnp.asarray(jnp.concatenate(kernels, axis=-1), jnp.float32) - - y = jnp.dot(inputs, kernel) - if use_bias: - bias = jnp.asarray(jnp.concatenate(biases, axis=-1), jnp.float32) - y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) - - # Split the result back into individual (i, f, g, o) outputs. - split_indices = np.cumsum([b.shape[0] for b in biases[:-1]]) - ys = jnp.split(y, split_indices, axis=-1) - return dict(zip(params.keys(), ys)) - - # Create the params in the same order as LSTMCell for initialization - # compatibility. - dense_params_h = {} - dense_params_i = {} - for component in ['i', 'f', 'g', 'o']: - dense_params_i[component] = OptimizedLSTMCell.DummyDense( - inputs=inputs, features=hidden_features, bias=False, - kernel_init=kernel_init, bias_init=bias_init, - name=f'i{component}') - dense_params_h[component] = OptimizedLSTMCell.DummyDense( - inputs=h, features=hidden_features, bias=True, - kernel_init=recurrent_kernel_init, bias_init=bias_init, - name=f'h{component}') - dense_h = _concat_dense(h, dense_params_h, use_bias=True) - dense_i = _concat_dense(inputs, dense_params_i, use_bias=False) - - i = gate_fn(dense_h['i'] + dense_i['i']) - f = gate_fn(dense_h['f'] + dense_i['f']) - g = activation_fn(dense_h['g'] + dense_i['g']) - o = gate_fn(dense_h['o'] + dense_i['o']) - - new_c = f * c + i * g - new_h = o * activation_fn(new_c) - return (new_c, new_h), new_h - - @staticmethod - def initialize_carry(rng, batch_dims, size, init_fn=initializers.zeros): - """initialize the RNN cell carry. - - Args: - rng: random number generator passed to the init_fn. - batch_dims: a tuple providing the shape of the batch dimensions. - size: the size or number of features of the memory. - init_fn: initializer function for the carry. - - Returns: - An initialized carry for the given RNN cell. - """ - key1, key2 = random.split(rng) - mem_shape = batch_dims + (size,) - return init_fn(key1, mem_shape), init_fn(key2, mem_shape) - - -class GRUCell(RNNCellBase): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - GRU cell.""" - - def apply(self, carry, inputs, - gate_fn=activation.sigmoid, activation_fn=activation.tanh, - kernel_init=linear.default_kernel_init, - recurrent_kernel_init=initializers.orthogonal(), - bias_init=initializers.zeros): - r"""Gated recurrent unit (GRU) cell. - - the mathematical definition of the cell is as follows - .. math:: - \begin{array}{ll} - r = \sigma(W_{ir} x + W_{hr} h + b_{hr}) \\ - z = \sigma(W_{iz} x + W_{hz} h + b_{hz}) \\ - n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ - h' = (1 - z) * n + z * h - \end{array} - where x is the input and h, is the output of the previous time step. - - Args: - carry: the hidden state of the LSTM cell, - initialized using `GRUCell.initialize_carry`. - inputs: an ndarray with the input for the current time step. - All dimensions except the final are considered batch dimensions. - gate_fn: activation function used for gates (default: sigmoid) - activation_fn: activation function used for output and memory update - (default: tanh). - kernel_init: initializer function for the kernels that transform - the input (default: lecun_normal). - recurrent_kernel_init: initializer function for the kernels that transform - the hidden state (default: orthogonal). - bias_init: initializer for the bias parameters (default: zeros) - Returns: - A tuple with the new carry and the output. - """ - h = carry - hidden_features = h.shape[-1] - # input and recurrent layers are summed so only one needs a bias. - dense_h = linear.Dense.partial( - inputs=h, features=hidden_features, bias=False, - kernel_init=recurrent_kernel_init, bias_init=bias_init) - dense_i = linear.Dense.partial( - inputs=inputs, features=hidden_features, bias=True, - kernel_init=kernel_init, bias_init=bias_init) - r = gate_fn(dense_i(name='ir') + dense_h(name='hr')) - z = gate_fn(dense_i(name='iz') + dense_h(name='hz')) - # add bias because the linear transformations aren't directly summed. - n = activation_fn(dense_i(name='in') + r * dense_h(name='hn', bias=True)) - new_h = (1. - z) * n + z * h - return new_h, new_h - - @staticmethod - def initialize_carry(rng, batch_dims, size, init_fn=initializers.zeros): - """initialize the RNN cell carry. - - Args: - rng: random number generator passed to the init_fn. - batch_dims: a tuple providing the shape of the batch dimensions. - size: the size or number of features of the memory. - init_fn: initializer function for the carry. - Returns: - An initialized carry for the given RNN cell. - """ - mem_shape = batch_dims + (size,) - return init_fn(rng, mem_shape) - - -class ConvLSTM(RNNCellBase): - r"""DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - A convolutional LSTM cell. - - The implementation is based on xingjian2015convolutional. - Given x_t and the previous state (h_{t-1}, c_{t-1}) - the core computes - - .. math:: - - \begin{array}{ll} - i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ - f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ - g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ - o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ - c_t = f_t c_{t-1} + i_t g_t \\ - h_t = o_t \tanh(c_t) - \end{array} - - where * denotes the convolution operator; - i_t, f_t, o_t are input, forget and output gate activations, - and g_t is a vector of cell updates. - - Notes: - Forget gate initialization: - Following jozefowicz2015empirical we add 1.0 to b_f - after initialization in order to reduce the scale of forgetting in - the beginning of the training. - """ - - def apply(self, - carry, - inputs, - features, - kernel_size, - strides=None, - padding='SAME', - bias=True, - dtype=jnp.float32): - """Constructs a convolutional LSTM. - - Args: - carry: the hidden state of the Conv2DLSTM cell, - initialized using `Conv2DLSTM.initialize_carry`. - inputs: input data with dimensions (batch, spatial_dims..., features). - features: number of convolution filters. - kernel_size: shape of the convolutional kernel. - strides: a sequence of `n` integers, representing the inter-window - strides. - padding: either the string `'SAME'`, the string `'VALID'`, or a sequence - of `n` `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - bias: whether to add a bias to the output (default: True). - dtype: the dtype of the computation (default: float32). - Returns: - A tuple with the new carry and the output. - """ - c, h = carry - input_to_hidden = linear.Conv.partial( - features=4*features, - kernel_size=kernel_size, - strides=strides, - padding=padding, - bias=bias, - dtype=dtype, - name="ih") - - hidden_to_hidden = linear.Conv.partial( - features=4*features, - kernel_size=kernel_size, - strides=strides, - padding=padding, - bias=bias, - dtype=dtype, - name="hh") - - gates = input_to_hidden(inputs) + hidden_to_hidden(h) - i, g, f, o = jnp.split(gates, indices_or_sections=4, axis=-1) - - f = activation.sigmoid(f + 1) - new_c = f * c + activation.sigmoid(i) * jnp.tanh(g) - new_h = activation.sigmoid(o) * jnp.tanh(new_c) - return (new_c, new_h), new_h - - @staticmethod - def initialize_carry(rng, batch_dims, size, init_fn=initializers.zeros): - """initialize the RNN cell carry. - - Args: - rng: random number generator passed to the init_fn. - batch_dims: a tuple providing the shape of the batch dimensions. - size: the input_shape + (features,). - init_fn: initializer function for the carry. - Returns: - An initialized carry for the given RNN cell. - """ - key1, key2 = random.split(rng) - mem_shape = batch_dims + size - return init_fn(key1, mem_shape), init_fn(key2, mem_shape) diff --git a/flax/deprecated/nn/stochastic.py b/flax/deprecated/nn/stochastic.py deleted file mode 100644 index e42f21a33..000000000 --- a/flax/deprecated/nn/stochastic.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Stochastic modules. -""" - -import contextlib - -from . import utils -from jax import lax -from jax import random -import jax.numpy as jnp - - -_prng_stack = utils.CallStack() - - -class _PRNGFrame: - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Random Number generator scope responsible for generation prngs in a stochastic context.""" - - def __init__(self, rng): - self.base_rng = rng - self.counter = 0 - self.level = utils._trace_level(utils._current_trace()) - - def make_rng(self): - # when calling make_rng within a jax transformations - # the rng could be implicitly reused (eg. in jit, vmap, scan, ...). - # We raise an error to avoid silent errors. - level = utils._trace_level(utils._current_trace()) - if level > self.level: - raise ValueError('stochastic operations are not allowed when the' - ' stochastic context is created outside of the' - ' current Jax transformation') - self.counter += 1 - return random.fold_in(self.base_rng, self.counter) - - -@contextlib.contextmanager -def stochastic(rng): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - A context manager for stochastic computations. - - Args: - rng: the random number generator used as a seed for the stochastic context. - Yields: - A scope in which unique rngs can be created using `nn.make_rng()`. - """ - with _prng_stack.frame(_PRNGFrame(rng)): - yield - - -def is_stochastic(): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Returns true if a stochastic scope is currently active.""" - return bool(_prng_stack) - - -def make_rng(): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Create a new unique random number generator in a stochastic scope. - - In combination with `nn.stochastic()` this function is used to generate random - keys without manually passing around and splitting a random number generator:: - - with nn.stochastic(rng): - x = random.normal(nn.make_rng(), shape) - x_drop = nn.dropout(x, 0.5) - - - Returns: - A unique jax.random.PRNGKey. - """ - if not _prng_stack: - raise ValueError('Use the `nn.stochastic()` context manager to enable' - ' stochastic computations.') - rng_frame = _prng_stack[-1] - return rng_frame.make_rng() - - -def dropout(inputs, rate, deterministic=False, rng=None): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Applies a random dropout mask to the input. - - Args: - inputs: the inputs that should be randomly masked. - rate: the probablity of masking out a value. - deterministic: if false the inputs are scaled by `1 / (1 - rate)` and - masked, whereas if true, no mask is applied and the inputs are returned as - is. - rng: an optional `jax.random.PRNGKey`. By default `nn.make_rng()` will - be used. - Returns: - The masked inputs. - """ - if rate == 0.: - return inputs - keep_prob = 1. - rate - - if deterministic: - return inputs - else: - if rng is None: - rng = make_rng() - mask = random.bernoulli(rng, p=keep_prob, shape=inputs.shape) - return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)) diff --git a/flax/deprecated/nn/utils.py b/flax/deprecated/nn/utils.py deleted file mode 100644 index ce254b57d..000000000 --- a/flax/deprecated/nn/utils.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - NN base modules for JAX.""" - -import contextlib -import threading -import jax - - -class CallStack(object): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Utility for tracking data across a call stack.""" - - def __init__(self): - self._stack = threading.local() - - @property - def _frames(self): - if not hasattr(self._stack, 'frames'): - self._stack.frames = [] - return self._stack.frames - - @contextlib.contextmanager - def frame(self, data=None): - if data is None: - data = {} - self._frames.append(data) - try: - yield data - finally: - self._frames.pop(-1) - - def __iter__(self): - return iter(self._frames) - - def __len__(self): - return len(self._frames) - - def __getitem__(self, key): - return self._frames.__getitem__(key) - - -def classproperty(f): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - decorator that registers a function as a read-only property of the class.""" - - class _ClassProperty: - - def __get__(self, _, cls): - # python will call the __get__ magic function whenever the property is - # read from the class. - return f(cls) - - return _ClassProperty() - - -def _mains(): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Returns a list of currently active Jax tracers.""" - # TODO(jheek): consider re-introducing the tracer check - # for now we pretent there are never any tracers - return () - - -def _trace_level(main): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Returns the level of the trace of -infinity if it is None.""" - if main: - return main.level - return float('-inf') - - -def _current_trace(): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Returns the innermost Jax tracer.""" - tracers = _mains() - if tracers: - return tracers[-1] - return None - - -def _level_of_value(xs): - """DEPRECATION WARNING: - The `flax.nn` module is Deprecated, use `flax.linen` instead. - Learn more and find an upgrade guide at - https://github.com/google/flax/blob/main/flax/linen/README.md" - Returns the tracer level associated with a value if any.""" - xs = jax.tree_leaves(xs) - max_level = float('-inf') - # TODO(jheek): consider re-introducing the tracer check - # for x in xs: - # if hasattr(x, '_trace'): - # level = _trace_level(x._trace.main) - # max_level = max(level, max_level) - return max_level diff --git a/flax/linen/README.md b/flax/linen/README.md index 8a064f501..eacf3d4db 100644 --- a/flax/linen/README.md +++ b/flax/linen/README.md @@ -1,7 +1,7 @@ # Linen: A comfortable evolution of Flax -Linen is a rewrite of Flax Modules based on learning from our users and the broader JAX community. Linen improves on much of the former `flax.nn` API, such as submodule sharing and better support for non-trainable variables. -Moreover, Linen builds on a new "functional core", enabling direct usage of JAX transformations such as `vmap`, `remat` or `scan` inside your modules. +Linen is a neural network API developed based on learning from our users and the broader JAX community. Linen improves on much of the former `flax.nn` API (removed since v0.4.0), such as submodule sharing and better support for non-trainable variables. +Moreover, Linen builds on a "functional core", enabling direct usage of JAX transformations such as `vmap`, `remat` or `scan` inside your modules. In Linen, Modules behave much closer to vanilla Python objects, while still letting you opt-in to the concise single-method pattern many of our users love. diff --git a/flax/linen/linear.py b/flax/linen/linear.py index 80a1ac2bd..81b97a15d 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -209,9 +209,10 @@ class Conv(Module): be a sequence of integers. strides: an integer or a sequence of `n` integers, representing the inter-window strides (default: 1). - padding: either the string `'SAME'`, the string `'VALID'`, the string 'CIRCULAR'` (periodic boundary conditions), - or a sequence of `n` `(low, high)` integer pairs that give the padding to apply - before and after each spatial dimension. + padding: either the string `'SAME'`, the string `'VALID'`, the string + `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, + high)` integer pairs that give the padding to apply before and after each + spatial dimension. input_dilation: an integer or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `inputs` (default: 1). Convolution with input dilation `d` is equivalent to transposed @@ -329,11 +330,11 @@ class ConvTranspose(Module): kernel_size: shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer. For all other cases, it must be a sequence of integers. - strides: a sequence of `n` integers, representing the inter-window - strides. - padding: either the string `'SAME'`, the string `'VALID'`, the string 'CIRCULAR'` (periodic boundary conditions), - or a sequence of `n` `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. + strides: a sequence of `n` integers, representing the inter-window strides. + padding: either the string `'SAME'`, the string `'VALID'`, the string + `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, + high)` integer pairs that give the padding to apply before and after each + spatial dimension. kernel_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Convolution with kernel dilation is also known as 'atrous @@ -411,7 +412,7 @@ def __call__(self, inputs: Array) -> Array: # dimension. Padding should be done in such a way that the start of the # original input data inside the padded array is located at integer # number of periods - otherwise the result would be circularly shifted. - + # Compute period along each spatial dimension - it's input size scaled # by the stride. scaled_x_dims = [ diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index e6d575080..fb5256970 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -37,6 +37,13 @@ def _canonicalize_axes(rank: int, axes: Axes) -> Iterable[int]: axes = (axes,) return tuple(set([rank + axis if axis < 0 else axis for axis in axes])) +def _abs_sq(x): + """Computes the elementwise square of the absolute value |x|^2.""" + if jnp.iscomplexobj(x): + return lax.square(lax.real(x)) + lax.square(lax.imag(x)) + else: + return lax.square(x) + def _compute_stats(x: Array, axes: Axes, axis_name: Optional[str] = None, @@ -46,15 +53,17 @@ def _compute_stats(x: Array, axes: Axes, This implementation takes care of a few important details: - Computes in float32 precision for half precision inputs - mean and variance is computable in a single XLA fusion, - by using Var = E[x^2] - E[x]^2 instead of Var = E[(x - E[x])^2]). + by using Var = E[|x|^2] - |E[x]|^2 instead of Var = E[|x - E[x]|^2]). - Clips negative variances to zero which can happen due to roundoff errors. This avoids downstream NaNs. - Supports averaging across a parallel axis and subgroups of a parallel axis with a single `lax.pmean` call to avoid latency. """ - x = jnp.asarray(x, jnp.float32) + # promote x to at least float32, this avoids half precision computation + # but preserves double or complex floating points + x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x))) mean = jnp.mean(x, axes) - mean2 = jnp.mean(lax.square(x), axes) + mean2 = jnp.mean(_abs_sq(x), axes) if axis_name is not None: concatenated_mean = jnp.concatenate([mean, mean2]) mean, mean2 = jnp.split( @@ -62,9 +71,9 @@ def _compute_stats(x: Array, axes: Axes, concatenated_mean, axis_name=axis_name, axis_index_groups=axis_index_groups), 2) - # mean2 - lax.square(mean) is not guaranteed to be non-negative due + # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due # to floating point round-off errors. - var = jnp.maximum(0., mean2 - lax.square(mean)) + var = jnp.maximum(0., mean2 - _abs_sq(mean)) return mean, var diff --git a/flax/linen/partitioning.py b/flax/linen/partitioning.py index 75d89f236..4ad8f2172 100644 --- a/flax/linen/partitioning.py +++ b/flax/linen/partitioning.py @@ -438,8 +438,29 @@ def get_axis_names(axes_metadata): def _tree_map_axes(fn, tree): """Only map over AxisMetadata leaves in pytree - identity for other leaves.""" safe_fn = lambda x: fn(x) if isinstance(x, AxisMetadata) else x - return jax.tree_map(safe_fn, tree, - is_leaf=lambda x: isinstance(x, AxisMetadata)) + return jax.tree_map( + safe_fn, tree, is_leaf=lambda x: isinstance(x, AxisMetadata)) + + +def _is_mutable(axis_col: str) -> bool: + """Determines whether a collection is mutable. + + For example, when a module is called with `module.apply(..., mutable=['z'])`, + this function will return True for `axis_col='z'` and False otherwise. + + If there is no module in scope, this function will return True. + + Args: + axis_col: Name of the collection in question. + + Returns: + Whether it is currently mutable. + """ + last = nn.module._context.module_stack[-1] + if last: + return last.is_mutable_collection(axis_col) + else: + return True # uses this variable_transform to change 'params_axes' pytree as it bubbles @@ -449,14 +470,16 @@ def _add_axis_to_metadata(fn, axis_pos, axis_name, axis_col='params_axes'): # Handle In() / Out() scan axis marker types. if hasattr(axis_pos, 'axis'): axis_pos = axis_pos.axis + def insert_fn(x): names = list(x.names) names.insert(axis_pos, axis_name) return x.replace(names=tuple(names)) + return nn.transforms.map_variables( fn, axis_col, - mutable=True, + mutable=_is_mutable(axis_col), trans_out_fn=lambda tree: _tree_map_axes(insert_fn, tree)) diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index 51326e904..5c9a6c66d 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -186,10 +186,6 @@ def set_scopes_inner(x): for f in dataclasses.fields(module) if f.name != 'parent' and f.init} new_attrs = jax.tree_map(set_scopes_inner, attrs) new_module = module.clone(parent=scopes[idx], **new_attrs) - if module.name == 'Dense_0': - print('--compare--') - print(id(module)) - print(id(new_module)) idx += 1 return new_module new_module = set_scopes(module) @@ -247,27 +243,13 @@ def wrapped_fn(self, *args, **kwargs): # make a scope-function to transform def core_fn(scopes, *args, **kwargs): # make a clone of self using its arguments - m1 = self.m1 - m2 = self.m2 - print('--- before ---') - print(id(m1)) - print(id(m2)) attrs = {f.name: getattr(self, f.name) for f in dataclasses.fields(self) if f.name != 'parent' and f.init} # we reference module_class, not self.__class__ to avoid infinite loop cloned = module_class(parent=None, **attrs) cloned, args, kwargs = set_module_scopes(cloned, args, kwargs, scopes) - print('--- after ---') - print(id(cloned.m1)) - print(id(cloned.m2)) object.__setattr__(cloned, '_state', self._state.export()) # pylint: disable=protected-access - print('>args', *args) - print('>cloned', cloned) - print(cloned.m1.name) - print(cloned.m2.name) - print(fn) res = fn(cloned, *args, **kwargs) - return None self._state.reimport(cloned._state) # pylint: disable=protected-access _test_transformed_return_values(res, fn_name) return res diff --git a/flax/traverse_util.py b/flax/traverse_util.py index 183b537ca..19dc5703a 100644 --- a/flax/traverse_util.py +++ b/flax/traverse_util.py @@ -369,9 +369,7 @@ def iterate(self, inputs): def _get_params_dict(inputs): - if isinstance(inputs, flax.nn.Model): - return inputs.params - elif isinstance(inputs, (dict, flax.core.FrozenDict)): + if isinstance(inputs, (dict, flax.core.FrozenDict)): return flax.core.unfreeze(inputs) else: raise ValueError( @@ -393,10 +391,6 @@ class ModelParamTraversal(Traversal): See :class:`flax.optim.MultiOptimizer` for an example of how to use :class:`ModelParamTraversal` to update subsets of the parameter tree with a specific optimizer. - - Backward compatibility: - When using the old api the parameters can be encapsulated in a - :class:`flax.nn.Model` instance. """ def __init__(self, filter_fn): @@ -431,9 +425,7 @@ def update(self, fn, inputs): value = fn(value) new_dict[key] = value new_params = unflatten_dict(new_dict) - if isinstance(inputs, flax.nn.base.Model): - return inputs.replace(params=new_params) - elif isinstance(inputs, flax.core.FrozenDict): + if isinstance(inputs, flax.core.FrozenDict): return flax.core.FrozenDict(new_params) else: return new_params diff --git a/tests/checkpoints_test.py b/tests/checkpoints_test.py deleted file mode 100644 index 37efe7b25..000000000 --- a/tests/checkpoints_test.py +++ /dev/null @@ -1,341 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for flax.training.checkpoints.""" - -import copy -import os -import pathlib -from typing import Any - -from absl.testing import absltest -from absl.testing import parameterized -import flax -from flax import core -from flax import errors -from flax.training import checkpoints -import jax -from jax import numpy as jnp -from jax import test_util as jtu -import numpy as np -from tensorflow.io import gfile - -# Parse absl flags test_srcdir and test_tmpdir. -jax.config.parse_flags_with_absl() - - -def shuffle(l): - """Functional shuffle.""" - l = copy.copy(l) - np.random.shuffle(l) - return l - - -class InnerPreLinen(flax.nn.Module): - """Inner class based on pre-Linen flax.nn.""" - - def apply(self, x): - x = flax.nn.Conv(x, 10, (2, 2)) - x = flax.nn.normalization.BatchNorm(x, use_running_average=True) - return x - - -class ModelPreLinen(flax.nn.Module): - """Simple model based on pre-Linen flax.nn.""" - - def apply(self, inputs): - x = flax.nn.Conv(inputs, 10, (2, 2)) - x = InnerPreLinen(x, name='Inner_1') - x = x.reshape([x.shape[0], -1]) - x = flax.nn.normalization.BatchNorm(x, use_running_average=True) - x = flax.nn.Dense(x, 10) - x = flax.nn.log_softmax(x) - return x - - -class Inner(flax.linen.Module): - """Inner class based on flax.linen.""" - - @flax.linen.compact - def __call__(self, x): - x = flax.linen.Conv(10, (2, 2))(x) - x = flax.linen.normalization.BatchNorm(True)(x) - return x - - -class Model(flax.linen.Module): - """Simple model based on flax.linen.""" - - @flax.linen.compact - def __call__(self, inputs): - x = flax.linen.Conv(10, (2, 2))(inputs) - x = Inner()(x) - x = x.reshape([x.shape[0], -1]) - x = flax.linen.normalization.BatchNorm(True)(x) - x = flax.linen.Dense(10)(x) - x = flax.linen.log_softmax(x) - return x - - -@flax.struct.dataclass -class TrainState: - """Simple container that captures training state.""" - optimizer: flax.optim.Optimizer - model_state: Any - - -class CheckpointsTest(parameterized.TestCase): - - def test_naturalsort(self): - np.random.seed(0) - tests = [ - ['file_1', 'file_2', 'file_10', 'file_11', 'file_21'], - ['file_0.001', 'file_0.01', 'file_0.1', 'file_1'], - ['file_-3.0', 'file_-2', 'file_-1', 'file_0.0'], - ['file_1e1', 'file_1.0e2', 'file_1e3', 'file_1.0e4'], - ['file_1', 'file_2', 'file_9', 'file_1.0e1', 'file_11'], - ] - for test in tests: - self.assertEqual(test, checkpoints.natural_sort(shuffle(test))) - - def test_safe_normpath(self): - tests = ['./a/b/c', '/a//b/c', '/a/../b/c', 'a/b/./c', 'gs://a//b/c'] - expected = ['a/b/c', '/a/b/c', '/b/c', 'a/b/c', 'gs://a/b/c'] - for test, expect in zip(tests, expected): - self.assertEqual(expect, checkpoints.safe_normpath(test)) - - def test_save_restore_checkpoints(self): - tmp_dir = pathlib.Path(self.create_tempdir().full_path) - test_object0 = {'a': np.array([0, 0, 0], np.int32), - 'b': np.array([0, 0, 0], np.int32)} - test_object1 = {'a': np.array([1, 2, 3], np.int32), - 'b': np.array([1, 1, 1], np.int32)} - test_object2 = {'a': np.array([4, 5, 6], np.int32), - 'b': np.array([2, 2, 2], np.int32)} - new_object = checkpoints.restore_checkpoint( - tmp_dir, test_object0, prefix='test_') - jtu.check_eq(new_object, test_object0) - # Create leftover temporary checkpoint, which should be ignored. - gfile.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w') - checkpoints.save_checkpoint( - tmp_dir, test_object1, 0, prefix='test_', keep=1) - self.assertIn('test_0', os.listdir(tmp_dir)) - new_object = checkpoints.restore_checkpoint( - tmp_dir, test_object0, prefix='test_') - jtu.check_eq(new_object, test_object1) - checkpoints.save_checkpoint( - tmp_dir, test_object1, 1, prefix='test_', keep=1) - checkpoints.save_checkpoint( - tmp_dir, test_object2, 2, prefix='test_', keep=1) - new_object = checkpoints.restore_checkpoint( - tmp_dir, test_object0, prefix='test_') - jtu.check_eq(new_object, test_object2) - checkpoints.save_checkpoint( - tmp_dir, test_object2, 3, prefix='test_', keep=2) - checkpoints.save_checkpoint( - tmp_dir, test_object1, 4, prefix='test_', keep=2) - new_object = checkpoints.restore_checkpoint( - tmp_dir, test_object0, prefix='test_') - jtu.check_eq(new_object, test_object1) - new_object = checkpoints.restore_checkpoint( - tmp_dir, test_object0, step=3, prefix='test_') - jtu.check_eq(new_object, test_object2) - # Restore a specific path. - new_object = checkpoints.restore_checkpoint( - os.path.join(tmp_dir, 'test_3'), test_object0) - jtu.check_eq(new_object, test_object2) - # If a specific path is specified, but it does not exist, the same behavior - # as when a directory is empty should apply: the target is returned - # unchanged. - new_object = checkpoints.restore_checkpoint( - os.path.join(tmp_dir, 'test_not_there'), test_object0) - jtu.check_eq(new_object, test_object0) - with self.assertRaises(ValueError): - checkpoints.restore_checkpoint( - tmp_dir, test_object0, step=5, prefix='test_') - - def test_overwrite_checkpoints(self): - tmp_dir = self.create_tempdir().full_path - test_object0 = {'a': np.array([0, 0, 0], np.int32)} - test_object = {'a': np.array([1, 2, 3], np.int32)} - - checkpoints.save_checkpoint( - tmp_dir, test_object0, 0, keep=1) - with self.assertRaises(errors.InvalidCheckpointError): - checkpoints.save_checkpoint( - tmp_dir, test_object, 0, keep=1) - checkpoints.save_checkpoint( - tmp_dir, test_object, 0, keep=1, overwrite=True) - new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0) - jtu.check_eq(new_object, test_object) - checkpoints.save_checkpoint( - tmp_dir, test_object0, 2, keep=1, overwrite=True) - new_object = checkpoints.restore_checkpoint(tmp_dir, test_object) - jtu.check_eq(new_object, test_object0) - with self.assertRaises(errors.InvalidCheckpointError): - checkpoints.save_checkpoint( - tmp_dir, test_object, 1, keep=1) - checkpoints.save_checkpoint( - tmp_dir, test_object, 1, keep=1, overwrite=True) - new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0) - jtu.check_eq(new_object, test_object) - os.chdir(os.path.dirname(tmp_dir)) - rel_tmp_dir = './' + os.path.basename(tmp_dir) - checkpoints.save_checkpoint(rel_tmp_dir, test_object, 3, keep=1) - new_object = checkpoints.restore_checkpoint(rel_tmp_dir, test_object0) - jtu.check_eq(new_object, test_object) - non_norm_dir_path = tmp_dir + '//' - checkpoints.save_checkpoint(non_norm_dir_path, test_object, 4, keep=1) - new_object = checkpoints.restore_checkpoint(non_norm_dir_path, test_object0) - jtu.check_eq(new_object, test_object) - - @parameterized.parameters({'keep_every_n_steps': None}, - {'keep_every_n_steps': 7}) - def test_keep(self, keep_every_n_steps): - tmp_dir = self.create_tempdir().full_path - test_object = {'a': np.array([1, 2, 3], np.int32)} - steps_start = 17 - steps_end = 37 - keep = 3 - increment = 5 - - for step in range(steps_start, steps_end, increment): - checkpoints.save_checkpoint(tmp_dir, - test_object, - step=step, - keep=keep, - keep_every_n_steps=keep_every_n_steps) - - last_checkpoint = -float('inf') - for step in range(steps_start, steps_end, increment): - if ((steps_end - step) / increment <= keep) or (keep_every_n_steps and ( - step - last_checkpoint) >= keep_every_n_steps): - restored = checkpoints.restore_checkpoint( - tmp_dir, target=None, step=step) - jtu.check_eq(restored, test_object) - last_checkpoint = step - else: - with self.assertRaises(ValueError): - checkpoints.restore_checkpoint(tmp_dir, target=None, step=step) - - def test_save_restore_checkpoints_w_float_steps(self): - tmp_dir = self.create_tempdir().full_path - test_object0 = {'a': np.array([0, 0, 0], np.int32), - 'b': np.array([0, 0, 0], np.int32)} - test_object1 = {'a': np.array([1, 2, 3], np.int32), - 'b': np.array([1, 1, 1], np.int32)} - test_object2 = {'a': np.array([4, 5, 6], np.int32), - 'b': np.array([2, 2, 2], np.int32)} - # Create leftover temporary checkpoint, which should be ignored. - gfile.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w') - checkpoints.save_checkpoint( - tmp_dir, test_object1, 0.0, prefix='test_', keep=1) - self.assertIn('test_0.0', os.listdir(tmp_dir)) - new_object = checkpoints.restore_checkpoint( - tmp_dir, test_object0, prefix='test_') - jtu.check_eq(new_object, test_object1) - checkpoints.save_checkpoint( - tmp_dir, test_object1, 2.0, prefix='test_', keep=1) - with self.assertRaises(errors.InvalidCheckpointError): - checkpoints.save_checkpoint( - tmp_dir, test_object2, 1.0, prefix='test_', keep=1) - checkpoints.save_checkpoint( - tmp_dir, test_object2, 3.0, prefix='test_', keep=2) - self.assertIn('test_3.0', os.listdir(tmp_dir)) - self.assertIn('test_2.0', os.listdir(tmp_dir)) - jtu.check_eq(new_object, test_object1) - - def test_save_restore_checkpoints_target_none(self): - tmp_dir = self.create_tempdir().full_path - test_object0 = {'a': np.array([0, 0, 0], np.int32), - 'b': np.array([0, 0, 0], np.int32)} - # Target pytree is a dictionary, so it's equal to a restored state_dict. - checkpoints.save_checkpoint(tmp_dir, test_object0, 0) - new_object = checkpoints.restore_checkpoint(tmp_dir, target=None) - jtu.check_eq(new_object, test_object0) - # Target pytree it's a tuple, check the expected state_dict is recovered. - test_object1 = (np.array([0, 0, 0], np.int32), - np.array([1, 1, 1], np.int32)) - checkpoints.save_checkpoint(tmp_dir, test_object1, 1) - new_object = checkpoints.restore_checkpoint(tmp_dir, target=None) - expected_new_object = {str(k): v for k, v in enumerate(test_object1)} - jtu.check_eq(new_object, expected_new_object) - - def test_convert_pre_linen(self): - params = checkpoints.convert_pre_linen({ - 'mod_0': { - 'submod1_0': {}, - 'submod2_1': {}, - 'submod1_2': {}, - }, - 'mod2_2': { - 'submod2_2_0': {} - }, - 'mod2_11': { - 'submod2_11_0': {} - }, - 'mod2_1': { - 'submod2_1_0': {} - }, - }) - self.assertDictEqual( - core.unfreeze(params), { - 'mod_0': { - 'submod1_0': {}, - 'submod1_1': {}, - 'submod2_0': {}, - }, - 'mod2_0': { - 'submod2_1_0': {} - }, - 'mod2_1': { - 'submod2_2_0': {} - }, - 'mod2_2': { - 'submod2_11_0': {} - }, - }) - - def test_convert_checkpoint(self): - inputs = jnp.ones([2, 5, 5, 1]) - rng = jax.random.PRNGKey(0) - # pre-Linen. - with flax.nn.stateful() as model_state: - y, params = ModelPreLinen.init(rng, inputs) - pre_linen_optimizer = flax.optim.GradientDescent(0.1).create(params) - train_state = TrainState( - optimizer=pre_linen_optimizer, model_state=model_state) - state_dict = flax.serialization.to_state_dict(train_state) - # Linen. - model = Model() - variables = model.init(rng, inputs) - optimizer = flax.optim.GradientDescent(0.1).create(variables['params']) - optimizer = optimizer.restore_state( - flax.core.unfreeze( - checkpoints.convert_pre_linen(state_dict['optimizer']))) - optimizer = optimizer.apply_gradient(variables['params']) - batch_stats = checkpoints.convert_pre_linen( - flax.traverse_util.unflatten_dict({ - tuple(k.split('/')[1:]): v - for k, v in model_state.as_dict().items() - })) - y, updated_state = model.apply( - dict(params=optimizer.target, batch_stats=batch_stats), - inputs, - mutable=['batch_stats']) - del y, updated_state # not used. - - -if __name__ == '__main__': - absltest.main() diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 7e890b805..892974d50 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -105,6 +105,27 @@ def test_batch_norm(self): np.testing.assert_allclose( ema['var'], 0.9 + 0.1 * x.var((0, 1), keepdims=False), rtol=1e-4) + def test_batch_norm_complex(self): + rng = random.PRNGKey(0) + key1, key2 = random.split(rng) + x = random.normal(key1, (4, 3, 2), dtype=jnp.complex64) + model_cls = nn.BatchNorm(momentum=0.9, use_running_average=False, dtype=jnp.complex64) + y, initial_params = model_cls.init_with_output(key2, x) + + mean = y.mean((0, 1)) + var = y.var((0, 1)) + np.testing.assert_allclose(mean, np.array([0., 0.]), atol=1e-4) + np.testing.assert_allclose(var, np.array([1., 1.]), rtol=1e-4) + self.assertEqual(mean.dtype, jnp.complex64) + + y, vars_out = model_cls.apply(initial_params, x, mutable=['batch_stats']) + + ema = vars_out['batch_stats'] + np.testing.assert_allclose( + ema['mean'], 0.1 * x.mean((0, 1), keepdims=False), atol=1e-4) + np.testing.assert_allclose( + ema['var'], 0.9 + 0.1 * x.var((0, 1), keepdims=False), rtol=1e-4) + def test_layer_norm(self): rng = random.PRNGKey(0) key1, key2 = random.split(rng) @@ -164,6 +185,7 @@ def __call__(self, x): (y1, y2), variables = model.init_with_output(key, x) np.testing.assert_allclose(y1, y2, rtol=1e-4) + class StochasticTest(absltest.TestCase): def test_dropout(self): diff --git a/tests/linen/partitioning_test.py b/tests/linen/partitioning_test.py index dfe0eceae..b29970690 100644 --- a/tests/linen/partitioning_test.py +++ b/tests/linen/partitioning_test.py @@ -240,6 +240,9 @@ def __call__(self, x): p_rules = (('emb', 'data'), ('mlp', 'model'), ('batch', 'data')) with partitioning.axis_rules(p_rules): variables = Scanned(L, E).init(k, x) + + # Ensure that the module can be called when 'params_axes' is not mutable. + Scanned(L, E).apply(variables, x) self.assertIn('params', variables) self.assertIn('params_axes', variables) self.assertIn('stats', variables) diff --git a/tests/nn_attention_test.py b/tests/nn_attention_test.py deleted file mode 100644 index d32c94aab..000000000 --- a/tests/nn_attention_test.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for flax.nn.attention.""" - -from absl.testing import absltest -from absl.testing import parameterized -from flax import jax_utils -from flax.deprecated import nn -import jax -from jax import lax -from jax import random -from jax.nn import initializers -import jax.numpy as jnp - -import numpy as np - -# Parse absl flags test_srcdir and test_tmpdir. -jax.config.parse_flags_with_absl() - - -class AttentionTest(parameterized.TestCase): - - def test_multihead_self_attention(self): - rng = random.PRNGKey(0) - x = jnp.ones((4, 2, 3, 5)) - sa_module = nn.SelfAttention.partial( - num_heads=8, - attention_axis=(1, 2), - qkv_features=16, - kernel_init=initializers.ones, - bias_init=initializers.zeros, - ) - y, initial_params = sa_module.init(rng, x) - self.assertEqual(y.shape, x.shape) - - def test_multihead_encoder_decoder_attention(self): - rng = random.PRNGKey(0) - q = jnp.ones((4, 2, 3, 5)) - kv = jnp.ones((4, 2, 3, 5)) - sa_module = nn.MultiHeadDotProductAttention.partial( - num_heads=8, - attention_axis=(1, 2), - qkv_features=16, - kernel_init=initializers.ones, - bias_init=initializers.zeros, - ) - y, _ = sa_module.init(rng, q, kv) - self.assertEqual(y.shape, q.shape) - - def test_multihead_self_attention_w_dropout(self): - rng = random.PRNGKey(0) - x = jnp.ones((4, 2, 3, 5)) - sa_module = nn.SelfAttention.partial( - num_heads=8, - attention_axis=(1, 2), - qkv_features=16, - kernel_init=initializers.ones, - bias_init=initializers.zeros, - dropout_rate=0.1, - ) - rng1, rng2 = random.split(rng) - with nn.stochastic(rng1): - y, initial_params = sa_module.init(rng2, x) - self.assertEqual(y.shape, x.shape) - - def test_causal_mask_1d(self): - """Tests autoregresive masking for 1d attention.""" - key = jnp.ones((4, 5, 2, 16)) # (bs, dim1, dim2, heads, channel) - att_axis = (1,) - mask_1d = nn.attention._make_causal_mask( - key, attention_axis=att_axis, self_mask=False) - - ts = np.arange(key.shape[1]) - mask_1d_simple = (ts[:, None] >= ts[None, :])[None, None, :, :] - np.testing.assert_allclose(mask_1d, mask_1d_simple,) - - def test_causal_mask_2d(self): - """Tests autoregresive masking for 2d attention.""" - key = jnp.ones((4, 5, 5, 2, 16)) # (bs, dim1, dim2, heads, channel) - - # masking when dealing with nd attention weights - # w_nd_shape = (4, 5, 5, 5, 5, 2) - att_axis = (1, 2) - mask_nd = nn.attention._make_causal_mask( - key, attention_axis=att_axis, self_mask=False) - - # masking when dealing with 1d attention weights - # w_1d_shape = (4, 5*5, 5*5, 2) - ts = np.arange(25) - mask_1d = (ts[:, None] >= ts[None, :])[None, None, :, :] - - np.testing.assert_allclose(mask_nd.reshape(mask_1d.shape), mask_1d, - atol=1e-9) - - @parameterized.parameters([((5,), (1,)), - ((5, 6), (1,)), - ((5, 6), (2,)), - ((5, 6), (1, 2)),]) - def test_decoding(self, spatial_shape, attn_dims): - bs = 2 - num_heads = 3 - num_features = 4 - rng = random.PRNGKey(0) - key1, key2 = random.split(rng) - inputs = random.normal( - key1, (bs,) + spatial_shape + (num_heads * num_features,)) - module = nn.SelfAttention.partial( - num_heads=num_heads, - qkv_features=num_heads * num_features, - attention_axis=attn_dims, - causal_mask=True, - precision=lax.Precision.HIGHEST) - - with nn.attention.Cache().mutate() as cache_def: - _, initial_params = module.init_by_shape( - key2, [(inputs.shape, inputs.dtype)], cache=cache_def) - model = nn.Model(module, initial_params) - y_ref = jax.jit(lambda f, x: f(x))(model, inputs) - - # feed the inputs sequentially to simulate decoding - cache0 = cache_def.initialize_cache((bs,) + spatial_shape) - def body_fn(cache, x): - with cache.mutate() as new_cache: - y = model(x, cache=new_cache) - return new_cache, y - # scan_in_dim supports scanning multiple dims - _, y = jax_utils.scan_in_dim(body_fn, cache0, inputs, - axis=attn_dims, keepdims=True) - - np.testing.assert_allclose(y_ref, y, atol=1e-5) - - def test_autoregresive_receptive_field_1d(self): - """Tests the autoregresive self-attention receptive field.""" - rng = random.PRNGKey(0) - rng1, rng2 = random.split(rng, num=2) - - def model_loss(inputs, pos): - out = model(inputs) - assert out.shape == input_shape - assert len(out.shape) == 3 - return out[0, pos, :].sum() - - grad_fn = jax.jit(jax.grad(model_loss)) - - def get_receptive_field_1d(pos): - g = grad_fn(inputs, pos)[0, :, :] - return jnp.any((jnp.abs(g) > 1e-5).astype(jnp.uint32), axis=-1) - - length = 10 - dim = 1 - num_heads = 1 - input_shape = (1, length, dim) - inputs = random.normal(rng2, input_shape) - - module = nn.attention.SelfAttention.partial( - num_heads=num_heads, - causal_mask=True, - kernel_init=jax.nn.initializers.ones) - _, initial_params = module.init_by_shape( - rng1, [((1,) + (length, dim), jnp.float32)]) - model = nn.Model(module, initial_params) - - - - for i in range(length): - deps = get_receptive_field_1d(i) - assert (deps[:i] == 1).all(), ('Receptive Field Error: Some of the ' - 'previous postions are not reachable ' - 'in autoregressive self-attention.') - if i != length - 1: - k = i + 1 - assert (deps[k:] == 0).all(), ('Receptive Field Error: Some of the ' - 'future postions are reachable in ' - 'autoregressive self-attention.') - - -if __name__ == '__main__': - absltest.main() diff --git a/tests/nn_linear_test.py b/tests/nn_linear_test.py deleted file mode 100644 index 181dde4f1..000000000 --- a/tests/nn_linear_test.py +++ /dev/null @@ -1,285 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for flax.nn.linear.""" - -import functools - -from absl.testing import absltest -from absl.testing import parameterized - -from flax.deprecated import nn - -import jax -from jax import random -from jax.nn import initializers -import jax.numpy as jnp - -import numpy as np - -# Parse absl flags test_srcdir and test_tmpdir. -jax.config.parse_flags_with_absl() - - -class LinearTest(parameterized.TestCase): - - def test_dense(self): - rng = random.PRNGKey(0) - x = jnp.ones((1, 3)) - dense_module = nn.Dense.partial( - features=4, - kernel_init=initializers.ones, - bias_init=initializers.ones, - ) - y, _ = dense_module.init(rng, x) - self.assertEqual(y.shape, (1, 4)) - np.testing.assert_allclose(y, np.full((1, 4), 4.)) - - def test_dense_extra_batch_dims(self): - rng = random.PRNGKey(0) - x = jnp.ones((1, 2, 3)) - dense_module = nn.Dense.partial( - features=4, - kernel_init=initializers.ones, - bias_init=initializers.ones, - ) - y, _ = dense_module.init(rng, x) - np.testing.assert_allclose(y, np.full((1, 2, 4), 4.)) - - def test_dense_no_bias(self): - rng = random.PRNGKey(0) - x = jnp.ones((1, 3)) - dense_module = nn.Dense.partial( - features=4, - bias=False, - kernel_init=initializers.ones, - ) - y, _ = dense_module.init(rng, x) - np.testing.assert_allclose(y, np.full((1, 4), 3.)) - - def test_dense_is_dense_general(self): - x = jax.random.normal(random.PRNGKey(0), (5, 3)) - dense_module = nn.Dense.partial( - features=4, - bias=True, - bias_init=initializers.normal(), - ) - y1, _ = dense_module.init(random.PRNGKey(1), x) - dg_module = nn.DenseGeneral.partial( - features=4, - bias=True, - bias_init=initializers.normal(), - ) - y2, _ = dg_module.init(random.PRNGKey(1), x) - - np.testing.assert_allclose(y1, y2) - - def test_dense_general_batch_dim_raises(self): - rng = random.PRNGKey(0) - x = jnp.ones((1, 3, 2, 5)) - with self.assertRaises(ValueError): - dg_module = nn.DenseGeneral.partial( - features=4, - batch_dims=(0, 2), - kernel_init=initializers.ones, - bias_init=initializers.ones, - ) - dg_module.init(rng, x) - - def test_dense_general_two_out(self): - rng = random.PRNGKey(0) - x = jnp.ones((1, 3)) - dg_module = nn.DenseGeneral.partial( - features=(2, 2), - kernel_init=initializers.ones, - bias_init=initializers.ones, - ) - y, _ = dg_module.init(rng, x) - np.testing.assert_allclose(y, np.full((1, 2, 2), 4.)) - - def test_dense_general_two_in(self): - rng = random.PRNGKey(0) - x = jnp.ones((1, 2, 2)) - dg_module = nn.DenseGeneral.partial( - features=3, - axis=(-2, 2), - kernel_init=initializers.ones, - bias_init=initializers.ones, - ) - y, _ = dg_module.init(rng, x) - np.testing.assert_allclose(y, np.full((1, 3), 5.)) - - def test_dense_general_batch_dim(self): - rng = random.PRNGKey(0) - x = jnp.ones((2, 1, 3, 5)) - - state = {'counter': 0.} - def _counter_init(rng, shape, dtype, state): - del rng, dtype - state['counter'] += 1. - return jnp.full(shape, state['counter']) - counter_init = functools.partial(_counter_init, state=state) - - dg_module = nn.DenseGeneral.partial( - features=7, - axis=(3, -2), - batch_dims=0, - bias_init=initializers.ones, - kernel_init=counter_init, - ) - y, _ = dg_module.init(rng, x) - target = np.concatenate( - [np.full((1, 1, 7), 16.), np.full((1, 1, 7), 31.)], axis=0) - np.testing.assert_allclose(y, target) - - @parameterized.parameters([((-2, 3), (), 'bijk,jklm->bilm'), - ((3, -2), (), 'bijk,kjlm->bilm'), - ((-2, 3), (0,), 'bijk,bjklm->bilm')]) - def test_dense_general_vs_numpy(self, axis, batch_dims, einsum_expr): - rng = random.PRNGKey(0) - x = jnp.ones((16, 8, 9, 10)) - - dg_module = nn.DenseGeneral.partial( - features=(11, 12), - axis=axis, - batch_dims=batch_dims, - bias_init=initializers.ones, - kernel_init=initializers.normal(), - ) - y, initial_params = dg_module.init(rng, x) - dg_module = nn.Model(dg_module, initial_params) - target = np.einsum(einsum_expr, x, dg_module.params['kernel']) + 1. - np.testing.assert_allclose(y, target, atol=1e-6) - - @parameterized.parameters([((3,),), (3,)]) - def test_conv(self, kernel_size): - rng = random.PRNGKey(0) - x = jnp.ones((1, 8, 3)) - conv_module = nn.Conv.partial( - features=4, - kernel_size=kernel_size, - padding='VALID', - kernel_init=initializers.ones, - bias_init=initializers.ones, - ) - y, initial_params = conv_module.init(rng, x) - model = nn.Model(conv_module, initial_params) - self.assertEqual(model.params['kernel'].shape, (3, 3, 4)) - np.testing.assert_allclose(y, np.full((1, 6, 4), 10.)) - - @parameterized.parameters([((3,),), (3,)]) - def test_single_input_conv(self, kernel_size): - rng = random.PRNGKey(0) - x = jnp.ones((8, 3)) - conv_module = nn.Conv.partial( - features=4, - kernel_size=kernel_size, - padding='VALID', - kernel_init=initializers.ones, - bias_init=initializers.ones, - ) - y, initial_params = conv_module.init(rng, x) - model = nn.Model(conv_module, initial_params) - self.assertEqual(model.params['kernel'].shape, (3, 3, 4)) - np.testing.assert_allclose(y, np.full((6, 4), 10.)) - - @parameterized.parameters([((3,),), (3,)]) - def test_group_conv(self, kernel_size): - rng = random.PRNGKey(0) - x = jnp.ones((1, 8, 4)) - conv_module = nn.Conv.partial( - features=4, - kernel_size=kernel_size, - feature_group_count=2, - padding='VALID', - kernel_init=initializers.ones, - bias_init=initializers.ones, - ) - y, initial_params = conv_module.init(rng, x) - model = nn.Model(conv_module, initial_params) - self.assertEqual(model.params['kernel'].shape, (3, 2, 4)) - np.testing.assert_allclose(y, np.full((1, 6, 4), 7.)) - - @parameterized.parameters([((3,),), (3,)]) - def test_conv_transpose(self, kernel_size): - rng = random.PRNGKey(0) - x = jnp.ones((1, 8, 3)) - conv_transpose_module = nn.ConvTranspose.partial( - features=4, - kernel_size=kernel_size, - padding='VALID', - kernel_init=initializers.ones, - bias_init=initializers.ones, - ) - y, initial_params = conv_transpose_module.init(rng, x) - model = nn.Model(conv_transpose_module, initial_params) - self.assertEqual(model.params['kernel'].shape, (3, 3, 4)) - correct_ans = np.array([[[ 4., 4., 4., 4.], - [ 7., 7., 7., 7.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [ 7., 7., 7., 7.], - [ 4., 4., 4., 4.]]]) - np.testing.assert_allclose(y, correct_ans) - - @parameterized.parameters([((3,),), (3,)]) - def test_single_input_conv_transpose(self, kernel_size): - rng = random.PRNGKey(0) - x = jnp.ones((8, 3)) - conv_transpose_module = nn.ConvTranspose.partial( - features=4, - kernel_size=kernel_size, - padding='VALID', - kernel_init=initializers.ones, - bias_init=initializers.ones, - ) - y, initial_params = conv_transpose_module.init(rng, x) - model = nn.Model(conv_transpose_module, initial_params) - self.assertEqual(model.params['kernel'].shape, (3, 3, 4)) - correct_ans = np.array([[ 4., 4., 4., 4.], - [ 7., 7., 7., 7.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [ 7., 7., 7., 7.], - [ 4., 4., 4., 4.]]) - np.testing.assert_allclose(y, correct_ans) - - def test_embed(self): - rng = random.PRNGKey(0) - x = jnp.arange(4)[None] - dummy_embedding = jnp.broadcast_to( - jnp.arange(4)[..., None], (4, 3)).astype(jnp.float32) - embed_module = nn.Embed.partial( - num_embeddings=4, - features=3, - embedding_init=lambda rng, shape: dummy_embedding, - ) - y, initial_params = embed_module.init(rng, x) - model = nn.Model(embed_module, initial_params) - np.testing.assert_allclose(y, dummy_embedding[None]) - - z = model.attend(jnp.ones((3,))) - np.testing.assert_allclose(z, 3. * jnp.arange(4)) - - -if __name__ == '__main__': - absltest.main() diff --git a/tests/nn_test.py b/tests/nn_test.py deleted file mode 100644 index 9a0ceb76f..000000000 --- a/tests/nn_test.py +++ /dev/null @@ -1,749 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for flax.nn.""" - -import threading -from absl.testing import absltest - -from flax.deprecated import nn - -import jax -from jax import random -from jax import test_util as jtu -from jax.nn import initializers -import jax.numpy as jnp - - -import numpy as np - -# Parse absl flags test_srcdir and test_tmpdir. -jax.config.parse_flags_with_absl() - - -class DummyModule(nn.Module): - - def apply(self, x): - bias = self.param('bias', x.shape, initializers.ones) - return x + bias - - -class NestedModule(nn.Module): - - def apply(self, x): - x = DummyModule(x, name='dummy_0') - x = DummyModule(x, name='dummy_1') - return x - - -class NestedModel(nn.Module): - - def apply(self, x, model): - x = DummyModule(x, name='dummy_0') - x = model(x, name='inner_model') - return x - - -class DataDependentInitModule(nn.Module): - - def apply(self, x): - bias = self.param('bias', x.shape, lambda rng, shape: x + 1.) - return x + bias - - -class CollectionModule(nn.Module): - - def apply(self, x, activations=None): - bias = self.param('bias', x.shape, initializers.ones) - y = x + bias - if activations: - previous_activation = activations.retrieve() - activations.store(y) - return y, previous_activation - else: - return y, None - - -class LoopModule(nn.Module): - - def apply(self, x, activations=None): - module = CollectionModule.shared(activations=activations, name='dummy') - for _ in range(2): - x, _ = module(x) - return x - - -class ModuleTest(absltest.TestCase): - - def test_init_module(self): - rng = random.PRNGKey(0) - x = jnp.array([1.]) - y, params = DummyModule.init(rng, x) - y2 = DummyModule.call(params, x) - self.assertEqual(y, y2) - self.assertEqual(y, jnp.array([2.])) - self.assertEqual(params, {'bias': jnp.array([1.])}) - - def test_init_by_shape_module(self): - rng = random.PRNGKey(0) - x = jnp.array([1.]) - y, params = DummyModule.init_by_shape(rng, [(x.shape, x.dtype)]) - y2 = DummyModule.call(params, x) - self.assertEqual(y.shape, y2.shape) - self.assertEqual(y2, jnp.array([2.])) - self.assertEqual(params, {'bias': jnp.array([1.])}) - - def test_model(self): - rng = random.PRNGKey(0) - x = jnp.array([1.]) - _, params = DummyModule.init(rng, x) - model = nn.Model(DummyModule, params) - y = model(x) - self.assertEqual(y, jnp.array([2.])) - y2 = jax.jit(model)(x) - self.assertEqual(y2, jnp.array([2.])) - - def test_shared_module(self): - rng = random.PRNGKey(0) - x = jnp.array([1.]) - _, initial_params = LoopModule.init(rng, x) - model = nn.Model(LoopModule, initial_params) - y = model(x) - self.assertEqual(y, jnp.array([3.])) - self.assertEqual(model.params, {'dummy': {'bias': jnp.array([1.])}}) - - def test_name_collsion(self): - class FaultyModule(nn.Module): - - def apply(self, x): - for _ in range(2): - DummyModule(x, name='dummy') - - x = jnp.array([1.]) - with self.assertRaises(ValueError): - FaultyModule.init(random.PRNGKey(0), x) - - def test_sharing_name_collsion(self): - class FaultyModule(nn.Module): - - def apply(self, x): - for _ in range(2): - module = DummyModule.shared(name='dummy') - module(x) - - x = jnp.array([1.]) - with self.assertRaises(ValueError): - FaultyModule.init(random.PRNGKey(0), x) - - def test_sharing_name_on_apply(self): - class FaultyModule(nn.Module): - - def apply(self, x): - module = DummyModule.shared(name='dummy') - for _ in range(2): - module(x, name='dummy2') - - x = jnp.array([1.]) - with self.assertRaises(ValueError): - FaultyModule.init(random.PRNGKey(0), x) - - def test_shared_module_called_in_other_frame(self): - """Test that shared modules only appear once in parameters. - - Concretely, create a shared submodule, then pass it in to - a child module and apply it there. Test that the parameters - are only stored once, in the frame where the shared module - was created. - """ - - class SubModule(nn.Module): - - def apply(self): - self.param('params', (), initializers.zeros) - - class UseSharedModule(nn.Module): - - def apply(self, submodule): - submodule() - - class TopLevel(nn.Module): - - def apply(self): - submodule = SubModule.shared(name='shared') - submodule() - UseSharedModule(submodule, name='use_shared') - - _, params = TopLevel.init(random.PRNGKey(0)) - self.assertEqual({ - 'shared': {'params': jnp.zeros(())}, - 'use_shared': {}, - }, params) - - def test_module_decorator(self): - @nn.module - def MyModule(x): # pylint: disable=invalid-name - return DummyModule(x) - - self.assertEqual(MyModule.__name__, 'MyModule') - self.assertTrue(issubclass(MyModule, nn.Module)) - - rng = random.PRNGKey(0) - x = jnp.array([1.]) - y, params = MyModule.init(rng, x) - y2 = MyModule.call(params, x) - self.assertEqual(y, y2) - self.assertEqual(y, jnp.array([2.])) - - def test_partial_application(self): - rng = random.PRNGKey(0) - x = jnp.array([1.]) - dummy_module = DummyModule.partial(x=x) # partially apply the inputs - self.assertEqual(DummyModule.__name__, dummy_module.__name__) - self.assertEqual(DummyModule.__qualname__, dummy_module.__qualname__) - y, initial_params = dummy_module.init(rng) - model = nn.Model(dummy_module, initial_params) - y2 = model() - self.assertEqual(y.shape, y2.shape) - self.assertEqual(y2, jnp.array([2.])) - - def test_nested_model(self): - x = jnp.array([1.]) - _, inner_initial_params = DummyModule.init(random.PRNGKey(0), x) - inner_model = nn.Model(DummyModule, inner_initial_params) - _, initial_params = NestedModel.init(random.PRNGKey(1), x, inner_model) - model = nn.Model(NestedModel, initial_params) - y = model(x, inner_model) - self.assertEqual(y, jnp.array([3.])) - - def test_capture_module_outputs(self): - x = jnp.array([1.]) - _, initial_params = NestedModule.init(random.PRNGKey(0), x) - model = nn.Model(NestedModule, initial_params) - with nn.capture_module_outputs() as activations: - model(x) - expected_activations = { - '/': [x + 2], - '/dummy_0': [x + 1], - '/dummy_1': [x + 2], - } - self.assertEqual(activations.as_dict(), expected_activations) - - def test_nested_model_capture_outputs(self): - x = jnp.array([1.]) - _, inner_initial_params = DummyModule.init(random.PRNGKey(0), x) - inner_model = nn.Model(DummyModule, inner_initial_params) - _, initial_params = NestedModel.init(random.PRNGKey(1), x, inner_model) - model = nn.Model(NestedModel, initial_params) - with nn.capture_module_outputs() as activations: - model(x, inner_model) - expected_activations = { - '/': [x + 2], - '/dummy_0': [x + 1], - '/inner_model': [x + 2], - } - self.assertEqual(activations.as_dict(), expected_activations) - - def test_truncated_module(self): - x = jnp.array([1.]) - _, initial_params = NestedModule.init(random.PRNGKey(0), x) - model = nn.Model(NestedModule, initial_params) - model = model.truncate_at('/dummy_0') - y = model(x) - self.assertEqual(y, [x + 1]) - - def test_call_module_method(self): - class MultiMethod(nn.Module): - - def apply(self, x): - return x + self.param('bias', x.shape, initializers.ones) - - @nn.module_method - def l2(self): - return jnp.sum(self.get_param('bias') ** 2) - - class MultiMethodModel(nn.Module): - - def apply(self, x): - layer = MultiMethod.shared() - layer(x) # init - return layer.l2() - - self.assertEqual( - MultiMethod.l2.__qualname__, - MultiMethod.__qualname__ + '.l2') - - x = jnp.array([1., 2.]) - - _, params = MultiMethod.init(random.PRNGKey(0), x) - model = nn.Model(MultiMethod, params) - self.assertEqual(model.l2(), 2.) - - y, _ = MultiMethodModel.init(random.PRNGKey(0), x) - self.assertEqual(y, 2.) - - def test_module_state(self): - class StatefulModule(nn.Module): - - def apply(self, x, coll=None): - state = self.state('state', x.shape, nn.initializers.zeros, - collection=coll) - state.value += x - - x = jnp.array([1.,]) - # no collection should raise an error - with self.assertRaises(ValueError): - StatefulModule.call({}, x) - - # pass collection explicitly - with nn.Collection().mutate() as state: - self.assertEqual(state.as_dict(), {}) - StatefulModule.init(random.PRNGKey(0), x, state) - self.assertEqual(state.as_dict(), {'/': {'state': x}}) - self.assertEqual(state.as_dict(), {'/': {'state': x}}) - with state.mutate() as new_state: - # assert new_state is a clone of state - self.assertEqual(new_state.as_dict(), state.as_dict()) - StatefulModule.call({}, x, new_state) - self.assertEqual(new_state.as_dict(), {'/': {'state': x + x}}) - - # use stateful - with nn.stateful() as state: - self.assertEqual(state.as_dict(), {}) - StatefulModule.init(random.PRNGKey(0), x) - self.assertEqual(state.as_dict(), {'/': {'state': x}}) - with nn.stateful(state) as new_state: - # assert new_state is a clone of state - self.assertEqual(new_state.as_dict(), state.as_dict()) - StatefulModule.call({}, x) - self.assertEqual(new_state.as_dict(), {'/': {'state': x + x}}) - self.assertEqual(new_state.as_dict(), {'/': {'state': x + x}}) - - def test_parameter_rng(self): - @nn.module - def model(x): - return nn.Dense(x, features=2, name='dummy', - bias_init=nn.initializers.normal()) - rng = random.PRNGKey(0) - _, params = model.init(rng, jnp.ones((1, 1))) - dense_rng = nn.base._fold_in_str(rng, 'dummy') - kernel_rng = nn.base._fold_in_str(dense_rng, 'kernel') - bias_rng = nn.base._fold_in_str(dense_rng, 'bias') - kernel = nn.linear.default_kernel_init(kernel_rng, (1, 2)) - bias = nn.initializers.normal()(bias_rng, (2,)) - np.testing.assert_allclose(kernel, params['dummy']['kernel']) - np.testing.assert_allclose(bias, params['dummy']['bias']) - -class CollectionTest(absltest.TestCase): - - def test_collection_store_and_retrieve(self): - rng = random.PRNGKey(0) - x = jnp.array([1.]) - with nn.Collection().mutate() as activations: - (_, y), initial_params = CollectionModule.init(rng, x, activations) - model = nn.Model(CollectionModule, initial_params) - self.assertEqual(y, None) - with activations.mutate() as new_activations: - _, y2 = model(x, new_activations) - self.assertEqual(y2, jnp.array([2.])) - - def test_collection_multiple_calls(self): - rng = random.PRNGKey(0) - with nn.Collection().mutate() as activations: - x = jnp.array([1.]) - _, _ = LoopModule.init(rng, x, activations) - expected_state = { - '/dummy': jnp.array([3.]), - } - self.assertEqual(activations.state, expected_state) - - def test_collection_multiple_roots(self): - rng = random.PRNGKey(0) - with nn.Collection().mutate() as activations: - x = jnp.array([1.]) - LoopModule.init(rng, x, activations, name='a') - LoopModule.init(rng, x, activations, name='b') - expected_state = { - '/a/dummy': jnp.array([3.]), - '/b/dummy': jnp.array([3.]), - } - self.assertEqual(activations.state, expected_state) - with self.assertRaises(ValueError): - with nn.Collection().mutate() as activations: - x = jnp.array([1.]) - LoopModule.init(rng, x, activations) - LoopModule.init(rng, x, activations) - - def test_mutable_collection_cannot_be_passed_to_jax(self): - with nn.Collection().mutate() as collection: - def fn(col): - return col - with self.assertRaises(ValueError): - jax.jit(fn)(collection) - - def test_collection_lookup(self): - state = { - '/dummy/sub': 1, - } - collection = nn.Collection(state=state) - root = nn.base._ModuleFrame(None) - frame = nn.base._ModuleFrame('dummy', parent=root) - with nn.base._module_stack.frame(root): - with nn.base._module_stack.frame(frame): - self.assertEqual(collection['/dummy/sub'], 1) - - def test_collection_inside_module(self): - class NestedCollection(nn.Module): - - def apply(self, x): - with nn.Collection().mutate() as activations: - LoopModule(x, activations, name='a') - LoopModule(x, activations, name='b') - return activations - - rng = random.PRNGKey(0) - x = jnp.array([1.]) - activations, _ = NestedCollection.init(rng, x, name='nested') - expected_state = { - '/a/dummy': jnp.array([3.]), - '/b/dummy': jnp.array([3.]), - } - self.assertEqual(activations.as_dict(), expected_state) - - def test_collection_store_fails_if_not_in_module(self): - @nn.module - def test(): - with nn.Collection().mutate() as coll: - coll.store(1) - pattern = 'State should be stored from within a module' - with self.assertRaisesRegex(ValueError, pattern): - test.init(random.PRNGKey(0)) - - def test_collection_store_fails_if_out_of_scope(self): - @nn.module - def stateful_module(coll): - coll.store(1) - - @nn.module - def test_inner(f): - with nn.Collection().mutate() as coll: - # this should fail because f is a shared module defined - # in the parent. Therefore we cannot capture in the scope - # of this Module. - f(coll) - - @nn.module - def test(): - f = stateful_module.shared() - test_inner(f) - pattern = 'Trying to capture state outside the scope' - with self.assertRaisesRegex(ValueError, pattern): - test.init(random.PRNGKey(0)) - - # TODO(jheek): re-introduce this test when the tracer check is revived. - # def test_jax_transform_of_stateful_function(self): - # test = self - # class NestedTransform(nn.Module): - - # def apply(self, state, y): - # def inner_fn(x): - # # constants should be storable - # state.store(1.) - # # values in the same trace should be storable - # state.store({'a': y}) - # with test.assertRaises(ValueError): - # # values depending on the vmap should not be storable - # state.store({'a': y, 'b': x}) - # jax.vmap(inner_fn)(jnp.ones((2,))) - - # def outer_fn(x): - # with nn.Collection().mutate() as state: - # NestedTransform.init(random.PRNGKey(0), state, x) - - # outer_fn(1.) - # jax.jit(outer_fn)(1.) - - -class UtilsTest(absltest.TestCase): - - def test_call_stack_happy_path(self): - stack = nn.utils.CallStack() - self.assertFalse(stack) - with stack.frame({'id': 1}): - self.assertTrue(stack) - self.assertEqual(stack[-1], {'id': 1}) - with stack.frame({'id': 2}): - self.assertEqual(list(stack), [{'id': 1}, {'id': 2}]) - self.assertEqual(list(stack), [{'id': 1}]) - - def test_call_stack_multithreading(self): - stack = nn.utils.CallStack() - self.assertFalse(stack) - with stack.frame({'id': 1}): - self.assertEqual(stack[-1], {'id': 1}) - def _main(): - # Each thread should have its own stack. - self.assertFalse(stack) - with stack.frame({'id': 2}): - self.assertEqual(stack[-1], {'id': 2}) - thread = threading.Thread(target=_main) - thread.start() - thread.join() - - def test_call_stack_error_path(self): - stack = nn.utils.CallStack() - with stack.frame({'id': 1}): - with self.assertRaises(ValueError): - with stack.frame({'id': 2}): - raise ValueError('dummy') - self.assertEqual(list(stack), [{'id': 1}]) - - -class PoolTest(absltest.TestCase): - - def test_pool_custom_reduce(self): - x = jnp.full((1, 3, 3, 1), 2.) - mul_reduce = lambda x, y: x * y - y = nn.pooling.pool(x, 1., mul_reduce, (2, 2), (1, 1), 'VALID') - np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2. ** 4)) - - def test_avg_pool(self): - x = jnp.full((1, 3, 3, 1), 2.) - pool = lambda x: nn.avg_pool(x, (2, 2)) - y = pool(x) - np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.)) - y_grad = jax.grad(lambda x: pool(x).sum())(x) - expected_grad = jnp.array([ - [0.25, 0.5, 0.25], - [0.5, 1., 0.5], - [0.25, 0.5, 0.25], - ]).reshape((1, 3, 3, 1)) - np.testing.assert_allclose(y_grad, expected_grad) - - def test_max_pool(self): - x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32) - pool = lambda x: nn.max_pool(x, (2, 2)) - expected_y = jnp.array([ - [4., 5.], - [7., 8.], - ]).reshape((1, 2, 2, 1)) - y = pool(x) - np.testing.assert_allclose(y, expected_y) - y_grad = jax.grad(lambda x: pool(x).sum())(x) - expected_grad = jnp.array([ - [0., 0., 0.], - [0., 1., 1.], - [0., 1., 1.], - ]).reshape((1, 3, 3, 1)) - np.testing.assert_allclose(y_grad, expected_grad) - - def test_max_pool_explicit_pads(self): - x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32) - pool = lambda x: nn.max_pool(x, (2, 2), padding=((1,1),(1,1))) - expected_y = jnp.array([ - [0.,1.,2.,2.], - [3.,4.,5.,5.], - [6.,7.,8.,8.], - [6.,7.,8.,8.], - ]).reshape((1, 4, 4, 1)) - y = pool(x) - np.testing.assert_allclose(y, expected_y) - y_grad = jax.grad(lambda x: pool(x).sum())(x) - expected_grad = jnp.array([ - [1., 1., 2.], - [1., 1., 2.], - [2., 2., 4.], - ]).reshape((1, 3, 3, 1)) - np.testing.assert_allclose(y_grad, expected_grad) - -class NormalizationTest(absltest.TestCase): - - def test_batch_norm(self): - rng = random.PRNGKey(0) - key1, key2 = random.split(rng) - x = random.normal(key1, (4, 3, 2)) - model_cls = nn.BatchNorm.partial(momentum=0.9) - with nn.stateful() as state_0: - y, initial_params = model_cls.init(key2, x) - model = nn.Model(model_cls, initial_params) - mean = y.mean((0, 1)) - var = y.var((0, 1)) - np.testing.assert_allclose(mean, np.array([0., 0.]), atol=1e-4) - np.testing.assert_allclose(var, np.array([1., 1.]), rtol=1e-4) - with nn.stateful(state_0) as state: - y = model(x) - ema = state['/'] - np.testing.assert_allclose( - ema['mean'], 0.1 * x.mean((0, 1), keepdims=False), atol=1e-4) - np.testing.assert_allclose( - ema['var'], 0.9 + 0.1 * x.var((0, 1), keepdims=False), rtol=1e-4) - - def test_layer_norm(self): - rng = random.PRNGKey(0) - key1, key2 = random.split(rng) - e = 1e-5 - x = random.normal(key1, (2, 3, 4)) - y, _ = nn.LayerNorm.init(key2, x, bias=False, scale=False, epsilon=e) - assert x.shape == y.shape - input_type = type(x) - assert isinstance(y, input_type) - y_one_liner = ((x - x.mean(axis=-1, keepdims=True)) * - jax.lax.rsqrt(x.var(axis=-1, keepdims=True) + e)) - np.testing.assert_allclose(y_one_liner, y, atol=1e-4) - - def test_group_norm(self): - rng = random.PRNGKey(0) - key1, key2 = random.split(rng) - e = 1e-5 - x = random.normal(key1, (2, 5, 4, 4, 32)) - y, _ = nn.GroupNorm.init(key2, x, num_groups=2, - bias=True, scale=True, epsilon=e) - self.assertEqual(x.shape, y.shape) - self.assertIsInstance(y, type(x)) - - x_gr = x.reshape([2, 5, 4, 4, 2, 16]) - y_test = ((x_gr - x_gr.mean(axis=[1, 2, 3, 5], keepdims=True)) * - jax.lax.rsqrt(x_gr.var(axis=[1, 2, 3, 5], keepdims=True) + e)) - y_test = y_test.reshape([2, 5, 4, 4, 32]) - - np.testing.assert_allclose(y_test, y, atol=1e-4) - - -# TODO(flax-dev): add integration tests for RNN cells -class RecurrentTest(absltest.TestCase): - - def test_lstm(self): - rng = random.PRNGKey(0) - key1, key2 = random.split(rng) - x = random.normal(key1, (2, 3)) - c0, h0 = nn.LSTMCell.initialize_carry(rng, (2,), 4) - self.assertEqual(c0.shape, (2, 4)) - self.assertEqual(h0.shape, (2, 4)) - (carry, y), initial_params = nn.LSTMCell.init(key2, (c0, h0), x) - lstm = nn.Model(nn.LSTMCell, initial_params) - self.assertEqual(carry[0].shape, (2, 4)) - self.assertEqual(carry[1].shape, (2, 4)) - np.testing.assert_allclose(y, carry[1]) - param_shapes = jax.tree_map(np.shape, lstm.params) - self.assertEqual(param_shapes, { - 'ii': {'kernel': (3, 4)}, - 'if': {'kernel': (3, 4)}, - 'ig': {'kernel': (3, 4)}, - 'io': {'kernel': (3, 4)}, - 'hi': {'kernel': (4, 4), 'bias': (4,)}, - 'hf': {'kernel': (4, 4), 'bias': (4,)}, - 'hg': {'kernel': (4, 4), 'bias': (4,)}, - 'ho': {'kernel': (4, 4), 'bias': (4,)}, - }) - - def test_gru(self): - rng = random.PRNGKey(0) - key1, key2 = random.split(rng) - x = random.normal(key1, (2, 3)) - carry0 = nn.GRUCell.initialize_carry(rng, (2,), 4) - self.assertEqual(carry0.shape, (2, 4)) - (carry, y), initial_params = nn.GRUCell.init(key2, carry0, x) - gru = nn.Model(nn.GRUCell, initial_params) - self.assertEqual(carry.shape, (2, 4)) - np.testing.assert_allclose(y, carry) - param_shapes = jax.tree_map(np.shape, gru.params) - self.assertEqual(param_shapes, { - 'ir': {'kernel': (3, 4), 'bias': (4,)}, - 'iz': {'kernel': (3, 4), 'bias': (4,)}, - 'in': {'kernel': (3, 4), 'bias': (4,)}, - 'hr': {'kernel': (4, 4)}, - 'hz': {'kernel': (4, 4)}, - 'hn': {'kernel': (4, 4), 'bias': (4,)}, - }) - - def test_conv2dlstm(self): - rng = random.PRNGKey(0) - key1, key2 = random.split(rng) - x = random.normal(key1, (2, 4, 4, 3)) - c0, h0 = nn.ConvLSTM.initialize_carry(rng, (2,), (4, 4, 6)) - self.assertEqual(c0.shape, (2, 4, 4, 6)) - self.assertEqual(h0.shape, (2, 4, 4, 6)) - (carry, y), initial_params = nn.ConvLSTM.init( - key2, (c0, h0), x, features=6, kernel_size=(3, 3)) - lstm = nn.Model(nn.ConvLSTM, initial_params) - self.assertEqual(carry[0].shape, (2, 4, 4, 6)) - self.assertEqual(carry[1].shape, (2, 4, 4, 6)) - np.testing.assert_allclose(y, carry[1]) - param_shapes = jax.tree_map(np.shape, lstm.params) - self.assertEqual(param_shapes, { - 'hh': {'bias': (6*4,), 'kernel': (3, 3, 6, 6*4)}, - 'ih': {'bias': (6*4,), 'kernel': (3, 3, 3, 6*4)}, - }) - - def test_optimized_lstm_cell_matches_regular(self): - - # Create regular LSTMCell. - rng = random.PRNGKey(0) - key1, key2 = random.split(rng) - x = random.normal(key1, (2, 3)) - c0, h0 = nn.LSTMCell.initialize_carry(rng, (2,), 4) - self.assertEqual(c0.shape, (2, 4)) - self.assertEqual(h0.shape, (2, 4)) - (carry, y), initial_params = nn.LSTMCell.init(key2, (c0, h0), x) - lstm = nn.Model(nn.LSTMCell, initial_params) - - # Create OptimizedLSTMCell. - rng = random.PRNGKey(0) - key1, key2 = random.split(rng) - x = random.normal(key1, (2, 3)) - c0, h0 = nn.OptimizedLSTMCell.initialize_carry(rng, (2,), 4) - self.assertEqual(c0.shape, (2, 4)) - self.assertEqual(h0.shape, (2, 4)) - (carry, y_opt), initial_params = nn.OptimizedLSTMCell.partial( - name='LSTMCell').init(key2, (c0, h0), x) - lstm_opt = nn.Model(nn.OptimizedLSTMCell.partial(name='LSTMCell'), - initial_params) - - np.testing.assert_allclose(y, y_opt, rtol=1e-6) - jtu.check_eq(lstm.params, lstm_opt.params) - - -class StochasticTest(absltest.TestCase): - - def test_make_rng_requires_stochastic(self): - with self.assertRaises(ValueError): - nn.make_rng() - - def test_stochastic_rngs(self): - rng = random.PRNGKey(0) - with nn.stochastic(rng): - r1 = nn.make_rng() - r2 = nn.make_rng() - self.assertTrue(np.all(r1 == random.fold_in(rng, 1))) - self.assertTrue(np.all(r2 == random.fold_in(rng, 2))) - - # TODO(jheek): re-introduce this test when the tracer check is revived. - # def test_make_rng_in_jax_transform_check(self): - # with nn.stochastic(random.PRNGKey(0)): - # with self.assertRaises(ValueError): - # jax.jit(nn.make_rng)() - - def test_init_by_shape_lifts_stochastic(self): - class StochasticModule(nn.Module): - def apply(self): - return nn.make_rng() - - with nn.stochastic(random.PRNGKey(0)): - rng, _ = StochasticModule.init_by_shape(random.PRNGKey(1), []) - expected_rng = random.fold_in(random.PRNGKey(0), 1) - expected_rng = random.fold_in(expected_rng, 1) - self.assertTrue(np.all(rng == expected_rng)) - - -if __name__ == '__main__': - absltest.main() diff --git a/tests/optim_test.py b/tests/optim_test.py deleted file mode 100644 index 150f4ac4a..000000000 --- a/tests/optim_test.py +++ /dev/null @@ -1,566 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for flax.optim.""" - -from functools import partial -from absl.testing import absltest -from flax import optim -from flax import traverse_util -from flax.core.frozen_dict import FrozenDict -from flax.deprecated import nn -from flax.optim.adabelief import _AdaBeliefHyperParams, _AdaBeliefParamState -from flax.optim.adadelta import _AdadeltaHyperParams, _AdadeltaParamState -from flax.optim.adafactor import _AdafactorHyperParams, _AdafactorParamState -from flax.optim.adagrad import _AdagradHyperParams, _AdagradParamState -from flax.optim.adam import _AdamHyperParams, _AdamParamState -from flax.optim.momentum import _MomentumHyperParams, _MomentumParamState -from flax.optim.rmsprop import _RMSPropHyperParams, _RMSPropParamState -from flax.optim.sgd import _GradientDescentHyperParams -from flax.optim.weight_norm import _WeightNormParamState -import jax -import jax.numpy as jnp -import numpy as np -import pytest - -# Parse absl flags test_srcdir and test_tmpdir. -jax.config.parse_flags_with_absl() - - -def _assert_numpy_allclose(a, b, atol=None, rtol=None): - a, b = jnp.array(a), jnp.array(b) - a = a.astype(np.float32) if a.dtype == jnp.bfloat16 else a - b = b.astype(np.float32) if b.dtype == jnp.bfloat16 else b - kw = {} - if atol: kw["atol"] = atol - if rtol: kw["rtol"] = rtol - np.testing.assert_allclose(a, b, **kw) - - -def check_eq(xs, ys, atol=None, rtol=None): - xs_leaves, xs_tree = jax.tree_flatten(xs) - ys_leaves, ys_tree = jax.tree_flatten(ys) - assert xs_tree == ys_tree, "Tree shapes don't match." - assert jax.tree_util.tree_all(jax.tree_multimap( - lambda x, y: np.array(x).shape == np.array(y).shape, - xs_leaves, ys_leaves)), "Leaves' shapes don't match." - assert jax.tree_multimap( - partial(_assert_numpy_allclose, atol=atol, rtol=rtol), - xs_leaves, ys_leaves) - - -class OptimizerDefTest(absltest.TestCase): - - def test_create(self): - params = np.ones((1,)) - optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2) - optimizer = optimizer_def.create(params) - expected_state = optim.OptimizerState( - 0, _MomentumParamState(np.zeros((1,)))) - self.assertEqual(optimizer.optimizer_def, optimizer_def) - self.assertEqual(optimizer.state, expected_state) - self.assertEqual(optimizer.target, params) - - @pytest.mark.filterwarnings("ignore: compute_gradient()") - def test_compute_grad(self): - params = np.ones(()) - optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2) - optimizer = optimizer_def.create(params) - def loss_fn(x): - return 2. * x - loss, grad = optimizer.compute_gradient(loss_fn) - self.assertEqual(loss, 2.) - self.assertEqual(grad, 2.) - - def loss_aux_fn(x): - return 3. * x, 4. - loss, aux, grad = optimizer.compute_gradient(loss_aux_fn) - self.assertEqual(loss, 3.) - self.assertEqual(grad, 3.) - self.assertEqual(aux, 4.) - - def test_optimizer_with_focus(self): - params = {'a': 0., 'b': 0.} - opt_def = optim.GradientDescent(learning_rate=1.) - t_a = traverse_util.t_identity['a'] - optimizer = opt_def.create(params, focus=t_a) - expected_state = optim.OptimizerState(0, {'a': (), 'b': None}) - self.assertEqual(optimizer.state, expected_state) - grads = {'a': -1., 'b': -2.} - new_optimizer = optimizer.apply_gradient(grads) - expected_params = {'a': 1., 'b': 0.} - expected_state = optim.OptimizerState(1, {'a': (), 'b': None}) - self.assertEqual(new_optimizer.state, expected_state) - self.assertEqual(new_optimizer.target, expected_params) - - def test_empty_optimizer(self): - params = {} - optimizer_def = optim.Momentum(learning_rate=0.1) - optimizer = optimizer_def.create(params) - new_optimizer = optimizer.apply_gradient({}) - expected_state = optim.OptimizerState(1, {}) - self.assertEqual(new_optimizer.state, expected_state) - - -class MultiOptimizerTest(absltest.TestCase): - - def test_multi_optimizer(self): - params = {'a': 0., 'b': 0., 'c': {}} - opt_a = optim.GradientDescent(learning_rate=1.) - opt_b = optim.GradientDescent(learning_rate=10.) - t_a = traverse_util.t_identity['a'] - t_b = traverse_util.t_identity['b'] - optimizer_def = optim.MultiOptimizer((t_a, opt_a), (t_b, opt_b)) - state = optimizer_def.init_state(params) - expected_hyper_params = [ - _GradientDescentHyperParams(1.), - _GradientDescentHyperParams(10.) - ] - self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) - expected_state = optim.OptimizerState(0, {'a': (), 'b': (), 'c': {}}) - self.assertEqual(state, expected_state) - grads = {'a': -1., 'b': -2., 'c': {}} - new_params, new_state = optimizer_def.apply_gradient( - optimizer_def.hyper_params, params, state, grads) - expected_params = {'a': 1., 'b': 20., 'c': {}} - expected_state = optim.OptimizerState(1, {'a': (), 'b': (), 'c': {}}) - self.assertEqual(new_state, expected_state) - self.assertEqual(new_params, expected_params) - # override learning_rate - hp = optimizer_def.update_hyper_params(learning_rate=2.) - new_params, new_state = optimizer_def.apply_gradient( - hp, params, state, grads) - expected_params = {'a': 2., 'b': 4., 'c': {}} - self.assertEqual(new_params, expected_params) - - def test_multi_optimizer_multiple_matches(self): - params = {'a': {'x': 0., 'y': 0.}, 'b': {'y': 0, 'z': 0.}} - opt_a = optim.GradientDescent(learning_rate=1.) - opt_b = optim.GradientDescent(learning_rate=10.) - t_a = traverse_util.ModelParamTraversal( - lambda path, _: path.endswith('/x') or path.endswith('/y') - ) - t_b = traverse_util.ModelParamTraversal( - lambda path, value: value.dtype == jnp.int32 or path.endswith('/z') - ) - optimizer_def = optim.MultiOptimizer((t_a, opt_a), (t_b, opt_b)) - with self.assertRaisesRegex( - ValueError, r"Multiple optimizers match.*'y': \[0, 1\]"): - jax.jit(optimizer_def.init_state)(params) - - -class GradientDescentTest(absltest.TestCase): - - def test_init_state(self): - params = np.zeros((1,)) - optimizer_def = optim.GradientDescent(learning_rate=0.1) - state = optimizer_def.init_state(params) - expected_hyper_params = _GradientDescentHyperParams(0.1) - self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) - expected_state = optim.OptimizerState(0, ()) - self.assertEqual(state, expected_state) - - def test_apply_gradient(self): - optimizer_def = optim.GradientDescent(learning_rate=0.1) - params = np.ones((1,)) - state = optim.OptimizerState(0, ()) - grads = np.array([3.]) - new_params, new_state = optimizer_def.apply_gradient( - optimizer_def.hyper_params, params, state, grads) - expected_new_state = optim.OptimizerState(1, ()) - expected_new_params = np.array([0.7]) - self.assertEqual(new_params, expected_new_params) - self.assertEqual(new_state, expected_new_state) - - -class MomentumTest(absltest.TestCase): - - def test_init_state(self): - params = np.zeros((1,)) - optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2) - state = optimizer_def.init_state(params) - expected_hyper_params = _MomentumHyperParams(0.1, 0.2, 0, False) - self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) - expected_state = optim.OptimizerState( - 0, _MomentumParamState(np.zeros((1,)))) - self.assertEqual(state, expected_state) - - def test_apply_gradient(self): - optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.2) - params = np.ones((1,)) - state = optim.OptimizerState( - 0, _MomentumParamState(np.array([1.]))) - grads = np.array([3.]) - new_params, new_state = optimizer_def.apply_gradient( - optimizer_def.hyper_params, params, state, grads) - expected_new_state = optim.OptimizerState( - 1, _MomentumParamState(np.array([3.2]))) - expected_new_params = np.array([1. - 0.32]) - self.assertEqual(new_params, expected_new_params) - self.assertEqual(new_state, expected_new_state) - - -class AdamTest(absltest.TestCase): - - def test_init_state(self): - params = np.zeros((1,)) - optimizer_def = optim.Adam(learning_rate=0.1, - beta1=0.2, - beta2=0.9, - eps=0.01, - weight_decay=0.0) - state = optimizer_def.init_state(params) - - expected_hyper_params = _AdamHyperParams(0.1, 0.2, 0.9, 0.01, 0.0) - self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) - expected_state = optim.OptimizerState( - 0, _AdamParamState(np.zeros((1,)), np.zeros((1,)))) - self.assertEqual(state, expected_state) - - def test_apply_gradient(self): - optimizer_def = optim.Adam(learning_rate=0.1, - beta1=0.2, - beta2=0.9, - eps=0.01, - weight_decay=0.0) - params = np.array([1.]) - state = optim.OptimizerState( - 1, _AdamParamState(np.array([0.1]), np.array([0.9]))) - grads = np.array([4.]) - new_params, new_state = optimizer_def.apply_gradient( - optimizer_def.hyper_params, params, state, grads) - expected_new_state = optim.OptimizerState( - 2, _AdamParamState(np.array([3.22]), np.array([2.41]))) - expected_new_params = np.array([0.906085]) - np.testing.assert_allclose(new_params, expected_new_params) - self.assertEqual(new_state, expected_new_state) - - -class AdaBeliefTest(absltest.TestCase): - - def test_init_state(self): - params = np.zeros((1,)) - optimizer_def = optim.AdaBelief( - learning_rate=0.1, beta1=0.2, beta2=0.9, eps=0.01, weight_decay=0.0) - state = optimizer_def.init_state(params) - - expected_hyper_params = _AdaBeliefHyperParams(0.1, 0.2, 0.9, 0.01, 0.0) - self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) - expected_state = optim.OptimizerState( - 0, _AdaBeliefParamState(np.zeros((1,)), np.zeros((1,)))) - self.assertEqual(state, expected_state) - - def test_apply_gradient(self): - optimizer_def = optim.AdaBelief( - learning_rate=0.1, beta1=0.2, beta2=0.9, eps=0.01, weight_decay=0.0) - params = np.array([1.]) - state = optim.OptimizerState( - 1, _AdaBeliefParamState(np.array([0.1]), np.array([0.9]))) - grads = np.array([4.]) - new_params, new_state = optimizer_def.apply_gradient( - optimizer_def.hyper_params, params, state, grads) - expected_new_state = optim.OptimizerState( - 2, _AdaBeliefParamState(np.array([3.22]), np.array([0.88084]))) - expected_new_params = np.array([0.8449397]) - np.testing.assert_allclose(new_params, expected_new_params) - self.assertEqual(new_state, expected_new_state) - - -class AdadeltaTest(absltest.TestCase): - - def test_init_state(self): - params = np.zeros((1,)) - optimizer_def = optim.Adadelta(learning_rate=0.1, - rho=0.9, - eps=1e-6, - weight_decay=0.1) - state = optimizer_def.init_state(params) - - expected_hyper_params = _AdadeltaHyperParams(0.1, 0.9, 1e-6, 0.1) - self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) - expected_state = optim.OptimizerState( - 0, _AdadeltaParamState(np.zeros((1,)), np.zeros((1,))) - ) - self.assertEqual(state, expected_state) - - def test_apply_gradient(self): - optimizer_def = optim.Adadelta(learning_rate=0.1, - rho=0.9, - eps=1e-6, - weight_decay=0.1) - params = np.array([1.]) - state = optim.OptimizerState( - 1, _AdadeltaParamState(np.zeros((1,)), np.zeros((1,))) - ) - grads = np.array([1.]) - new_param, new_state = optimizer_def.apply_gradient( - optimizer_def.hyper_params, params, state, grads - ) - expected_new_state = optim.OptimizerState( - 2, _AdadeltaParamState(np.array([0.1]), np.array([9.999902e-7])) - ) - expected_new_params = np.array([0.9896838]) - np.testing.assert_allclose(new_param, expected_new_params) - self.assertEqual(new_state, expected_new_state) - - -class AdafactorTest(absltest.TestCase): - - def test_init_state(self): - params = np.zeros((3, 2)) - optimizer_def = optim.Adafactor(learning_rate=0.1, - decay_rate=0.8, - beta1=None, - min_dim_size_to_factor=0) - state = optimizer_def.init_state(params) - - expected_hyper_params = _AdafactorHyperParams(0.1, True, True, - None, 0.8, 0, 1.0, None, 0, - 1e-30, 1e-3) - self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) - expected_state = optim.OptimizerState( - 0, _AdafactorParamState(np.zeros((2,)), np.zeros((3,)), - np.zeros((1,)), np.zeros((1,)))) - check_eq(state, expected_state) - - # unfactorized - optimizer_def = optim.Adafactor(learning_rate=0.1, - decay_rate=0.8, - beta1=0.0, - min_dim_size_to_factor=32) - state = optimizer_def.init_state(params) - - expected_hyper_params = _AdafactorHyperParams(0.1, True, True, - 0.0, 0.8, 0, 1.0, None, 32, - 1e-30, 1e-3) - self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) - expected_state = optim.OptimizerState( - 0, _AdafactorParamState(np.zeros((1,)), np.zeros((1,)), - np.zeros((3, 2)), np.zeros((3, 2)))) - check_eq(state, expected_state) - - def test_apply_gradient(self): - optimizer_def = optim.Adafactor(learning_rate=0.1, decay_rate=0.8, - min_dim_size_to_factor=0) - params = np.ones((3, 2), np.float32) - state = optim.OptimizerState( - 1, _AdafactorParamState(np.array([0.9, 0.9]), - np.array([0.1, 0.1, 0.1]), - np.zeros((1,)), - np.zeros((1,)))) - grads = np.ones((3, 2), np.float32) - new_params, new_state = optimizer_def.apply_gradient( - optimizer_def.hyper_params, params, state, grads) - expected_new_state = optim.OptimizerState( - 2, _AdafactorParamState( - np.array([0.9574349, 0.9574349]), - np.array([0.6169143, 0.6169143, 0.6169143]), - np.zeros((1,)), - np.zeros((1,)))) - expected_new_params = 0.9 * np.ones((3, 2)) - np.testing.assert_allclose(new_params, expected_new_params) - check_eq(new_state, expected_new_state, rtol=1e-6) - - # unfactored w momentum - optimizer_def = optim.Adafactor(learning_rate=0.1, - beta1=0.0, decay_rate=0.8, - min_dim_size_to_factor=32) - params = np.ones((3, 2), np.float32) - state = optim.OptimizerState( - 1, _AdafactorParamState(np.zeros(1,), - np.zeros(1,), - 0.5*np.ones((3, 2)), - np.zeros((3, 2)))) - grads = np.ones((3, 2), np.float32) - new_params, new_state = optimizer_def.apply_gradient( - optimizer_def.hyper_params, params, state, grads) - expected_new_params = 0.9 * np.ones((3, 2)) - np.testing.assert_allclose(new_params, expected_new_params) - expected_new_state = optim.OptimizerState( - 2, _AdafactorParamState( - np.array([0.0]), - np.array([0.0]), - 0.787174 * np.ones((3, 2)), - 0.1 * np.ones((3,2)))) - check_eq(new_state, expected_new_state, rtol=1e-6) - - def test_factorizes(self): - params = np.zeros((64, 64)) - optimizer_def = optim.Adafactor(learning_rate=0.1, - decay_rate=0.8, - beta1=None, - min_dim_size_to_factor=32) - state = optimizer_def.init_state(params) - self.assertEqual(state.param_states.v.shape, (1,)) - self.assertEqual(state.param_states.m.shape, (1,)) - self.assertEqual(state.param_states.v_row.shape, (64,)) - self.assertEqual(state.param_states.v_col.shape, (64,)) - - params = np.zeros((31, 64)) - optimizer_def = optim.Adafactor(learning_rate=0.1, - decay_rate=0.8, - beta1=None, - min_dim_size_to_factor=32) - state = optimizer_def.init_state(params) - self.assertEqual(state.param_states.v.shape, (31, 64)) - self.assertEqual(state.param_states.m.shape, (1,)) - self.assertEqual(state.param_states.v_row.shape, (1,)) - self.assertEqual(state.param_states.v_col.shape, (1,)) - - -class AdagradTest(absltest.TestCase): - - def test_init_state(self): - params = np.zeros((1,)) - optimizer_def = optim.Adagrad(learning_rate=0.1, eps=0.01) - state = optimizer_def.init_state(params) - - expected_hyper_params = _AdagradHyperParams(0.1, 0.01) - self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) - expected_state = optim.OptimizerState( - 0, _AdagradParamState(np.zeros((1,)))) - self.assertEqual(state, expected_state) - - def test_apply_gradient(self): - optimizer_def = optim.Adagrad(learning_rate=0.1, eps=0.01) - params = np.array([1.]) - state = optim.OptimizerState( - 1, _AdagradParamState(np.array([0.1]))) - grads = np.array([4.]) - new_params, new_state = optimizer_def.apply_gradient( - optimizer_def.hyper_params, params, state, grads) - expected_new_state = optim.OptimizerState( - 2, _AdagradParamState(np.array([16.1]))) - expected_new_params = np.array([0.9005588]) - np.testing.assert_allclose(new_params, expected_new_params) - self.assertEqual(new_state, expected_new_state) - - -class RMSPropTest(absltest.TestCase): - - def test_init_state(self): - params = np.zeros((1,)) - optimizer_def = optim.RMSProp(learning_rate=0.1, - beta2=0.9, - eps=0.01, - centered=False) - state = optimizer_def.init_state(params) - - expected_hyper_params = _RMSPropHyperParams(0.1, 0.9, 0.01, False) - self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) - expected_state = optim.OptimizerState( - 0, _RMSPropParamState(np.zeros((1,)), None)) - self.assertEqual(state, expected_state) - - def test_init_state_centered(self): - params = np.zeros((1,)) - optimizer_def = optim.RMSProp(learning_rate=0.1, - beta2=0.9, - eps=0.01, - centered=True) - state = optimizer_def.init_state(params) - - expected_hyper_params = _RMSPropHyperParams(0.1, 0.9, 0.01, True) - self.assertEqual(optimizer_def.hyper_params, expected_hyper_params) - expected_state = optim.OptimizerState( - 0, _RMSPropParamState(np.zeros((1,)), np.zeros((1,)))) - self.assertEqual(state, expected_state) - - def test_apply_gradient(self): - optimizer_def = optim.RMSProp(learning_rate=0.1, - beta2=0.9, - eps=0.01) - params = np.array([1.]) - state = optim.OptimizerState( - 1, _RMSPropParamState(np.array([0.1]), None)) - grads = np.array([4.]) - new_params, new_state = optimizer_def.apply_gradient( - optimizer_def.hyper_params, params, state, grads) - expected_new_state = optim.OptimizerState( - 2, _RMSPropParamState(np.array([1.69]), None)) - expected_new_params = np.array([0.6946565]) - np.testing.assert_allclose(new_params, expected_new_params) - self.assertEqual(new_state, expected_new_state) - - def test_apply_gradient_centered(self): - optimizer_def = optim.RMSProp(learning_rate=0.1, - beta2=0.9, - eps=0.01, - centered=True) - params = np.array([1.]) - state = optim.OptimizerState( - 1, _RMSPropParamState(np.array([0.1]), np.array([0.1]))) - grads = np.array([4.]) - new_params, new_state = optimizer_def.apply_gradient( - optimizer_def.hyper_params, params, state, grads) - expected_new_state = optim.OptimizerState( - 2, _RMSPropParamState(np.array([1.69]), np.array([0.49]))) - expected_new_params = np.array([0.670543], dtype=np.float32) - np.testing.assert_allclose(new_params, expected_new_params, rtol=1e-6) - np.testing.assert_allclose(new_state.param_states.v, - expected_new_state.param_states.v) - np.testing.assert_allclose(new_state.param_states.mg, - expected_new_state.param_states.mg) - - -class WeightNormTest(absltest.TestCase): - - def test_momentum_with_weight_norm(self): - params = np.ones((2, 2)) * 2. - optimizer_def = optim.WeightNorm(optim.Momentum(0.1)) - state = optimizer_def.init_state(params) - self.assertEqual(jax.tree_map(np.shape, state), optim.OptimizerState( - step=(), - param_states=_WeightNormParamState( - direction_state=_MomentumParamState(momentum=(2, 2)), - scale_state=_MomentumParamState(momentum=(1, 2)), - direction=(2, 2), - scale=(1, 2), - ) - )) - grads = np.ones((2, 2)) - new_params, new_state = optimizer_def.apply_gradient( - optimizer_def.hyper_params, params, state, grads) - np.testing.assert_allclose(new_params, np.full_like(params, 1.9)) - np.testing.assert_allclose(new_state.param_states.direction, np.full_like(params, 2 ** -0.5)) - np.testing.assert_allclose(new_state.param_states.scale, np.full((1, 2), (2 * 1.9 ** 2) ** 0.5)) - - -class DynamicScaleTest(absltest.TestCase): - - def test_dynamic_scale(self): - def loss_fn(p): - return jnp.asarray(p, jnp.float16) ** 2 - p = jnp.array(1., jnp.float32) - - dyn_scale = optim.DynamicScale(growth_interval=2) - step = jax.jit(lambda ds, p: ds.value_and_grad(loss_fn)(p)) - inf = float('inf') - nan = float('nan') - expected_values = [ - (False, nan, 32768.0, inf), - (False, 1.0, 16384.0, inf), - (True, 1.0, 16384.0, 2.0), - (True, 1.0, 16384.0, 2.0), - (True, 1.0, 32768.0, 2.0), - (False, 1.0, 16384.0, inf), - ] - - for expected in expected_values: - dyn_scale, is_fin, loss, grad = step(dyn_scale, p) - values = np.array((is_fin, loss, dyn_scale.scale, grad)) - np.testing.assert_allclose(values, expected) - -if __name__ == '__main__': - absltest.main() diff --git a/tests/serialization_test.py b/tests/serialization_test.py deleted file mode 100644 index f2a6cb424..000000000 --- a/tests/serialization_test.py +++ /dev/null @@ -1,322 +0,0 @@ -# Copyright 2022 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for flax.struct.""" - -import collections - -from typing import Any - -from absl.testing import absltest -from flax import optim -from flax import serialization -from flax import struct -from flax.deprecated import nn -import jax -from jax import random -import jax.numpy as jnp -import msgpack - -import numpy as np - -# Parse absl flags test_srcdir and test_tmpdir. -jax.config.parse_flags_with_absl() - - -@struct.dataclass -class Point: - x: float - y: float - meta: Any = struct.field(pytree_node=False) - - -class SerializationTest(absltest.TestCase): - - def test_dataclass_serialization(self): - p = Point(x=1, y=2, meta={'dummy': True}) - state_dict = serialization.to_state_dict(p) - self.assertEqual(state_dict, { - 'x': 1, - 'y': 2, - }) - restored_p = serialization.from_state_dict(p, {'x': 3, 'y': 4}) - expected_p = Point(x=3, y=4, meta={'dummy': True}) - self.assertEqual(restored_p, expected_p) - - with self.assertRaises(ValueError): # invalid field - serialization.from_state_dict(p, {'z': 3}) - with self.assertRaises(ValueError): # missing field - serialization.from_state_dict(p, {'x': 3}) - - def test_model_serialization(self): - rng = random.PRNGKey(0) - module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones) - _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)]) - model = nn.Model(module, initial_params) - state = serialization.to_state_dict(model) - self.assertEqual(state, { - 'params': { - 'kernel': np.ones((1, 1)), - 'bias': np.zeros((1,)), - } - }) - state = { - 'params': { - 'kernel': np.zeros((1, 1)), - 'bias': np.zeros((1,)), - } - } - restored_model = serialization.from_state_dict(model, state) - self.assertEqual(restored_model.params, state['params']) - - def test_optimizer_serialization(self): - rng = random.PRNGKey(0) - module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones) - _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)]) - model = nn.Model(module, initial_params) - optim_def = optim.Momentum(learning_rate=1.) - optimizer = optim_def.create(model) - state = serialization.to_state_dict(optimizer) - expected_state = { - 'target': { - 'params': { - 'kernel': np.ones((1, 1)), - 'bias': np.zeros((1,)), - } - }, - 'state': { - 'step': 0, - 'param_states': { - 'params': { - 'kernel': {'momentum': np.zeros((1, 1))}, - 'bias': {'momentum': np.zeros((1,))}, - } - } - }, - } - self.assertEqual(state, expected_state) - state = jax.tree_map(lambda x: x + 1, expected_state) - restored_optimizer = serialization.from_state_dict(optimizer, state) - optimizer_plus1 = jax.tree_map(lambda x: x + 1, optimizer) - self.assertEqual(restored_optimizer, optimizer_plus1) - - def test_collection_serialization(self): - - @struct.dataclass - class DummyDataClass: - x: float - - @classmethod - def initializer(cls, key, shape): - del shape, key - return cls(x=0.) - - class StatefulModule(nn.Module): - - def apply(self): - state = self.state('state', (), DummyDataClass.initializer) - state.value = state.value.replace(x=state.value.x + 1.) - - # use stateful - with nn.stateful() as state: - self.assertEqual(state.as_dict(), {}) - StatefulModule.init(random.PRNGKey(0)) - self.assertEqual(state.as_dict(), {'/': {'state': DummyDataClass(x=1.)}}) - with nn.stateful(state) as new_state: - StatefulModule.call({}) - self.assertEqual(new_state.as_dict(), - {'/': { - 'state': DummyDataClass(x=2.) - }}) - serialized_state_dict = serialization.to_state_dict(new_state) - self.assertEqual(serialized_state_dict, {'/': {'state': {'x': 2.}}}) - deserialized_state = serialization.from_state_dict(state, - serialized_state_dict) - self.assertEqual(state.as_dict(), {'/': {'state': DummyDataClass(x=1.)}}) - self.assertEqual(new_state.as_dict(), deserialized_state.as_dict()) - - def test_numpy_serialization(self): - normal_dtypes = ['byte', 'b', 'ubyte', 'short', - 'h', 'ushort', 'i', 'uint', 'intp', - 'p', 'uintp', 'long', 'l', 'longlong', - 'q', 'ulonglong', 'half', 'e', 'f', - 'double', 'd', 'longdouble', 'g', - 'cfloat', 'cdouble', 'clongdouble', 'm', - 'bool8', 'b1', 'int64', 'i8', 'uint64', 'u8', - 'float16', 'f2', 'float32', 'f4', 'float64', - 'f8', 'float128', 'f16', 'complex64', 'c8', - 'complex128', 'c16', 'complex256', 'c32', - 'm8', 'int32', 'i4', 'uint32', 'u4', 'int16', - 'i2', 'uint16', 'u2', 'int8', 'i1', 'uint8', - 'u1', 'complex_', 'int0', 'uint0', 'single', - 'csingle', 'singlecomplex', 'float_', 'intc', - 'uintc', 'int_', 'longfloat', 'clongfloat', - 'longcomplex', 'bool_', 'int', 'float', - 'complex', 'bool'] - np.random.seed(0) - for dtype in normal_dtypes: - v = np.random.uniform(-100, 100, size=()).astype(dtype)[()] - restored_v = serialization.msgpack_restore( - serialization.msgpack_serialize(v)) - self.assertEqual(restored_v.dtype, v.dtype) - np.testing.assert_array_equal(restored_v, v) - - for shape in [(), (5,), (10, 10), (1, 20, 30, 1)]: - arr = np.random.uniform(-100, 100, size=shape).astype(dtype) - restored_arr = serialization.msgpack_restore( - serialization.msgpack_serialize(arr)) - self.assertEqual(restored_arr.dtype, arr.dtype) - np.testing.assert_array_equal(restored_arr, arr) - - def test_jax_numpy_serialization(self): - jax_dtypes = [jnp.bool_, jnp.uint8, jnp.uint16, jnp.uint32, - jnp.int8, jnp.int16, jnp.int32, - jnp.bfloat16, jnp.float16, jnp.float32, - jnp.complex64] - for dtype in jax_dtypes: - v = jnp.array( - np.random.uniform(-100, 100, size=())).astype(dtype)[()] - restored_v = serialization.msgpack_restore( - serialization.msgpack_serialize(v)) - self.assertEqual(restored_v.dtype, v.dtype) - np.testing.assert_array_equal(restored_v, v) - - for shape in [(), (5,), (10, 10), (1, 20, 30, 1)]: - arr = jnp.array( - np.random.uniform(-100, 100, size=shape)).astype(dtype) - restored_arr = serialization.msgpack_restore( - serialization.msgpack_serialize(arr)) - self.assertEqual(restored_arr.dtype, arr.dtype) - np.testing.assert_array_equal(restored_arr, arr) - - def test_complex_serialization(self): - for x in [1j, 1+2j]: - restored_x = serialization.msgpack_restore( - serialization.msgpack_serialize(x)) - self.assertEqual(x, restored_x) - - def test_restore_chunked(self): - old_chunksize = serialization.MAX_CHUNK_SIZE - serialization.MAX_CHUNK_SIZE = 91 * 8 - try: - tmp = np.random.uniform(-100, 100, size=(21, 37)) - serialized = serialization.to_bytes(tmp) - restored = serialization.msgpack_restore(serialized) - finally: - serialization.MAX_CHUNK_SIZE = old_chunksize - - np.testing.assert_array_equal(restored, tmp) - - def test_restore_unchunked(self): - """Check if mgspack_restore works for unchunked inputs.""" - def msgpack_serialize_legacy(pytree): - """Old implementation that was not chunking.""" - return msgpack.packb(pytree, default=serialization._msgpack_ext_pack, - strict_types=True) - - tmp = np.random.uniform(-100, 100, size=(21, 37)) - serialized = msgpack_serialize_legacy(tmp) - old_chunksize = serialization.MAX_CHUNK_SIZE - serialization.MAX_CHUNK_SIZE = 91 * 8 - try: - restored = serialization.msgpack_restore(serialized) - finally: - serialization.MAX_CHUNK_SIZE = old_chunksize - - np.testing.assert_array_equal(restored, tmp) - - def test_namedtuple_serialization(self): - foo_class = collections.namedtuple('Foo', 'a b c') - x1 = foo_class(a=1, b=2, c=3) - x1_serialized = serialization.to_bytes(x1) - x2 = foo_class(a=0, b=0, c=0) - restored_x1 = serialization.from_bytes(x2, x1_serialized) - self.assertEqual(type(x1), type(restored_x1)) - self.assertEqual(x1, restored_x1) - - def test_namedtuple_restore_legacy(self): - foo_class = collections.namedtuple('Foo', 'a b c') - x1 = foo_class(a=1, b=2, c=3) - legacy_encoding = { - 'name': 'Foo', - 'fields': {'0': 'a', '1': 'b', '2': 'c'}, - 'values': {'0': 1, '1': 2, '2': 3}, - } - x2 = foo_class(a=0, b=0, c=0) - restored_x1 = serialization.from_state_dict(x2, legacy_encoding) - self.assertEqual(type(x1), type(restored_x1)) - self.assertEqual(x1, restored_x1) - - def test_model_serialization_to_bytes(self): - rng = random.PRNGKey(0) - module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones) - _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)]) - model = nn.Model(module, initial_params) - serialized_bytes = serialization.to_bytes(model) - restored_model = serialization.from_bytes(model, serialized_bytes) - self.assertEqual(restored_model.params, model.params) - - def test_optimizer_serialization_to_bytes(self): - rng = random.PRNGKey(0) - module = nn.Dense.partial(features=1, kernel_init=nn.initializers.ones) - _, initial_params = module.init_by_shape(rng, [((1, 1), jnp.float32)]) - model = nn.Model(module, initial_params) - optim_def = optim.Momentum(learning_rate=1.) - optimizer = optim_def.create(model) - serialized_bytes = serialization.to_bytes(optimizer) - restored_optimizer = serialization.from_bytes(optimizer, serialized_bytes) - self.assertEqual(restored_optimizer, optimizer) - - def test_serialization_chunking(self): - old_chunksize = serialization.MAX_CHUNK_SIZE - serialization.MAX_CHUNK_SIZE = 91 * 8 - try: - tmp = {'a': np.ones((10, 10))} - tmp = serialization._chunk_array_leaves_in_place(tmp) - finally: - serialization.MAX_CHUNK_SIZE = old_chunksize - test = jax.tree_map(jnp.shape, tmp) - ref = {'a': { - '__msgpack_chunked_array__': (), - 'chunks': {'0': (91,), '1': (9,)}, - 'shape': {'0': (), '1': ()}} - } - self.assertEqual(test, ref) - - def test_serialization_chunking2(self): - old_chunksize = serialization.MAX_CHUNK_SIZE - serialization.MAX_CHUNK_SIZE = 91 * 8 - try: - tmp = {'a': np.ones((10, 10))} - tmpbytes = serialization.to_bytes(tmp) - newtmp = serialization.from_bytes(tmp, tmpbytes) - finally: - serialization.MAX_CHUNK_SIZE = old_chunksize - jax.tree_multimap(np.testing.assert_array_equal, tmp, newtmp) - - def test_serialization_chunking3(self): - old_chunksize = serialization.MAX_CHUNK_SIZE - serialization.MAX_CHUNK_SIZE = 91 * 8 - try: - tmp = {'a': np.ones((10, 10))} - tmpbytes = serialization.msgpack_serialize(tmp) - newtmp = serialization.msgpack_restore(tmpbytes) - finally: - serialization.MAX_CHUNK_SIZE = old_chunksize - - jax.tree_multimap(np.testing.assert_array_equal, tmp, newtmp) - - -if __name__ == '__main__': - absltest.main() diff --git a/tests/traverse_util_test.py b/tests/traverse_util_test.py index 26a801a6b..8ac2b57df 100644 --- a/tests/traverse_util_test.py +++ b/tests/traverse_util_test.py @@ -147,28 +147,25 @@ def test_flatten_dict(self): xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} flat_xs = traverse_util.flatten_dict(xs) self.assertEqual(flat_xs, { - ('foo',): 1, - ('bar', 'a'): 2, + ('foo',): 1, + ('bar', 'a'): 2, }) def test_unflatten_dict(self): flat_xs = { - ('foo',): 1, - ('bar', 'a'): 2, + ('foo',): 1, + ('bar', 'a'): 2, } xs = traverse_util.unflatten_dict(flat_xs) - self.assertEqual(xs, { - 'foo': 1, - 'bar': {'a': 2} - }) + self.assertEqual(xs, {'foo': 1, 'bar': {'a': 2}}) def test_flatten_dict_keep_empty(self): xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} flat_xs = traverse_util.flatten_dict(xs, keep_empty_nodes=True) self.assertEqual(flat_xs, { - ('foo',): 1, - ('bar', 'a'): 2, - ('bar', 'b'): traverse_util.empty_node, + ('foo',): 1, + ('bar', 'a'): 2, + ('bar', 'b'): traverse_util.empty_node, }) xs_restore = traverse_util.unflatten_dict(flat_xs) self.assertEqual(xs, xs_restore) @@ -179,8 +176,11 @@ def test_flatten_dict_is_leaf(self): xs, is_leaf=lambda k, x: len(k) == 1 and len(x) == 2) self.assertEqual(flat_xs, { - ('foo', 'c'): 4, - ('bar',): {'a': 2, 'b': {}}, + ('foo', 'c'): 4, + ('bar',): { + 'a': 2, + 'b': {} + }, }) xs_restore = traverse_util.unflatten_dict(flat_xs) self.assertEqual(xs, xs_restore) @@ -222,13 +222,10 @@ def filter_fn(name, _): return 'kernel' in name traversal = traverse_util.ModelParamTraversal(filter_fn) - # Model - model = flax.nn.Model(None, params) - values = list(traversal.iterate(model)) + values = list(traversal.iterate(params)) configs = [ - (flax.nn.Model(None, params), flax.nn.Model(None, expected_params)), - (params, expected_params), - (flax.core.FrozenDict(params), flax.core.FrozenDict(expected_params)), + (params, expected_params), + (flax.core.FrozenDict(params), flax.core.FrozenDict(expected_params)), ] for model, expected_model in configs: self.assertEqual(values, [1, 3])