diff --git a/CHANGELOG.md b/CHANGELOG.md index 2bbcd7362..6a298007a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -82,7 +82,7 @@ vNext 0.8.0 ----- -- Added [NNX](https://github.com/google/flax/tree/main/flax/experimental/nnx#nnx), a neural network library for JAX that provides a simple yet powerful module system that adheres to standard Python semantics. Its aim is to combine the robustness of Linen with a simplified, Pythonic API akin to that of PyTorch. +- Added [NNX](https://github.com/google/flax/tree/main/flax/nnx#nnx), a neural network library for JAX that provides a simple yet powerful module system that adheres to standard Python semantics. Its aim is to combine the robustness of Linen with a simplified, Pythonic API akin to that of PyTorch. - Added `nn.compact_name_scope` decorator that enables methods to act as compact name scopes as with regular Haiku methods. This makes porting Haiku code easier. - Add copy() method to Module. This is a user-friendly version of the internal clone() method with better defaults for common use cases. diff --git a/README.md b/README.md index f5dd4b09b..6cc80978d 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,8 @@ | [**What does Flax look like?**](#what-does-flax-look-like) | [**Documentation**](https://flax.readthedocs.io/) +**📣 NEW**: Check out the [**NNX**](https://flax.readthedocs.io/en/latest/nnx/index.html) API! + This README is a very short intro. **To learn everything you need to know about Flax, refer to our [full documentation](https://flax.readthedocs.io/).** Flax was originally started by engineers and researchers within the Brain Team in Google Research (in close collaboration with the JAX team), and is now developed jointly with the open source community. diff --git a/docs/api_reference/flax.experimental.nnx/nn/stochastic.rst b/docs/api_reference/flax.experimental.nnx/nn/stochastic.rst deleted file mode 100644 index 975b8bdb9..000000000 --- a/docs/api_reference/flax.experimental.nnx/nn/stochastic.rst +++ /dev/null @@ -1,8 +0,0 @@ -Stochastic ------------------------- - -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx - -.. autoclass:: Dropout - :members: \ No newline at end of file diff --git a/docs/api_reference/flax.experimental.nnx/training/optimizer.rst b/docs/api_reference/flax.experimental.nnx/training/optimizer.rst deleted file mode 100644 index a17b74e99..000000000 --- a/docs/api_reference/flax.experimental.nnx/training/optimizer.rst +++ /dev/null @@ -1,8 +0,0 @@ -Optimizer ------------------------- - -.. automodule:: flax.experimental.nnx.optimizer -.. currentmodule:: flax.experimental.nnx.optimizer - -.. autoclass:: Optimizer - :members: diff --git a/docs/api_reference/flax.experimental.nnx/visualization.rst b/docs/api_reference/flax.experimental.nnx/visualization.rst deleted file mode 100644 index 0bbbb9872..000000000 --- a/docs/api_reference/flax.experimental.nnx/visualization.rst +++ /dev/null @@ -1,7 +0,0 @@ -visualization ------------------------- - -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx - -.. autofunction:: display \ No newline at end of file diff --git a/docs/api_reference/flax.experimental.nnx/graph.rst b/docs/api_reference/flax.nnx/graph.rst similarity index 81% rename from docs/api_reference/flax.experimental.nnx/graph.rst rename to docs/api_reference/flax.nnx/graph.rst index a2d21b60e..35d3939db 100644 --- a/docs/api_reference/flax.experimental.nnx/graph.rst +++ b/docs/api_reference/flax.nnx/graph.rst @@ -1,8 +1,8 @@ graph ------------------------ -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autofunction:: split diff --git a/docs/api_reference/flax.experimental.nnx/helpers.rst b/docs/api_reference/flax.nnx/helpers.rst similarity index 69% rename from docs/api_reference/flax.experimental.nnx/helpers.rst rename to docs/api_reference/flax.nnx/helpers.rst index c0413acf5..f2b67522d 100644 --- a/docs/api_reference/flax.experimental.nnx/helpers.rst +++ b/docs/api_reference/flax.nnx/helpers.rst @@ -1,8 +1,8 @@ helpers ------------------------ -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autoclass:: Dict :members: diff --git a/docs/api_reference/flax.experimental.nnx/index.rst b/docs/api_reference/flax.nnx/index.rst similarity index 73% rename from docs/api_reference/flax.experimental.nnx/index.rst rename to docs/api_reference/flax.nnx/index.rst index fb90e3d4e..37a22d311 100644 --- a/docs/api_reference/flax.experimental.nnx/index.rst +++ b/docs/api_reference/flax.nnx/index.rst @@ -1,7 +1,7 @@ -flax.experimental.nnx +flax.nnx ------------------------ -Experimental API. See the `NNX page `__ for more details. +Experimental API. See the `NNX page `__ for more details. .. toctree:: :maxdepth: 3 diff --git a/docs/api_reference/flax.experimental.nnx/module.rst b/docs/api_reference/flax.nnx/module.rst similarity index 61% rename from docs/api_reference/flax.experimental.nnx/module.rst rename to docs/api_reference/flax.nnx/module.rst index ffdff78a8..9e58068a8 100644 --- a/docs/api_reference/flax.experimental.nnx/module.rst +++ b/docs/api_reference/flax.nnx/module.rst @@ -1,8 +1,8 @@ module ------------------------ -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autoclass:: Module :members: diff --git a/docs/api_reference/flax.experimental.nnx/nn/activations.rst b/docs/api_reference/flax.nnx/nn/activations.rst similarity index 89% rename from docs/api_reference/flax.experimental.nnx/nn/activations.rst rename to docs/api_reference/flax.nnx/nn/activations.rst index 0464975fe..db20ceb4d 100644 --- a/docs/api_reference/flax.experimental.nnx/nn/activations.rst +++ b/docs/api_reference/flax.nnx/nn/activations.rst @@ -1,8 +1,8 @@ Activation functions ------------------------ -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autofunction:: celu .. autofunction:: elu diff --git a/docs/api_reference/flax.experimental.nnx/nn/attention.rst b/docs/api_reference/flax.nnx/nn/attention.rst similarity index 74% rename from docs/api_reference/flax.experimental.nnx/nn/attention.rst rename to docs/api_reference/flax.nnx/nn/attention.rst index a2137ac88..3a10c7728 100644 --- a/docs/api_reference/flax.experimental.nnx/nn/attention.rst +++ b/docs/api_reference/flax.nnx/nn/attention.rst @@ -1,8 +1,8 @@ Attention ------------------------ -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autoclass:: MultiHeadAttention :members: diff --git a/docs/api_reference/flax.experimental.nnx/nn/index.rst b/docs/api_reference/flax.nnx/nn/index.rst similarity index 77% rename from docs/api_reference/flax.experimental.nnx/nn/index.rst rename to docs/api_reference/flax.nnx/nn/index.rst index a179948da..abe4da330 100644 --- a/docs/api_reference/flax.experimental.nnx/nn/index.rst +++ b/docs/api_reference/flax.nnx/nn/index.rst @@ -1,7 +1,7 @@ nn ---------------------------- -Experimental API. See the `NNX page `__ for more details. +Experimental API. See the `NNX page `__ for more details. .. toctree:: :maxdepth: 3 diff --git a/docs/api_reference/flax.experimental.nnx/nn/initializers.rst b/docs/api_reference/flax.nnx/nn/initializers.rst similarity index 86% rename from docs/api_reference/flax.experimental.nnx/nn/initializers.rst rename to docs/api_reference/flax.nnx/nn/initializers.rst index 0468f1870..a5734d8a4 100644 --- a/docs/api_reference/flax.experimental.nnx/nn/initializers.rst +++ b/docs/api_reference/flax.nnx/nn/initializers.rst @@ -1,8 +1,8 @@ Initializers ------------------------ -.. automodule:: flax.experimental.nnx.initializers -.. currentmodule:: flax.experimental.nnx.initializers +.. automodule:: flax.nnx.initializers +.. currentmodule:: flax.nnx.initializers .. autofunction:: constant .. autofunction:: delta_orthogonal diff --git a/docs/api_reference/flax.experimental.nnx/nn/linear.rst b/docs/api_reference/flax.nnx/nn/linear.rst similarity index 75% rename from docs/api_reference/flax.experimental.nnx/nn/linear.rst rename to docs/api_reference/flax.nnx/nn/linear.rst index 3206c4e84..057682069 100644 --- a/docs/api_reference/flax.experimental.nnx/nn/linear.rst +++ b/docs/api_reference/flax.nnx/nn/linear.rst @@ -3,8 +3,8 @@ Linear NNX linear layer classes. -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autoclass:: Conv :members: diff --git a/docs/api_reference/flax.experimental.nnx/nn/normalization.rst b/docs/api_reference/flax.nnx/nn/normalization.rst similarity index 65% rename from docs/api_reference/flax.experimental.nnx/nn/normalization.rst rename to docs/api_reference/flax.nnx/nn/normalization.rst index 402fa8376..c35bc5e0f 100644 --- a/docs/api_reference/flax.experimental.nnx/nn/normalization.rst +++ b/docs/api_reference/flax.nnx/nn/normalization.rst @@ -1,8 +1,8 @@ Normalization ------------------------ -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autoclass:: BatchNorm :members: diff --git a/docs/api_reference/flax.nnx/nn/stochastic.rst b/docs/api_reference/flax.nnx/nn/stochastic.rst new file mode 100644 index 000000000..70f7c497a --- /dev/null +++ b/docs/api_reference/flax.nnx/nn/stochastic.rst @@ -0,0 +1,8 @@ +Stochastic +------------------------ + +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx + +.. autoclass:: Dropout + :members: \ No newline at end of file diff --git a/docs/api_reference/flax.experimental.nnx/rnglib.rst b/docs/api_reference/flax.nnx/rnglib.rst similarity index 57% rename from docs/api_reference/flax.experimental.nnx/rnglib.rst rename to docs/api_reference/flax.nnx/rnglib.rst index 9defbc76f..2db1d6d63 100644 --- a/docs/api_reference/flax.experimental.nnx/rnglib.rst +++ b/docs/api_reference/flax.nnx/rnglib.rst @@ -1,8 +1,8 @@ rnglib ------------------------ -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autoclass:: Rngs :members: diff --git a/docs/api_reference/flax.experimental.nnx/spmd.rst b/docs/api_reference/flax.nnx/spmd.rst similarity index 69% rename from docs/api_reference/flax.experimental.nnx/spmd.rst rename to docs/api_reference/flax.nnx/spmd.rst index ed7af7f69..3429d898c 100644 --- a/docs/api_reference/flax.experimental.nnx/spmd.rst +++ b/docs/api_reference/flax.nnx/spmd.rst @@ -1,8 +1,8 @@ spmd ------------------------ -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autofunction:: get_partition_spec .. autofunction:: get_named_sharding diff --git a/docs/api_reference/flax.experimental.nnx/training/index.rst b/docs/api_reference/flax.nnx/training/index.rst similarity index 71% rename from docs/api_reference/flax.experimental.nnx/training/index.rst rename to docs/api_reference/flax.nnx/training/index.rst index c9bb4aa39..32404f1de 100644 --- a/docs/api_reference/flax.experimental.nnx/training/index.rst +++ b/docs/api_reference/flax.nnx/training/index.rst @@ -1,7 +1,7 @@ training ---------------------------- -Experimental API. See the `NNX page `__ for more details. +Experimental API. See the `NNX page `__ for more details. .. toctree:: :maxdepth: 3 diff --git a/docs/api_reference/flax.experimental.nnx/training/metrics.rst b/docs/api_reference/flax.nnx/training/metrics.rst similarity index 65% rename from docs/api_reference/flax.experimental.nnx/training/metrics.rst rename to docs/api_reference/flax.nnx/training/metrics.rst index f0e5ea201..e60c9d1c1 100644 --- a/docs/api_reference/flax.experimental.nnx/training/metrics.rst +++ b/docs/api_reference/flax.nnx/training/metrics.rst @@ -1,8 +1,8 @@ Metrics ------------------------ -.. automodule:: flax.experimental.nnx.metrics -.. currentmodule:: flax.experimental.nnx.metrics +.. automodule:: flax.nnx.metrics +.. currentmodule:: flax.nnx.metrics .. autoclass:: Metric :members: diff --git a/docs/api_reference/flax.nnx/training/optimizer.rst b/docs/api_reference/flax.nnx/training/optimizer.rst new file mode 100644 index 000000000..15966a1a2 --- /dev/null +++ b/docs/api_reference/flax.nnx/training/optimizer.rst @@ -0,0 +1,8 @@ +Optimizer +------------------------ + +.. automodule:: flax.nnx.optimizer +.. currentmodule:: flax.nnx.optimizer + +.. autoclass:: Optimizer + :members: diff --git a/docs/api_reference/flax.experimental.nnx/transforms.rst b/docs/api_reference/flax.nnx/transforms.rst similarity index 82% rename from docs/api_reference/flax.experimental.nnx/transforms.rst rename to docs/api_reference/flax.nnx/transforms.rst index bdf105fee..6750a109d 100644 --- a/docs/api_reference/flax.experimental.nnx/transforms.rst +++ b/docs/api_reference/flax.nnx/transforms.rst @@ -1,8 +1,8 @@ transforms ------------------------ -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autoclass:: JIT :members: diff --git a/docs/api_reference/flax.experimental.nnx/variables.rst b/docs/api_reference/flax.nnx/variables.rst similarity index 80% rename from docs/api_reference/flax.experimental.nnx/variables.rst rename to docs/api_reference/flax.nnx/variables.rst index b9f3d1dc5..54e442463 100644 --- a/docs/api_reference/flax.experimental.nnx/variables.rst +++ b/docs/api_reference/flax.nnx/variables.rst @@ -1,8 +1,8 @@ variables ------------------------ -.. automodule:: flax.experimental.nnx -.. currentmodule:: flax.experimental.nnx +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx .. autoclass:: BatchStat :members: diff --git a/docs/api_reference/flax.nnx/visualization.rst b/docs/api_reference/flax.nnx/visualization.rst new file mode 100644 index 000000000..a189aae52 --- /dev/null +++ b/docs/api_reference/flax.nnx/visualization.rst @@ -0,0 +1,7 @@ +visualization +------------------------ + +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx + +.. autofunction:: display \ No newline at end of file diff --git a/docs/api_reference/index.rst b/docs/api_reference/index.rst index 8448f316a..2c0d36025 100644 --- a/docs/api_reference/index.rst +++ b/docs/api_reference/index.rst @@ -8,7 +8,7 @@ API Reference flax.core.frozen_dict flax.cursor flax.errors - flax.experimental.nnx/index + flax.nnx/index flax.jax_utils flax.linen/index flax.serialization diff --git a/docs/conf.py b/docs/conf.py index 2ee6faca2..93d3d7009 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -110,6 +110,16 @@ html_extra_path = ['robots.txt'] +# href with no underline and white bold text color +announcement = """ + + 📣 Check out the new NNX API! + +""" + html_theme_options = { 'repository_url': 'https://github.com/google/flax', 'use_repository_button': True, # add a 'link to repository' button @@ -122,6 +132,7 @@ }, 'prev_next_buttons_location': None, 'show_navbar_depth': 1, + 'announcement': announcement, } # -- Options for myst ---------------------------------------------- @@ -135,7 +146,7 @@ nb_execution_excludepatterns = [ 'quick_start.ipynb', # <-- times out 'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0 - 'flax/experimental/nnx', # exclude nnx + 'flax/nnx', # exclude nnx ] # raise exceptions on execution so CI can catch errors nb_execution_allow_errors = False @@ -151,7 +162,7 @@ doctest_global_setup = """ import jax import jax.numpy as jnp -from flax.experimental import nnx +from flax import nnx import logging as slog from absl import logging as alog diff --git a/docs/experimental/index.rst b/docs/experimental/index.rst deleted file mode 100644 index 368491ce3..000000000 --- a/docs/experimental/index.rst +++ /dev/null @@ -1,7 +0,0 @@ -Experimental -============= - -.. toctree:: - :maxdepth: 2 - - nnx/index \ No newline at end of file diff --git a/docs/guides/flax_fundamentals/flax_basics.ipynb b/docs/guides/flax_fundamentals/flax_basics.ipynb index e20069aeb..e8e43f21c 100644 --- a/docs/guides/flax_fundamentals/flax_basics.ipynb +++ b/docs/guides/flax_fundamentals/flax_basics.ipynb @@ -951,7 +951,7 @@ "source": [ "### Exporting to Tensorflow's SavedModel with jax2tf\n", "\n", - "JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/experimental/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax." + "JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax." ] } ], diff --git a/docs/guides/flax_fundamentals/flax_basics.md b/docs/guides/flax_fundamentals/flax_basics.md index 0ce0f6f77..52755e9b5 100644 --- a/docs/guides/flax_fundamentals/flax_basics.md +++ b/docs/guides/flax_fundamentals/flax_basics.md @@ -469,4 +469,4 @@ Flax provides a handy wrapper - `TrainState` - that simplifies the above code. C ### Exporting to Tensorflow's SavedModel with jax2tf -JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/experimental/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax. +JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax. diff --git a/docs/index.rst b/docs/index.rst index be6781e82..75f5d985f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -28,7 +28,7 @@ both in the open source community (like `Hugging Face `__) and at Google (like -`PaLM `__, +`Gemini `__, `Imagen `__, `Scenic `__, and `Big Vision `__). @@ -309,6 +309,8 @@ Notable examples in Flax include: +.. role:: bold + :class: bold .. toctree:: :hidden: @@ -325,4 +327,4 @@ Notable examples in Flax include: contributing experimental api_reference/index - experimental/index + NNX diff --git a/docs/experimental/nnx/index.rst b/docs/nnx/index.rst similarity index 69% rename from docs/experimental/nnx/index.rst rename to docs/nnx/index.rst index 9a7defeeb..5865e6c17 100644 --- a/docs/experimental/nnx/index.rst +++ b/docs/nnx/index.rst @@ -3,13 +3,11 @@ NNX ======== -NNX is a JAX-based neural network library designed for simplicity and power. Its modular -approach follows standard Python conventions, making it both intuitive and compatible with -the broader JAX ecosystem. - -.. note:: - NNX is currently in an experimental state and is subject to change. Linen is still the - recommended option for large-scale projects. Feedback and contributions are welcome! +NNX is a **N**\ eural **N**\ etwork library for JA\ **X** that focuses on providing the best +development experience, so building and experimenting with neural networks is easy and +intuitive. It achieves this by embracing Python’s object-oriented model and making it +compatible with JAX transforms, resulting in code that is easy to inspect, debug, and +analyze. Features ^^^^^^^^^ @@ -26,47 +24,47 @@ Features .. div:: sd-font-normal - Modules are standard Python classes, promoting ease of use and a more familiar - development experience. + NNX supports the use or regular Python object, providing an intuitive + and predictable development experience. .. grid-item:: :columns: 12 12 12 6 - .. card:: Compatible + .. card:: Simple :class-card: sd-border-0 :shadow: none :class-title: sd-fs-5 .. div:: sd-font-normal - Effortlessly convert between Modules and pytrees using the Functional API for maximum - flexibility. + NNX relies on Python's object model, this results in simplicity for + the user which increases development speed. .. grid-item:: :columns: 12 12 12 6 - .. card:: Control + .. card:: Streamlined :class-card: sd-border-0 :shadow: none :class-title: sd-fs-5 .. div:: sd-font-normal - Manage a Module's state with precision using typed Variable collections, enabling fine-grained - control on JAX transformations. + NNX integrates of user feedback and hands-on experience with Linen + into a new simplified API. .. grid-item:: :columns: 12 12 12 6 - .. card:: User-friendly + .. card:: Compatible :class-card: sd-border-0 :shadow: none :class-title: sd-fs-5 .. div:: sd-font-normal - NNX prioritizes simplicity for common use cases, building upon lessons learned from Linen - to provide a streamlined experience. + NNX makes it very easy integrate objects with regular JAX code + via the `Functional API `__. Basic usage ^^^^^^^^^^^^ @@ -78,7 +76,7 @@ Basic usage .. testcode:: - from flax.experimental import nnx + from flax import nnx import optax @@ -110,7 +108,14 @@ Basic usage Installation ^^^^^^^^^^^^ -NNX is under active development, we recommend using the latest version from Flax's GitHub repository: + +Install NNX via pip: + +.. code-block:: bash + + pip install flax + +Or install the latest version from the repository: .. code-block:: bash @@ -150,7 +155,7 @@ Learn more .. card:: :material-regular:`menu_book;2em` API reference :class-card: sd-text-black sd-bg-light - :link: ../../api_reference/index.html + :link: ../api_reference/flax.nnx/index.html ---- diff --git a/docs/experimental/nnx/mnist_tutorial.ipynb b/docs/nnx/mnist_tutorial.ipynb similarity index 99% rename from docs/experimental/nnx/mnist_tutorial.ipynb rename to docs/nnx/mnist_tutorial.ipynb index c143ca57b..6f990696e 100644 --- a/docs/experimental/nnx/mnist_tutorial.ipynb +++ b/docs/nnx/mnist_tutorial.ipynb @@ -132,7 +132,7 @@ } ], "source": [ - "from flax.experimental import nnx # NNX API\n", + "from flax import nnx # NNX API\n", "from functools import partial\n", "\n", "class CNN(nnx.Module):\n", @@ -297,7 +297,7 @@ "id": "17", "metadata": {}, "source": [ - "The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.experimental.nnx/transforms.html#flax.experimental.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with \n", + "The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with \n", "[XLA](https://www.tensorflow.org/xla), optimizing performance on \n", "hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit),\n", "except it can transforms functions that contain NNX objects as inputs and outputs.\n", diff --git a/docs/experimental/nnx/mnist_tutorial.md b/docs/nnx/mnist_tutorial.md similarity index 98% rename from docs/experimental/nnx/mnist_tutorial.md rename to docs/nnx/mnist_tutorial.md index e6510d239..3c4ba0955 100644 --- a/docs/experimental/nnx/mnist_tutorial.md +++ b/docs/nnx/mnist_tutorial.md @@ -77,7 +77,7 @@ test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) Create a convolutional neural network with NNX by subclassing `nnx.Module`. ```{code-cell} ipython3 -from flax.experimental import nnx # NNX API +from flax import nnx # NNX API from functools import partial class CNN(nnx.Module): @@ -163,7 +163,7 @@ def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, b optimizer.update(grads) ``` -The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.experimental.nnx/transforms.html#flax.experimental.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with +The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with [XLA](https://www.tensorflow.org/xla), optimizing performance on hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit), except it can transforms functions that contain NNX objects as inputs and outputs. diff --git a/docs/experimental/nnx/nnx_basics.ipynb b/docs/nnx/nnx_basics.ipynb similarity index 99% rename from docs/experimental/nnx/nnx_basics.ipynb rename to docs/nnx/nnx_basics.ipynb index 3a90e0d83..d03318442 100644 --- a/docs/experimental/nnx/nnx_basics.ipynb +++ b/docs/nnx/nnx_basics.ipynb @@ -23,7 +23,7 @@ "metadata": {}, "outputs": [], "source": [ - "from flax.experimental import nnx\n", + "from flax import nnx\n", "import jax\n", "import jax.numpy as jnp" ] diff --git a/docs/experimental/nnx/nnx_basics.md b/docs/nnx/nnx_basics.md similarity index 99% rename from docs/experimental/nnx/nnx_basics.md rename to docs/nnx/nnx_basics.md index c27ae068a..b1c8841b9 100644 --- a/docs/experimental/nnx/nnx_basics.md +++ b/docs/nnx/nnx_basics.md @@ -21,7 +21,7 @@ Despite its simplified implementation, NNX supports the same powerful design pat that have allowed Linen to scale effectively to large codebases. ```{code-cell} ipython3 -from flax.experimental import nnx +from flax import nnx import jax import jax.numpy as jnp ``` diff --git a/docs/experimental/nnx/transforms.rst b/docs/nnx/transforms.rst similarity index 96% rename from docs/experimental/nnx/transforms.rst rename to docs/nnx/transforms.rst index 9f35afcc2..76e807f24 100644 --- a/docs/experimental/nnx/transforms.rst +++ b/docs/nnx/transforms.rst @@ -9,7 +9,7 @@ First, let's set up imports and generate some dummy data: .. testcode:: NNX, JAX - from flax.experimental import nnx + from flax import nnx import jax x = jax.random.normal(jax.random.key(0), (1, 2)) @@ -24,7 +24,7 @@ even those whose state will be mutated, whereas they aren't recognized in JAX tr Therefore NNX transformations can transform functions that are not pure and make mutations and side-effects. -NNX's `Functional API `_ +NNX's `Functional API `_ provides a way to convert graph structures to pytrees and back, by doing this at every function boundary you can effectively use graph structures with any JAX transform and propagate state updates in a way consistent with functional purity. NNX custom transforms such as ``nnx.jit`` and ``nnx.grad`` diff --git a/flax/experimental/nnx.py b/flax/experimental/nnx.py new file mode 100644 index 000000000..489991429 --- /dev/null +++ b/flax/experimental/nnx.py @@ -0,0 +1,22 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl import logging + +from flax.nnx import * + + +logging.warning( + "Using 'flax.experimental.nnx' is deprecated. Please use 'flax.nnx' instead." +) \ No newline at end of file diff --git a/flax/experimental/nnx/.gitignore b/flax/nnx/.gitignore similarity index 100% rename from flax/experimental/nnx/.gitignore rename to flax/nnx/.gitignore diff --git a/flax/experimental/nnx/README.md b/flax/nnx/README.md similarity index 73% rename from flax/experimental/nnx/README.md rename to flax/nnx/README.md index cc00e1358..854e0971d 100644 --- a/flax/experimental/nnx/README.md +++ b/flax/nnx/README.md @@ -2,7 +2,7 @@ # NNX -_**N**eural **N**etworks for JA**X**_ - | [docs](https://flax.readthedocs.io/en/latest/experimental/nnx/index.html) | +_**N**eural **N**etworks for JA**X**_ - | [docs](https://flax.readthedocs.io/en/latest/nnx/index.html) | NNX is a JAX-based neural network library that focuses on providing the best development experience to make building and experimenting with neural networks as easy and intuitive as possible. @@ -28,7 +28,7 @@ a Module system that uses standard Python classes, and a set of transforms that JAX to handle objects. ```python -from flax.experimental import nnx +from flax import nnx import optax class Model(nnx.Module): @@ -58,7 +58,7 @@ def train_step(model, optimizer, x, y): return loss ``` -To learn more about the `Module` abstraction, check out our [NNX Basics](https://flax.readthedocs.io/en/latest/experimental/nnx/nnx_basics.html#) guide. +To learn more about the `Module` abstraction, check out our [NNX Basics](https://flax.readthedocs.io/en/latest/nnx/nnx_basics.html#) guide. ## Installation @@ -69,10 +69,10 @@ pip install git+https://github.com/google/flax.git ### Examples -* [LM1B](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/lm1b): A language model trained on the 1 Billion Word Benchmark dataset. +* [LM1B](https://github.com/google/flax/tree/main/flax/nnx/examples/lm1b): A language model trained on the 1 Billion Word Benchmark dataset. #### Toy Examples -* [Basic Example](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/02_lifted_transforms.py): Shows how to train a simple model using NNX. -* [Using the Functional API](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/01_functional_api.py): Shows how to train a simple model using the functional API. -* [Training a VAE](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/05_vae.py): Shows how to train a VAE on the binarized MNIST dataset. -* [Scan over layers](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py): An contrived example that implements scan over layers with dropout and a share BatcNorm layer to showcase how lifted transforms can be implemented. It uses the functional API along with `jax.vmap` and `jax.lax.scan`. +* [Basic Example](https://github.com/google/flax/tree/main/flax/nnx/examples/toy_examples/02_lifted_transforms.py): Shows how to train a simple model using NNX. +* [Using the Functional API](https://github.com/google/flax/tree/main/flax/nnx/examples/toy_examples/01_functional_api.py): Shows how to train a simple model using the functional API. +* [Training a VAE](https://github.com/google/flax/tree/main/flax/nnx/examples/toy_examples/05_vae.py): Shows how to train a VAE on the binarized MNIST dataset. +* [Scan over layers](https://github.com/google/flax/tree/main/flax/nnx/examples/toy_examples/06_scan_over_layers.py): An contrived example that implements scan over layers with dropout and a share BatcNorm layer to showcase how lifted transforms can be implemented. It uses the functional API along with `jax.vmap` and `jax.lax.scan`. diff --git a/flax/experimental/nnx/__init__.py b/flax/nnx/__init__.py similarity index 100% rename from flax/experimental/nnx/__init__.py rename to flax/nnx/__init__.py diff --git a/flax/experimental/nnx/docs/blog.md b/flax/nnx/docs/blog.md similarity index 100% rename from flax/experimental/nnx/docs/blog.md rename to flax/nnx/docs/blog.md diff --git a/flax/experimental/nnx/docs/demo.ipynb b/flax/nnx/docs/demo.ipynb similarity index 99% rename from flax/experimental/nnx/docs/demo.ipynb rename to flax/nnx/docs/demo.ipynb index ae71ad479..a2521ef10 100644 --- a/flax/experimental/nnx/docs/demo.ipynb +++ b/flax/nnx/docs/demo.ipynb @@ -17,7 +17,7 @@ "source": [ "import jax\n", "from jax import numpy as jnp\n", - "from flax.experimental import nnx" + "from flax import nnx" ] }, { diff --git a/flax/experimental/nnx/docs/demo.md b/flax/nnx/docs/demo.md similarity index 99% rename from flax/experimental/nnx/docs/demo.md rename to flax/nnx/docs/demo.md index 5d02e5da7..f507f9c48 100644 --- a/flax/experimental/nnx/docs/demo.md +++ b/flax/nnx/docs/demo.md @@ -13,7 +13,7 @@ jupytext: ```{code-cell} ipython3 import jax from jax import numpy as jnp -from flax.experimental import nnx +from flax import nnx ``` ### [1] NNX is Pythonic diff --git a/flax/experimental/nnx/docs/images/stateful-transforms.png b/flax/nnx/docs/images/stateful-transforms.png similarity index 100% rename from flax/experimental/nnx/docs/images/stateful-transforms.png rename to flax/nnx/docs/images/stateful-transforms.png diff --git a/flax/experimental/nnx/docs/quick_start.ipynb b/flax/nnx/docs/quick_start.ipynb similarity index 99% rename from flax/experimental/nnx/docs/quick_start.ipynb rename to flax/nnx/docs/quick_start.ipynb index fc617db8a..df64361b4 100644 --- a/flax/experimental/nnx/docs/quick_start.ipynb +++ b/flax/nnx/docs/quick_start.ipynb @@ -146,7 +146,7 @@ "source": [ "import jax\n", "import jax.numpy as jnp\n", - "from flax.experimental import nnx\n", + "from flax import nnx\n", "\n", "\n", "class CNN(nnx.Module):\n", diff --git a/flax/experimental/nnx/docs/tiny_nnx.ipynb b/flax/nnx/docs/tiny_nnx.ipynb similarity index 100% rename from flax/experimental/nnx/docs/tiny_nnx.ipynb rename to flax/nnx/docs/tiny_nnx.ipynb diff --git a/flax/experimental/nnx/docs/why.ipynb b/flax/nnx/docs/why.ipynb similarity index 99% rename from flax/experimental/nnx/docs/why.ipynb rename to flax/nnx/docs/why.ipynb index 04cad17da..46caf8c4e 100644 --- a/flax/experimental/nnx/docs/why.ipynb +++ b/flax/nnx/docs/why.ipynb @@ -7,7 +7,7 @@ "# Why NNX?\n", "\n", "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/flax/experimental/nnx/docs/why.ipynb)\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/flax/nnx/docs/why.ipynb)\n", "\n", "Four years ago we developed the Flax \"Linen\" API to support modeling research on JAX, with a focus on scaling scaling and performance. We've learned a lot from our users over these years.\n", "\n", @@ -25,8 +25,8 @@ "\n", "We'd love to hear from any of our users about their thoughts on these ideas.\n", "\n", - "[[nnx on github](https://github.com/google/flax/tree/main/flax/experimental/nnx)]\n", - "[[this doc on github](https://github.com/google/flax/blob/main/flax/experimental/nnx/docs/why.ipynb)]" + "[[nnx on github](https://github.com/google/flax/tree/main/flax/nnx)]\n", + "[[this doc on github](https://github.com/google/flax/blob/main/flax/nnx/docs/why.ipynb)]" ] }, { @@ -39,7 +39,7 @@ "from functools import partial\n", "import jax\n", "from jax import random, numpy as jnp\n", - "from flax.experimental import nnx" + "from flax import nnx" ] }, { diff --git a/flax/experimental/nnx/docs/why.md b/flax/nnx/docs/why.md similarity index 98% rename from flax/experimental/nnx/docs/why.md rename to flax/nnx/docs/why.md index 3dce4ad63..07142c0f4 100644 --- a/flax/experimental/nnx/docs/why.md +++ b/flax/nnx/docs/why.md @@ -11,7 +11,7 @@ jupytext: # Why NNX? -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/flax/experimental/nnx/docs/why.ipynb) +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/flax/nnx/docs/why.ipynb) Four years ago we developed the Flax "Linen" API to support modeling research on JAX, with a focus on scaling scaling and performance. We've learned a lot from our users over these years. @@ -29,15 +29,15 @@ NNX is an attempt to keep the features that made Linen useful while introducing We'd love to hear from any of our users about their thoughts on these ideas. -[[nnx on github](https://github.com/google/flax/tree/main/flax/experimental/nnx)] -[[this doc on github](https://github.com/google/flax/blob/main/flax/experimental/nnx/docs/why.ipynb)] +[[nnx on github](https://github.com/google/flax/tree/main/flax/nnx)] +[[this doc on github](https://github.com/google/flax/blob/main/flax/nnx/docs/why.ipynb)] ```{code-cell} ! pip install -U git+https://github.com/google/flax.git from functools import partial import jax from jax import random, numpy as jnp -from flax.experimental import nnx +from flax import nnx ``` ### NNX is Pythonic diff --git a/flax/experimental/nnx/examples/lm1b/README.md b/flax/nnx/examples/lm1b/README.md similarity index 100% rename from flax/experimental/nnx/examples/lm1b/README.md rename to flax/nnx/examples/lm1b/README.md diff --git a/flax/experimental/nnx/examples/lm1b/configs/default.py b/flax/nnx/examples/lm1b/configs/default.py similarity index 100% rename from flax/experimental/nnx/examples/lm1b/configs/default.py rename to flax/nnx/examples/lm1b/configs/default.py diff --git a/flax/experimental/nnx/examples/lm1b/input_pipeline.py b/flax/nnx/examples/lm1b/input_pipeline.py similarity index 100% rename from flax/experimental/nnx/examples/lm1b/input_pipeline.py rename to flax/nnx/examples/lm1b/input_pipeline.py diff --git a/flax/experimental/nnx/examples/lm1b/input_pipeline_test.py b/flax/nnx/examples/lm1b/input_pipeline_test.py similarity index 98% rename from flax/experimental/nnx/examples/lm1b/input_pipeline_test.py rename to flax/nnx/examples/lm1b/input_pipeline_test.py index 4ead911fe..e6287fac0 100644 --- a/flax/experimental/nnx/examples/lm1b/input_pipeline_test.py +++ b/flax/nnx/examples/lm1b/input_pipeline_test.py @@ -46,7 +46,7 @@ def _get_datasets(self): vocab_path = os.path.join(tempfile.mkdtemp(), 'sentencepiece_model') # Go two directories up to the root of the flax directory. - flax_root_dir = pathlib.Path(__file__).parents[5] + flax_root_dir = pathlib.Path(__file__).parents[4] data_dir = str(flax_root_dir) + '/.tfds/metadata' # pylint: disable=unused-variable with tfds.testing.mock_data(num_examples=128, data_dir=data_dir): diff --git a/flax/experimental/nnx/examples/lm1b/main.py b/flax/nnx/examples/lm1b/main.py similarity index 100% rename from flax/experimental/nnx/examples/lm1b/main.py rename to flax/nnx/examples/lm1b/main.py diff --git a/flax/experimental/nnx/examples/lm1b/models.py b/flax/nnx/examples/lm1b/models.py similarity index 99% rename from flax/experimental/nnx/examples/lm1b/models.py rename to flax/nnx/examples/lm1b/models.py index 1731ec7f3..bb80e1eee 100644 --- a/flax/experimental/nnx/examples/lm1b/models.py +++ b/flax/nnx/examples/lm1b/models.py @@ -32,7 +32,7 @@ import numpy as np from jax import lax -from flax.experimental import nnx +from flax import nnx from configs import default Shape = tuple[int, ...] diff --git a/flax/experimental/nnx/examples/lm1b/models_test.py b/flax/nnx/examples/lm1b/models_test.py similarity index 99% rename from flax/experimental/nnx/examples/lm1b/models_test.py rename to flax/nnx/examples/lm1b/models_test.py index 76296ae50..cc377eb33 100644 --- a/flax/experimental/nnx/examples/lm1b/models_test.py +++ b/flax/nnx/examples/lm1b/models_test.py @@ -27,7 +27,7 @@ from jax import random from flax import traverse_util -from flax.experimental import nnx +from flax import nnx from configs import default from models import TransformerConfig, TransformerLM from utils import HasCache @@ -35,7 +35,7 @@ jax.config.update('jax_disable_most_optimizations', True) # add project_root to import lm1b Linen model -project_root = str(Path(__file__).absolute().parents[5]) +project_root = str(Path(__file__).absolute().parents[4]) sys.path.append(project_root) from examples.lm1b.models import TransformerLM as TransformerLinen # type: ignore[import-error] diff --git a/flax/experimental/nnx/examples/lm1b/requirements.txt b/flax/nnx/examples/lm1b/requirements.txt similarity index 100% rename from flax/experimental/nnx/examples/lm1b/requirements.txt rename to flax/nnx/examples/lm1b/requirements.txt diff --git a/flax/experimental/nnx/examples/lm1b/temperature_sampler.py b/flax/nnx/examples/lm1b/temperature_sampler.py similarity index 100% rename from flax/experimental/nnx/examples/lm1b/temperature_sampler.py rename to flax/nnx/examples/lm1b/temperature_sampler.py diff --git a/flax/experimental/nnx/examples/lm1b/temperature_sampler_test.py b/flax/nnx/examples/lm1b/temperature_sampler_test.py similarity index 100% rename from flax/experimental/nnx/examples/lm1b/temperature_sampler_test.py rename to flax/nnx/examples/lm1b/temperature_sampler_test.py diff --git a/flax/experimental/nnx/examples/lm1b/tokenizer.py b/flax/nnx/examples/lm1b/tokenizer.py similarity index 100% rename from flax/experimental/nnx/examples/lm1b/tokenizer.py rename to flax/nnx/examples/lm1b/tokenizer.py diff --git a/flax/experimental/nnx/examples/lm1b/train.py b/flax/nnx/examples/lm1b/train.py similarity index 99% rename from flax/experimental/nnx/examples/lm1b/train.py rename to flax/nnx/examples/lm1b/train.py index ed3f5986d..a137b9da1 100644 --- a/flax/experimental/nnx/examples/lm1b/train.py +++ b/flax/nnx/examples/lm1b/train.py @@ -42,7 +42,7 @@ from utils import HasCache, TrainState from flax import linen as nn -from flax.experimental import nnx +from flax import nnx from flax.training import checkpoints, common_utils @@ -605,9 +605,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array): lambda x: x / denominator, metrics_sums ) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr - summary['perplexity'] = jnp.clip( - jnp.exp(summary['loss']), max=1.0e4 - ) + summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), max=1.0e4) summary = {'train_' + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] diff --git a/flax/experimental/nnx/examples/lm1b/train_test.py b/flax/nnx/examples/lm1b/train_test.py similarity index 97% rename from flax/experimental/nnx/examples/lm1b/train_test.py rename to flax/nnx/examples/lm1b/train_test.py index 9040c4f26..1f135048d 100644 --- a/flax/experimental/nnx/examples/lm1b/train_test.py +++ b/flax/nnx/examples/lm1b/train_test.py @@ -59,7 +59,7 @@ def test_train_and_evaluate(self): workdir = tempfile.mkdtemp() # Go two directories up to the root of the flax directory. - flax_root_dir = pathlib.Path(__file__).parents[5] + flax_root_dir = pathlib.Path(__file__).parents[4] data_dir = str(flax_root_dir) + '/.tfds/metadata' # pylint: disable=unused-variable print('data_dir: ', data_dir) diff --git a/flax/experimental/nnx/examples/lm1b/utils.py b/flax/nnx/examples/lm1b/utils.py similarity index 99% rename from flax/experimental/nnx/examples/lm1b/utils.py rename to flax/nnx/examples/lm1b/utils.py index 1bf2d7d8c..d2afc3c3b 100644 --- a/flax/experimental/nnx/examples/lm1b/utils.py +++ b/flax/nnx/examples/lm1b/utils.py @@ -25,7 +25,7 @@ from configs import default from models import TransformerConfig, TransformerLM -from flax.experimental import nnx +from flax import nnx from flax.training import train_state Dtype = Any @@ -38,8 +38,7 @@ class TrainState(train_state.TrainState): @runtime_checkable class HasCache(Protocol): - def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32): - ... + def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32): ... # Mesh utils. diff --git a/flax/experimental/nnx/examples/toy_examples/01_functional_api.py b/flax/nnx/examples/toy_examples/01_functional_api.py similarity index 98% rename from flax/experimental/nnx/examples/toy_examples/01_functional_api.py rename to flax/nnx/examples/toy_examples/01_functional_api.py index bd6451555..8f90a24ef 100644 --- a/flax/experimental/nnx/examples/toy_examples/01_functional_api.py +++ b/flax/nnx/examples/toy_examples/01_functional_api.py @@ -18,7 +18,7 @@ import matplotlib.pyplot as plt import numpy as np -from flax.experimental import nnx +from flax import nnx X = np.linspace(0, 1, 100)[:, None] Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) diff --git a/flax/experimental/nnx/examples/toy_examples/02_lifted_transforms.py b/flax/nnx/examples/toy_examples/02_lifted_transforms.py similarity index 98% rename from flax/experimental/nnx/examples/toy_examples/02_lifted_transforms.py rename to flax/nnx/examples/toy_examples/02_lifted_transforms.py index a29efe153..bb2238f7a 100644 --- a/flax/experimental/nnx/examples/toy_examples/02_lifted_transforms.py +++ b/flax/nnx/examples/toy_examples/02_lifted_transforms.py @@ -19,7 +19,7 @@ import numpy as np import optax -from flax.experimental import nnx +from flax import nnx X = np.linspace(0, 1, 100)[:, None] Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) @@ -62,6 +62,7 @@ def __call__(self, x): tx = optax.sgd(1e-3) optimizer = nnx.Optimizer(model, tx) + @nnx.jit def train_step(model: MLP, optimizer: nnx.Optimizer, batch): x, y = batch diff --git a/flax/experimental/nnx/examples/toy_examples/05_vae.py b/flax/nnx/examples/toy_examples/05_vae.py similarity index 99% rename from flax/experimental/nnx/examples/toy_examples/05_vae.py rename to flax/nnx/examples/toy_examples/05_vae.py index 895dcd894..7819c8dbe 100644 --- a/flax/experimental/nnx/examples/toy_examples/05_vae.py +++ b/flax/nnx/examples/toy_examples/05_vae.py @@ -22,7 +22,7 @@ import optax from datasets import load_dataset -from flax.experimental import nnx +from flax import nnx np.random.seed(42) latent_size = 32 diff --git a/flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py b/flax/nnx/examples/toy_examples/06_scan_over_layers.py similarity index 98% rename from flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py rename to flax/nnx/examples/toy_examples/06_scan_over_layers.py index 9a2b01727..ad2b2edce 100644 --- a/flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py +++ b/flax/nnx/examples/toy_examples/06_scan_over_layers.py @@ -17,7 +17,7 @@ import jax import jax.numpy as jnp -from flax.experimental import nnx +from flax import nnx class Block(nnx.Module): diff --git a/flax/experimental/nnx/examples/toy_examples/08_save_load_checkpoints.py b/flax/nnx/examples/toy_examples/08_save_load_checkpoints.py similarity index 98% rename from flax/experimental/nnx/examples/toy_examples/08_save_load_checkpoints.py rename to flax/nnx/examples/toy_examples/08_save_load_checkpoints.py index 281a290f1..ea6907964 100644 --- a/flax/experimental/nnx/examples/toy_examples/08_save_load_checkpoints.py +++ b/flax/nnx/examples/toy_examples/08_save_load_checkpoints.py @@ -18,7 +18,7 @@ import jax.numpy as jnp import orbax.checkpoint as orbax -from flax.experimental import nnx +from flax import nnx class MLP(nnx.Module): diff --git a/flax/experimental/nnx/examples/toy_examples/09_parameter_surgery.py b/flax/nnx/examples/toy_examples/09_parameter_surgery.py similarity index 98% rename from flax/experimental/nnx/examples/toy_examples/09_parameter_surgery.py rename to flax/nnx/examples/toy_examples/09_parameter_surgery.py index c7f5dd07f..11a785aaa 100644 --- a/flax/experimental/nnx/examples/toy_examples/09_parameter_surgery.py +++ b/flax/nnx/examples/toy_examples/09_parameter_surgery.py @@ -15,7 +15,7 @@ import jax -from flax.experimental import nnx +from flax import nnx # lets pretend this function loads a pretrained model from a checkpoint diff --git a/flax/experimental/nnx/examples/toy_examples/requirements.txt b/flax/nnx/examples/toy_examples/requirements.txt similarity index 100% rename from flax/experimental/nnx/examples/toy_examples/requirements.txt rename to flax/nnx/examples/toy_examples/requirements.txt diff --git a/flax/experimental/nnx/nnx/__init__.py b/flax/nnx/nnx/__init__.py similarity index 100% rename from flax/experimental/nnx/nnx/__init__.py rename to flax/nnx/nnx/__init__.py diff --git a/flax/experimental/nnx/nnx/compat/__init__.py b/flax/nnx/nnx/compat/__init__.py similarity index 100% rename from flax/experimental/nnx/nnx/compat/__init__.py rename to flax/nnx/nnx/compat/__init__.py diff --git a/flax/experimental/nnx/nnx/compat/module.py b/flax/nnx/nnx/compat/module.py similarity index 94% rename from flax/experimental/nnx/nnx/compat/module.py rename to flax/nnx/nnx/compat/module.py index c152811a1..0af4d38f5 100644 --- a/flax/experimental/nnx/nnx/compat/module.py +++ b/flax/nnx/nnx/compat/module.py @@ -21,13 +21,13 @@ import typing as tp import typing_extensions as tpe -from flax.experimental.nnx.nnx import graph, rnglib -import flax.experimental.nnx.nnx.module as nnx_module -from flax.experimental.nnx.nnx.proxy_caller import ( +from flax.nnx.nnx import graph, rnglib +import flax.nnx.nnx.module as nnx_module +from flax.nnx.nnx.proxy_caller import ( CallableProxy, DelayedAccessor, ) -from flax.experimental.nnx.nnx.object import Object +from flax.nnx.nnx.object import Object M = tp.TypeVar('M', bound='Module') F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) @@ -141,8 +141,8 @@ def init(self: M) -> M: Example:: - >>> from flax.experimental import nnx - >>> from flax.experimental.nnx import compat as nnc + >>> from flax import nnx + >>> from flax.nnx import compat as nnc >>> import jax >>> import jax.numpy as jnp ... diff --git a/flax/experimental/nnx/nnx/compat/wrappers.py b/flax/nnx/nnx/compat/wrappers.py similarity index 90% rename from flax/experimental/nnx/nnx/compat/wrappers.py rename to flax/nnx/nnx/compat/wrappers.py index 50a954e65..27c889c41 100644 --- a/flax/experimental/nnx/nnx/compat/wrappers.py +++ b/flax/nnx/nnx/compat/wrappers.py @@ -16,12 +16,12 @@ import typing as tp from typing import Any -from flax.experimental import nnx +from flax import nnx from flax import linen -from flax.experimental.nnx.nnx import variables as variableslib -from flax.experimental.nnx.nnx.module import GraphDef, Module -from flax.experimental.nnx.nnx.rnglib import Rngs -from flax.experimental.nnx.nnx.state import State +from flax.nnx.nnx import variables as variableslib +from flax.nnx.nnx.module import GraphDef, Module +from flax.nnx.nnx.rnglib import Rngs +from flax.nnx.nnx.state import State M = tp.TypeVar('M', bound=Module) @@ -107,5 +107,4 @@ def __call__( return out -class NNXWrapper(linen.Module): - ... +class NNXWrapper(linen.Module): ... diff --git a/flax/experimental/nnx/nnx/errors.py b/flax/nnx/nnx/errors.py similarity index 100% rename from flax/experimental/nnx/nnx/errors.py rename to flax/nnx/nnx/errors.py diff --git a/flax/experimental/nnx/nnx/filterlib.py b/flax/nnx/nnx/filterlib.py similarity index 100% rename from flax/experimental/nnx/nnx/filterlib.py rename to flax/nnx/nnx/filterlib.py diff --git a/flax/experimental/nnx/nnx/graph.py b/flax/nnx/nnx/graph.py similarity index 99% rename from flax/experimental/nnx/nnx/graph.py rename to flax/nnx/nnx/graph.py index 957418513..b1efb090a 100644 --- a/flax/experimental/nnx/nnx/graph.py +++ b/flax/nnx/nnx/graph.py @@ -26,22 +26,22 @@ import numpy as np import typing_extensions as tpe -from flax.experimental.nnx.nnx import ( +from flax.nnx.nnx import ( filterlib, reprlib, ) -from flax.experimental.nnx.nnx.proxy_caller import ( +from flax.nnx.nnx.proxy_caller import ( ApplyCaller, CallableProxy, DelayedAccessor, ) -from flax.experimental.nnx.nnx.state import ( +from flax.nnx.nnx.state import ( FlatState, State, StateLeaf, is_state_leaf, ) -from flax.experimental.nnx.nnx.variables import Variable, VariableState +from flax.nnx.nnx.variables import Variable, VariableState from flax.typing import Key, PathParts A = tp.TypeVar('A') @@ -69,6 +69,7 @@ NodeLeaf = tp.Union[Variable[tp.Any], np.ndarray, jax.Array] + @dataclasses.dataclass class GraphContext(threading.local): update_context_stacks: defaultdict[str, list[UpdateContext]] = ( @@ -831,6 +832,12 @@ def _graph_update_static( node_impl.set_key(node, name, value_updates) + +# -------------------------------------------------------- +# UpdateContext +# -------------------------------------------------------- + + # -------------------------------------------------------- # UpdateContext # -------------------------------------------------------- @@ -987,6 +994,7 @@ def merge( jax.tree_util.register_static(UpdateContext) + @dataclasses.dataclass class UpdateContextManager: tag: str @@ -1054,7 +1062,7 @@ def update_context(tag: str): Here is a simple example showing the use of ``update_context``:: - >>> from flax.experimental import nnx + >>> from flax import nnx ... >>> m1 = nnx.Dict({}) >>> with nnx.update_context('example') as ctx: @@ -1078,7 +1086,7 @@ def update_context(tag: str): current active context. current_update_context can be used as a way of accessing the current active context without having to pass it as a capture:: - >>> from flax.experimental import nnx + >>> from flax import nnx ... >>> m1 = nnx.Dict({}) >>> @jax.jit @@ -1378,7 +1386,7 @@ def iter_graph(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]: root. Repeated nodes are visited only once. Leaves include static values. Example:: - >>> from flax.experimental import nnx + >>> from flax import nnx >>> import jax.numpy as jnp ... >>> class Linear(nnx.Module): diff --git a/flax/experimental/nnx/nnx/helpers.py b/flax/nnx/nnx/helpers.py similarity index 95% rename from flax/experimental/nnx/nnx/helpers.py rename to flax/nnx/nnx/helpers.py index 90901c1f8..5667e38df 100644 --- a/flax/experimental/nnx/nnx/helpers.py +++ b/flax/nnx/nnx/helpers.py @@ -34,11 +34,11 @@ import jax.numpy as jnp import optax -from flax.experimental.nnx.nnx.graph import Key -from flax.experimental.nnx.nnx.module import GraphDef, Module -from flax.experimental.nnx.nnx.proxy_caller import ApplyCaller -from flax.experimental.nnx.nnx.rnglib import Rngs -from flax.experimental.nnx.nnx.state import State +from flax.nnx.nnx.graph import Key +from flax.nnx.nnx.module import GraphDef, Module +from flax.nnx.nnx.proxy_caller import ApplyCaller +from flax.nnx.nnx.rnglib import Rngs +from flax.nnx.nnx.state import State from flax.training.train_state import struct A = tp.TypeVar('A') diff --git a/flax/experimental/nnx/nnx/ids.py b/flax/nnx/nnx/ids.py similarity index 100% rename from flax/experimental/nnx/nnx/ids.py rename to flax/nnx/nnx/ids.py diff --git a/flax/experimental/nnx/nnx/module.py b/flax/nnx/nnx/module.py similarity index 95% rename from flax/experimental/nnx/nnx/module.py rename to flax/nnx/nnx/module.py index 1cb578d83..6f99558e7 100644 --- a/flax/experimental/nnx/nnx/module.py +++ b/flax/nnx/nnx/module.py @@ -19,14 +19,14 @@ import jax.tree_util as jtu -from flax.experimental.nnx.nnx import ( +from flax.nnx.nnx import ( filterlib, graph, ) -from flax.experimental.nnx.nnx import variables as variableslib -from flax.experimental.nnx.nnx.graph import GraphDef -from flax.experimental.nnx.nnx.object import Object, ObjectMeta -from flax.experimental.nnx.nnx.state import State, StateLeaf +from flax.nnx.nnx import variables as variableslib +from flax.nnx.nnx.graph import GraphDef +from flax.nnx.nnx.object import Object, ObjectMeta +from flax.nnx.nnx.state import State, StateLeaf from flax.typing import Path, PathParts A = tp.TypeVar('A') @@ -83,7 +83,7 @@ def iter_modules(self) -> tp.Iterator[tuple[PathParts, Module]]: Example:: - >>> from flax.experimental import nnx + >>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): @@ -116,7 +116,7 @@ def set_attributes( Example:: - >>> from flax.experimental import nnx + >>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): @@ -174,7 +174,7 @@ def train(self, **attributes): Example:: - >>> from flax.experimental import nnx + >>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): @@ -210,7 +210,7 @@ def eval(self, **attributes): Example:: - >>> from flax.experimental import nnx + >>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): diff --git a/flax/experimental/nnx/nnx/nn/__init__.py b/flax/nnx/nnx/nn/__init__.py similarity index 100% rename from flax/experimental/nnx/nnx/nn/__init__.py rename to flax/nnx/nnx/nn/__init__.py diff --git a/flax/experimental/nnx/nnx/nn/activations.py b/flax/nnx/nnx/nn/activations.py similarity index 100% rename from flax/experimental/nnx/nnx/nn/activations.py rename to flax/nnx/nnx/nn/activations.py diff --git a/flax/experimental/nnx/nnx/nn/attention.py b/flax/nnx/nnx/nn/attention.py similarity index 98% rename from flax/experimental/nnx/nnx/nn/attention.py rename to flax/nnx/nnx/nn/attention.py index 8e66567c6..d66400bc3 100644 --- a/flax/experimental/nnx/nnx/nn/attention.py +++ b/flax/nnx/nnx/nn/attention.py @@ -23,16 +23,16 @@ import jax.numpy as jnp from jax import lax, random -from flax.experimental import nnx -from flax.experimental.nnx.nnx import rnglib -from flax.experimental.nnx.nnx.module import Module, first_from -from flax.experimental.nnx.nnx.nn import initializers -from flax.experimental.nnx.nnx.nn.dtypes import promote_dtype -from flax.experimental.nnx.nnx.nn.linear import ( +from flax import nnx +from flax.nnx.nnx import rnglib +from flax.nnx.nnx.module import Module, first_from +from flax.nnx.nnx.nn import initializers +from flax.nnx.nnx.nn.dtypes import promote_dtype +from flax.nnx.nnx.nn.linear import ( LinearGeneral, default_kernel_init, ) -from flax.experimental.nnx.nnx.nn.normalization import LayerNorm +from flax.nnx.nnx.nn.normalization import LayerNorm from flax.typing import ( Dtype, Shape, @@ -40,6 +40,7 @@ PrecisionLike, DotGeneralT, ) + Array = jax.Array @@ -590,7 +591,7 @@ def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32): Example usage:: - >>> from flax.experimental import nnx + >>> from flax import nnx >>> import jax.numpy as jnp ... >>> rngs = nnx.Rngs(42) diff --git a/flax/experimental/nnx/nnx/nn/dtypes.py b/flax/nnx/nnx/nn/dtypes.py similarity index 100% rename from flax/experimental/nnx/nnx/nn/dtypes.py rename to flax/nnx/nnx/nn/dtypes.py diff --git a/flax/experimental/nnx/nnx/nn/initializers.py b/flax/nnx/nnx/nn/initializers.py similarity index 95% rename from flax/experimental/nnx/nnx/nn/initializers.py rename to flax/nnx/nnx/nn/initializers.py index ce7371824..2a44d7147 100644 --- a/flax/experimental/nnx/nnx/nn/initializers.py +++ b/flax/nnx/nnx/nn/initializers.py @@ -42,7 +42,7 @@ def zeros_init() -> Initializer: """Builds an initializer that returns a constant array full of zeros. >>> import jax, jax.numpy as jnp - >>> from flax.experimental.nnx import initializers + >>> from flax.nnx import initializers >>> zeros_initializer = initializers.zeros_init() >>> zeros_initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], @@ -55,7 +55,7 @@ def ones_init() -> Initializer: """Builds an initializer that returns a constant array full of ones. >>> import jax, jax.numpy as jnp - >>> from flax.experimental.nnx import initializers + >>> from flax.nnx import initializers >>> ones_initializer = initializers.ones_init() >>> ones_initializer(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], diff --git a/flax/experimental/nnx/nnx/nn/linear.py b/flax/nnx/nnx/nn/linear.py similarity index 97% rename from flax/experimental/nnx/nnx/nn/linear.py rename to flax/nnx/nnx/nn/linear.py index 0c8f0cd91..696aeac54 100644 --- a/flax/experimental/nnx/nnx/nn/linear.py +++ b/flax/nnx/nnx/nn/linear.py @@ -36,10 +36,10 @@ import opt_einsum from flax.core.frozen_dict import FrozenDict -from flax.experimental import nnx -from flax.experimental.nnx.nnx import rnglib, variables -from flax.experimental.nnx.nnx.module import Module, first_from -from flax.experimental.nnx.nnx.nn import dtypes, initializers +from flax import nnx +from flax.nnx.nnx import rnglib, variables +from flax.nnx.nnx.module import Module, first_from +from flax.nnx.nnx.nn import dtypes, initializers from flax.typing import ( Dtype, Shape, @@ -110,7 +110,7 @@ class LinearGeneral(Module): Example usage:: - >>> from flax.experimental import nnx + >>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> # equivalent to `nnx.Linear(2, 4)` @@ -270,7 +270,7 @@ def __call__(self, inputs: Array) -> Array: contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims)) inputs, kernel, bias = dtypes.promote_dtype( - (inputs, kernel, bias), dtype=self.dtype + (inputs, kernel, bias), dtype=self.dtype ) if self.dot_general_cls is not None: @@ -355,7 +355,7 @@ def __call__(self, inputs: Array) -> Array: bias = self.bias.value inputs, kernel, bias = dtypes.promote_dtype( - (inputs, kernel, bias), dtype=self.dtype + (inputs, kernel, bias), dtype=self.dtype ) y = self.dot_general( inputs, @@ -373,7 +373,7 @@ class Einsum(Module): Example usage:: - >>> from flax.experimental import nnx + >>> from flax import nnx >>> import jax.numpy as jnp ... >>> layer = nnx.Einsum('nta,hab->nthb', (8, 2, 4), (8, 4), rngs=nnx.Rngs(0)) @@ -463,12 +463,12 @@ def __call__( self._einsum_str_check(einsum_str) inputs, kernel, bias = dtypes.promote_dtype( - ( - inputs, - self.kernel.value, - self.bias.value if self.bias is not None else self.bias, - ), - dtype=self.dtype, + ( + inputs, + self.kernel.value, + self.bias.value if self.bias is not None else self.bias, + ), + dtype=self.dtype, ) y = jnp.einsum(einsum_str, inputs, kernel, precision=self.precision) @@ -706,7 +706,7 @@ def maybe_broadcast( bias = self.bias.value inputs, kernel, bias = dtypes.promote_dtype( - (inputs, kernel, bias), dtype=self.dtype + (inputs, kernel, bias), dtype=self.dtype ) y = self.conv_general_dilated( @@ -730,6 +730,7 @@ def maybe_broadcast( y = jnp.reshape(y, output_shape) return y + class ConvTranspose(Module): # features: int # kernel_size: Union[int, Sequence[int]] @@ -869,7 +870,7 @@ def maybe_broadcast( bias = self.bias.value if self.bias is not None else None inputs, kernel, bias = dtypes.promote_dtype( - (inputs, kernel, bias), dtype=self.dtype + (inputs, kernel, bias), dtype=self.dtype ) y = lax.conv_transpose( @@ -997,7 +998,7 @@ def __call__(self, inputs: Array) -> Array: # Use take because fancy indexing numpy arrays with JAX indices does not # work correctly. (embedding,) = dtypes.promote_dtype( - (self.embedding.value,), dtype=self.dtype, inexact=False + (self.embedding.value,), dtype=self.dtype, inexact=False ) if self.num_embeddings == 1: return jnp.where( @@ -1022,6 +1023,6 @@ def attend(self, query: Array) -> Array: in NLP models. """ query, embedding = dtypes.promote_dtype( - (query, self.embedding.value), dtype=self.dtype + (query, self.embedding.value), dtype=self.dtype ) return jnp.dot(query, embedding.T) diff --git a/flax/experimental/nnx/nnx/nn/lora.py b/flax/nnx/nnx/nn/lora.py similarity index 89% rename from flax/experimental/nnx/nnx/nn/lora.py rename to flax/nnx/nnx/nn/lora.py index 2ac217efd..96d495db5 100644 --- a/flax/experimental/nnx/nnx/nn/lora.py +++ b/flax/nnx/nnx/nn/lora.py @@ -32,14 +32,11 @@ import jax import jax.numpy as jnp -from flax.experimental import nnx -from flax.experimental.nnx.nnx import rnglib, variables -from flax.experimental.nnx.nnx.module import Module -from flax.experimental.nnx.nnx.nn import initializers -from flax.typing import ( - Dtype, - Initializer, -) +from flax.nnx.nnx import rnglib, variables +from flax.nnx.nnx.module import Module +from flax.nnx.nnx.nn import initializers +from flax.nnx.nnx.nn.linear import Linear +from flax.typing import Dtype, Initializer Array = jax.Array Axis = int @@ -49,7 +46,8 @@ default_kernel_init = initializers.lecun_normal() -class LoRAParam(variables.Variable[A]): pass +class LoRAParam(variables.Variable[A]): + pass class LoRA(Module): @@ -88,13 +86,14 @@ class LoRA(Module): kernel_init: initializer function for the weight matrices. lora_param_type: the type of the LoRA params. """ + def __init__( self, in_features: int, lora_rank: int, out_features: int, *, - base_module: tp.Optional[nnx.Module] = None, + base_module: tp.Optional[Module] = None, dtype: tp.Optional[Dtype] = None, param_dtype: Dtype = jnp.float32, kernel_init: Initializer = default_kernel_init, @@ -124,7 +123,7 @@ def __call__(self, x: jax.Array): return out -class LoRALinear(nnx.Linear): +class LoRALinear(Linear): """An `nnx.Linear` layer in which the output will be LoRAified. The model state structure will be compatible with that of Linear. @@ -159,6 +158,7 @@ class LoRALinear(nnx.Linear): kernel_init: initializer function for the weight matrices. lora_param_type: the type of the LoRA params. """ + def __init__( self, in_features: int, @@ -173,18 +173,18 @@ def __init__( **kwargs, ): super().__init__(in_features, out_features, rngs=rngs, **kwargs) - self.lora = LoRA(in_features, lora_rank, out_features, - dtype=lora_dtype, param_dtype=lora_param_dtype, - kernel_init=lora_kernel_init, lora_param_type=lora_param_type, - rngs=rngs) + self.lora = LoRA( + in_features, + lora_rank, + out_features, + dtype=lora_dtype, + param_dtype=lora_param_dtype, + kernel_init=lora_kernel_init, + lora_param_type=lora_param_type, + rngs=rngs, + ) def __call__(self, x: jax.Array): y = super().__call__(x) y += self.lora(x) return y - - - - - - diff --git a/flax/experimental/nnx/nnx/nn/normalization.py b/flax/nnx/nnx/nn/normalization.py similarity index 98% rename from flax/experimental/nnx/nnx/nn/normalization.py rename to flax/nnx/nnx/nn/normalization.py index f27d6b279..c65754fda 100644 --- a/flax/experimental/nnx/nnx/nn/normalization.py +++ b/flax/nnx/nnx/nn/normalization.py @@ -18,10 +18,10 @@ import jax.numpy as jnp from jax import lax -from flax.experimental import nnx -from flax.experimental.nnx.nnx import rnglib -from flax.experimental.nnx.nnx.module import Module, first_from -from flax.experimental.nnx.nnx.nn import dtypes, initializers +from flax import nnx +from flax.nnx.nnx import rnglib +from flax.nnx.nnx.module import Module, first_from +from flax.nnx.nnx.nn import dtypes, initializers from flax.typing import ( Array, Dtype, diff --git a/flax/experimental/nnx/nnx/nn/stochastic.py b/flax/nnx/nnx/nn/stochastic.py similarity index 96% rename from flax/experimental/nnx/nnx/nn/stochastic.py rename to flax/nnx/nnx/nn/stochastic.py index efd8f94f3..a2ee77bc2 100644 --- a/flax/experimental/nnx/nnx/nn/stochastic.py +++ b/flax/nnx/nnx/nn/stochastic.py @@ -34,8 +34,8 @@ import jax.numpy as jnp from jax import lax, random -from flax.experimental.nnx.nnx import rnglib -from flax.experimental.nnx.nnx.module import Module, first_from +from flax.nnx.nnx import rnglib +from flax.nnx.nnx.module import Module, first_from @dataclasses.dataclass diff --git a/flax/experimental/nnx/nnx/object.py b/flax/nnx/nnx/object.py similarity index 97% rename from flax/experimental/nnx/nnx/object.py rename to flax/nnx/nnx/object.py index cd0284cb3..9b2ae9a43 100644 --- a/flax/experimental/nnx/nnx/object.py +++ b/flax/nnx/nnx/object.py @@ -24,13 +24,13 @@ import jax import numpy as np -from flax.experimental.nnx.nnx import ( +from flax.nnx.nnx import ( errors, reprlib, tracers, ) -from flax.experimental.nnx.nnx import graph -from flax.experimental.nnx.nnx.variables import Variable, VariableState +from flax.nnx.nnx import graph +from flax.nnx.nnx.variables import Variable, VariableState from flax.typing import Key G = tp.TypeVar('G', bound='Object') diff --git a/flax/experimental/nnx/nnx/proxy_caller.py b/flax/nnx/nnx/proxy_caller.py similarity index 100% rename from flax/experimental/nnx/nnx/proxy_caller.py rename to flax/nnx/nnx/proxy_caller.py diff --git a/flax/experimental/nnx/nnx/reprlib.py b/flax/nnx/nnx/reprlib.py similarity index 100% rename from flax/experimental/nnx/nnx/reprlib.py rename to flax/nnx/nnx/reprlib.py diff --git a/flax/experimental/nnx/nnx/rnglib.py b/flax/nnx/nnx/rnglib.py similarity index 94% rename from flax/experimental/nnx/nnx/rnglib.py rename to flax/nnx/nnx/rnglib.py index 8554f85f4..2f93457f1 100644 --- a/flax/experimental/nnx/nnx/rnglib.py +++ b/flax/nnx/nnx/rnglib.py @@ -33,12 +33,12 @@ import jax import jax.numpy as jnp -from flax.experimental.nnx.nnx import graph -from flax.experimental.nnx.nnx.state import State -from flax.experimental.nnx.nnx.variables import Variable -from flax.experimental.nnx.nnx import filterlib -from flax.experimental.nnx.nnx.filterlib import All -from flax.experimental.nnx.nnx.object import Object +from flax.nnx.nnx import graph +from flax.nnx.nnx.state import State +from flax.nnx.nnx.variables import Variable +from flax.nnx.nnx import filterlib +from flax.nnx.nnx.filterlib import All +from flax.nnx.nnx.object import Object Counts = list[int] AxesValue = tp.Union[int, None] @@ -63,6 +63,7 @@ class RngCount(RngState): class RngKey(RngState): tag: str + class RngKeyBackup(RngState): pass @@ -155,6 +156,7 @@ def __len__(self) -> int: def __contains__(self, name: tp.Any) -> bool: return name in vars(self) + class ForkStates(tp.NamedTuple): split_keys: State split_counts: State diff --git a/flax/experimental/nnx/nnx/spmd.py b/flax/nnx/nnx/spmd.py similarity index 98% rename from flax/experimental/nnx/nnx/spmd.py rename to flax/nnx/nnx/spmd.py index 20c063017..fd7067c0a 100644 --- a/flax/experimental/nnx/nnx/spmd.py +++ b/flax/nnx/nnx/spmd.py @@ -19,8 +19,8 @@ from jax.interpreters import pxla from jax.sharding import Mesh, PartitionSpec -from flax.experimental.nnx.nnx import variables -from flax.experimental.nnx.nnx.state import State +from flax.nnx.nnx import variables +from flax.nnx.nnx.state import State from flax.typing import ( Array, ArrayPytree, # pylint: disable=invalid-name diff --git a/flax/experimental/nnx/nnx/state.py b/flax/nnx/nnx/state.py similarity index 96% rename from flax/experimental/nnx/nnx/state.py rename to flax/nnx/nnx/state.py index dff6fec5d..c9edd5a75 100644 --- a/flax/experimental/nnx/nnx/state.py +++ b/flax/nnx/nnx/state.py @@ -35,8 +35,8 @@ import numpy as np from flax import traverse_util -from flax.experimental.nnx.nnx import filterlib, reprlib -from flax.experimental.nnx.nnx.variables import VariableState +from flax.nnx.nnx import filterlib, reprlib +from flax.nnx.nnx.variables import VariableState from flax.typing import Key, PathParts A = tp.TypeVar('A') @@ -142,8 +142,7 @@ def from_flat_path( return cls(nested_state) @tp.overload - def split(self, first: filterlib.Filter, /) -> 'State': - ... + def split(self, first: filterlib.Filter, /) -> 'State': ... @tp.overload def split( @@ -152,8 +151,7 @@ def split( second: filterlib.Filter, /, *filters: filterlib.Filter, - ) -> tuple['State', ...]: - ... + ) -> tuple['State', ...]: ... def split( self, first: filterlib.Filter, /, *filters: filterlib.Filter @@ -179,8 +177,7 @@ def filter( self, first: filterlib.Filter, /, - ) -> 'State': - ... + ) -> 'State': ... @tp.overload def filter( @@ -189,8 +186,7 @@ def filter( second: filterlib.Filter, /, *filters: filterlib.Filter, - ) -> tuple['State', ...]: - ... + ) -> tuple['State', ...]: ... def filter( self, @@ -239,6 +235,7 @@ def __sub__(self, other: 'State') -> 'State': return State.from_flat_path(diff) + def _state_flatten_with_keys(x: State): items = sorted(x._mapping.items()) children = tuple((jtu.DictKey(key), value) for key, value in items) diff --git a/flax/experimental/nnx/nnx/tracers.py b/flax/nnx/nnx/tracers.py similarity index 97% rename from flax/experimental/nnx/nnx/tracers.py rename to flax/nnx/nnx/tracers.py index 1e8688f4e..c73e627e5 100644 --- a/flax/experimental/nnx/nnx/tracers.py +++ b/flax/nnx/nnx/tracers.py @@ -20,7 +20,7 @@ import jax.core from jax.core import MainTrace -from flax.experimental.nnx.nnx import reprlib +from flax.nnx.nnx import reprlib @tp.runtime_checkable diff --git a/flax/experimental/nnx/nnx/training/__init__.py b/flax/nnx/nnx/training/__init__.py similarity index 100% rename from flax/experimental/nnx/nnx/training/__init__.py rename to flax/nnx/nnx/training/__init__.py diff --git a/flax/experimental/nnx/nnx/training/metrics.py b/flax/nnx/nnx/training/metrics.py similarity index 88% rename from flax/experimental/nnx/nnx/training/metrics.py rename to flax/nnx/nnx/training/metrics.py index 87ec7831b..41b130fbf 100644 --- a/flax/experimental/nnx/nnx/training/metrics.py +++ b/flax/nnx/nnx/training/metrics.py @@ -28,13 +28,14 @@ from __future__ import annotations import jax, jax.numpy as jnp -from flax.experimental.nnx.nnx.object import Object -from flax.experimental.nnx.nnx.variables import Variable -from flax.experimental.nnx.nnx import filterlib, graph +from flax.nnx.nnx.object import Object +from flax.nnx.nnx.variables import Variable +from flax.nnx.nnx import filterlib, graph import typing as tp -#TODO: add tests and docstrings +# TODO: add tests and docstrings + class MetricState(Variable): """Wrapper class for Metric Variables.""" @@ -45,13 +46,16 @@ class MetricState(Variable): class Metric(Object): def __init__(self): raise NotImplementedError('Must override `__init__()` method.') + def reset(self): raise NotImplementedError('Must override `reset()` method.') def update(self, **kwargs) -> None: raise NotImplementedError('Must override `update()` method.') + def compute(self): raise NotImplementedError('Must override `compute()` method.') + def split(self, *filters: filterlib.Filter): return graph.split(self, *filters) @@ -61,6 +65,7 @@ def __init__(self, argname: str = 'values'): self.argname = argname self.total = MetricState(jnp.array(0, dtype=jnp.float32)) self.count = MetricState(jnp.array(0, dtype=jnp.int32)) + def reset(self): self.total.value = jnp.array(0, dtype=jnp.float32) self.count.value = jnp.array(0, dtype=jnp.int32) @@ -69,19 +74,24 @@ def update(self, **kwargs): if self.argname not in kwargs: raise TypeError(f"Expected keyword argument '{self.argname}'") values: tp.Union[int, float, jax.Array] = kwargs[self.argname] - self.total.value += values if isinstance(values, (int, float)) else values.sum() + self.total.value += ( + values if isinstance(values, (int, float)) else values.sum() + ) self.count.value += 1 if isinstance(values, (int, float)) else values.size + def compute(self): return self.total.value / self.count.value + class Accuracy(Average): def update(self, *, logits: jax.Array, labels: jax.Array, **_): # type: ignore[override] if logits.ndim != labels.ndim + 1 or labels.dtype != jnp.int32: raise ValueError( - f"Expected labels.dtype==jnp.int32 and logits.ndim={logits.ndim}==" - f"labels.ndim+1={labels.ndim + 1}" + f'Expected labels.dtype==jnp.int32 and logits.ndim={logits.ndim}==' + f'labels.ndim+1={labels.ndim + 1}' ) - super().update(values=(logits.argmax(axis=-1)==labels)) + super().update(values=(logits.argmax(axis=-1) == labels)) + class MultiMetric(Metric): """MultiMetric class to store multiple metrics and update them in a single call. @@ -89,7 +99,7 @@ class MultiMetric(Metric): Example usage:: >>> import jax, jax.numpy as jnp - >>> from flax.experimental import nnx + >>> from flax import nnx ... >>> logits = jax.random.normal(jax.random.key(0), (5, 2)) >>> labels = jnp.array([1, 1, 0, 1, 0]) @@ -114,19 +124,26 @@ class MultiMetric(Metric): >>> metrics.compute() {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)} """ + def __init__(self, **metrics): # TODO: raise error if a kwarg is passed that is in ('reset', 'update', 'compute'), since these names are reserved for methods self._metric_names = [] for metric_name, metric in metrics.items(): self._metric_names.append(metric_name) vars(self)[metric_name] = metric + def reset(self): for metric_name in self._metric_names: getattr(self, metric_name).reset() + def update(self, **updates): # TODO: should we give the option of updating only some of the metrics and not all? e.g. if for some kwargs==None, don't do update # TODO: should we raise an error if a kwarg is passed into **updates that has no match with any underlying metric? e.g. user typo for metric_name in self._metric_names: getattr(self, metric_name).update(**updates) + def compute(self): - return {f'{metric_name}': getattr(self, metric_name).compute() for metric_name in self._metric_names} \ No newline at end of file + return { + f'{metric_name}': getattr(self, metric_name).compute() + for metric_name in self._metric_names + } \ No newline at end of file diff --git a/flax/experimental/nnx/nnx/training/optimizer.py b/flax/nnx/nnx/training/optimizer.py similarity index 92% rename from flax/experimental/nnx/nnx/training/optimizer.py rename to flax/nnx/nnx/training/optimizer.py index c20fc35d4..00215c006 100644 --- a/flax/experimental/nnx/nnx/training/optimizer.py +++ b/flax/nnx/nnx/training/optimizer.py @@ -30,24 +30,27 @@ import jax.numpy as jnp import optax -from flax.experimental import nnx -from flax.experimental.nnx.nnx import filterlib, graph -from flax.experimental.nnx.nnx.object import Object -from flax.experimental.nnx.nnx.variables import Variable +from flax import nnx +from flax.nnx.nnx import filterlib, graph +from flax.nnx.nnx.object import Object +from flax.nnx.nnx.variables import Variable + +# TODO: add tests and docstrings -#TODO: add tests and docstrings class OptState(Variable): """Wrapper class for Optimizer Variables.""" + pass + class Optimizer(Object): """Simple train state for the common case with a single Optax optimizer. Example usage:: >>> import jax, jax.numpy as jnp - >>> from flax.experimental import nnx + >>> from flax import nnx >>> import optax ... >>> class Model(nnx.Module): @@ -134,13 +137,10 @@ def update(self, grads): """ params = nnx.state(self.model, nnx.Param) - updates, new_opt_state = self.tx.update( - grads, self.opt_state, params - ) + updates, new_opt_state = self.tx.update(grads, self.opt_state, params) new_params = optax.apply_updates(params, updates) assert isinstance(new_params, nnx.State) self.step.value += 1 nnx.update(self.model, new_params) self.opt_state = new_opt_state - diff --git a/flax/experimental/nnx/nnx/transforms.py b/flax/nnx/nnx/transforms.py similarity index 96% rename from flax/experimental/nnx/nnx/transforms.py rename to flax/nnx/nnx/transforms.py index 653d3b116..8fc4a8979 100644 --- a/flax/experimental/nnx/nnx/transforms.py +++ b/flax/nnx/nnx/transforms.py @@ -35,19 +35,19 @@ from flax import struct from flax.core.frozen_dict import FrozenDict -from flax.experimental.nnx.nnx import ( +from flax.nnx.nnx import ( filterlib, graph, rnglib, spmd, variables, ) -from flax.experimental.nnx.nnx.module import GraphDef, Module -from flax.experimental.nnx.nnx.proxy_caller import ( +from flax.nnx.nnx.module import GraphDef, Module +from flax.nnx.nnx.proxy_caller import ( CallableProxy, DelayedAccessor, ) -from flax.experimental.nnx.nnx.state import State +from flax.nnx.nnx.state import State from flax.typing import Leaf import jax from jax._src.tree_util import broadcast_prefix @@ -115,6 +115,7 @@ def check_and_call(accessor: DelayedAccessor, *args, **kwargs): UNSPECIFIED = object() + def _default_constrain_state(state: State) -> State: state_spec = spmd.get_partition_spec(state) state = jax.lax.with_sharding_constraint(state, state_spec) @@ -619,7 +620,7 @@ def grad( Example:: - >>> from flax.experimental import nnx + >>> from flax import nnx ... >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) @@ -779,6 +780,7 @@ def _call(self, accessor: DelayedAccessor, *args) -> tp.Any: # scan # ------------------------------- + @dataclasses.dataclass(frozen=True) class FlatDef(tp.Generic[A]): type: type[A] @@ -1333,22 +1335,21 @@ def __post_init__(self): class Remat(tp.Generic[M], LiftedModule[M]): - @staticmethod def constructor( - module_constructor: tp.Callable[..., MA], - prevent_cse: bool = True, - static_argnums: int | tuple[int, ...] = (), - policy: tp.Callable[..., bool] | None = None, + module_constructor: tp.Callable[..., MA], + prevent_cse: bool = True, + static_argnums: int | tuple[int, ...] = (), + policy: tp.Callable[..., bool] | None = None, ) -> tp.Callable[..., 'Remat[MA]']: def create_remat(*args, **kwargs): return Remat( - module_constructor=module_constructor, - module_init_args=args, - module_init_kwargs=kwargs, - prevent_cse=prevent_cse, - static_argnums=static_argnums, - policy=policy, + module_constructor=module_constructor, + module_init_args=args, + module_init_kwargs=kwargs, + prevent_cse=prevent_cse, + static_argnums=static_argnums, + policy=policy, ) return create_remat @@ -1461,38 +1462,37 @@ class VmapOptions: class Vmap(tp.Generic[M], LiftedModule[M]): - @staticmethod def constructor( - module_constructor: tp.Callable[..., MA], - *, - in_axes: int | None | tp.Sequence[tp.Any] = 0, - out_axes: tp.Any = 0, - axis_name: AxisName | None = None, - axis_size: int | None = None, - spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, - # nnx specific - in_axes_kwargs: tp.Any = 0, - state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), - split_rngs: filterlib.Filter = ..., - transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), + module_constructor: tp.Callable[..., MA], + *, + in_axes: int | None | tp.Sequence[tp.Any] = 0, + out_axes: tp.Any = 0, + axis_name: AxisName | None = None, + axis_size: int | None = None, + spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, + # nnx specific + in_axes_kwargs: tp.Any = 0, + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), ) -> tp.Callable[..., 'Vmap[MA]']: def _create_vmap(*args, **kwargs): return Vmap( - module_constructor=module_constructor, - in_axes=in_axes, - out_axes=out_axes, - axis_size=axis_size, - axis_name=axis_name, - spmd_axis_name=spmd_axis_name, - # nnx specific - in_axes_kwargs=in_axes_kwargs, - state_axes=state_axes, - split_rngs=split_rngs, - transform_metadata=transform_metadata, - # submodule args - module_init_args=args, - module_init_kwargs=kwargs, + module_constructor=module_constructor, + in_axes=in_axes, + out_axes=out_axes, + axis_size=axis_size, + axis_name=axis_name, + spmd_axis_name=spmd_axis_name, + # nnx specific + in_axes_kwargs=in_axes_kwargs, + state_axes=state_axes, + split_rngs=split_rngs, + transform_metadata=transform_metadata, + # submodule args + module_init_args=args, + module_init_kwargs=kwargs, ) return _create_vmap @@ -1577,7 +1577,7 @@ def vmap_apply( # split module state filters = (*options.state_axes.keys(), ...) graphdef, rng_state, *vectorized_states, broadcast_state = ctx.split( # type: ignore[misc] - input_graph_nodes, rnglib.RngState, *filters + input_graph_nodes, rnglib.RngState, *filters ) # infer length @@ -1682,14 +1682,14 @@ def vmap_fn( # split module state ( - graphdef_out, - rng_state_out, - *vectorized_states_out, - broadcast_state_out, + graphdef_out, + rng_state_out, + *vectorized_states_out, + broadcast_state_out, ) = ctx.split( # type: ignore[misc] - (input_graph_nodes, output_graph_nodes), - rnglib.RngState, - *filters, + (input_graph_nodes, output_graph_nodes), + rnglib.RngState, + *filters, ) not_keys_out, split_keys_out, broadcast_keys_out = rng_state_out.split( @@ -1804,6 +1804,7 @@ def _eval_shape_fn(state: State, *args, **kwargs): out = graph.insert_graph_nodes(out, output_nodes) return out + # ------------------------------- # cond # ------------------------------- @@ -1814,6 +1815,7 @@ class CondStaticInputs(tp.Generic[A]): true_fun: tp.Callable[..., A] false_fun: tp.Callable[..., A] + jax.tree_util.register_static(CondStaticInputs) @@ -1867,4 +1869,4 @@ def cond( **kwargs, ) _operands_out, out = ctx.merge(graphdef_out, state_out) - return out \ No newline at end of file + return out diff --git a/flax/experimental/nnx/nnx/variables.py b/flax/nnx/nnx/variables.py similarity index 98% rename from flax/experimental/nnx/nnx/variables.py rename to flax/nnx/nnx/variables.py index bef44f5dd..dafe286c7 100644 --- a/flax/experimental/nnx/nnx/variables.py +++ b/flax/nnx/nnx/variables.py @@ -34,8 +34,8 @@ import typing as tp from typing import Any -from flax.experimental import nnx -from flax.experimental.nnx.nnx import reprlib, tracers +from flax import nnx +from flax.nnx.nnx import reprlib, tracers import jax.tree_util as jtu A = tp.TypeVar('A') @@ -76,6 +76,7 @@ def __hash__(self): class _Missing: pass + MISSING = _Missing() @@ -224,8 +225,7 @@ def __init__( if tp.TYPE_CHECKING: - def __getattr__(self, name: str) -> tp.Any: - ... + def __getattr__(self, name: str) -> tp.Any: ... else: def __setattr__(self, name: str, value: Any) -> None: @@ -304,12 +304,10 @@ def __eq__(self, other: object) -> bool: return type(self) is type(other) and vars(other) == vars(self) @tp.overload - def replace(self, value: B, **kwargs) -> 'Variable[B]': - ... + def replace(self, value: B, **kwargs) -> 'Variable[B]': ... @tp.overload - def replace(self, **kwargs) -> 'Variable[A]': - ... + def replace(self, **kwargs) -> 'Variable[A]': ... def replace(self, value: tp.Any = MISSING, **kwargs) -> 'Variable[tp.Any]': if value is not MISSING: @@ -571,12 +569,11 @@ class Intermediate(Variable[A]): class VariableState(tp.Generic[A], reprlib.Representable): - def __init__( - self, - type: type[Variable[tp.Any]], - value: A, - **metadata, + self, + type: type[Variable[tp.Any]], + value: A, + **metadata, ): self.type = type self.value = value @@ -584,8 +581,7 @@ def __init__( if tp.TYPE_CHECKING: - def __getattr__(self, name: str) -> tp.Any: - ... + def __getattr__(self, name: str) -> tp.Any: ... def __nnx_repr__(self): yield reprlib.Object(type=type(self)) diff --git a/flax/experimental/nnx/nnx/visualization.py b/flax/nnx/nnx/visualization.py similarity index 99% rename from flax/experimental/nnx/nnx/visualization.py rename to flax/nnx/nnx/visualization.py index 0f657363c..03317b3e2 100644 --- a/flax/experimental/nnx/nnx/visualization.py +++ b/flax/nnx/nnx/visualization.py @@ -18,7 +18,7 @@ import jax -from flax.experimental import nnx +from flax import nnx penzai_installed = importlib.util.find_spec('penzai') is not None try: diff --git a/flax/experimental/nnx/scripts/requirements.txt b/flax/nnx/scripts/requirements.txt similarity index 100% rename from flax/experimental/nnx/scripts/requirements.txt rename to flax/nnx/scripts/requirements.txt diff --git a/flax/experimental/nnx/scripts/run-all-examples.bash b/flax/nnx/scripts/run-all-examples.bash similarity index 91% rename from flax/experimental/nnx/scripts/run-all-examples.bash rename to flax/nnx/scripts/run-all-examples.bash index 523fa3cf4..570e9c98e 100644 --- a/flax/experimental/nnx/scripts/run-all-examples.bash +++ b/flax/nnx/scripts/run-all-examples.bash @@ -2,7 +2,7 @@ set -e cd ../../.. source .venv/bin/activate -cd flax/experimental/nnx +cd flax/nnx for f in $(find examples/toy_examples -name "*.py" -maxdepth 1); do echo -e "\n---------------------------------" diff --git a/flax/experimental/nnx/tests/__init__.py b/flax/nnx/tests/__init__.py similarity index 100% rename from flax/experimental/nnx/tests/__init__.py rename to flax/nnx/tests/__init__.py diff --git a/flax/experimental/nnx/tests/compat/test_module.py b/flax/nnx/tests/compat/test_module.py similarity index 97% rename from flax/experimental/nnx/tests/compat/test_module.py rename to flax/nnx/tests/compat/test_module.py index 70bd403c5..df7603351 100644 --- a/flax/experimental/nnx/tests/compat/test_module.py +++ b/flax/nnx/tests/compat/test_module.py @@ -17,8 +17,8 @@ import jax import jax.numpy as jnp -from flax.experimental import nnx -from flax.experimental.nnx import compat +from flax import nnx +from flax.nnx import compat class TestCompatModule: diff --git a/flax/experimental/nnx/tests/compat/test_wrappers.py b/flax/nnx/tests/compat/test_wrappers.py similarity index 93% rename from flax/experimental/nnx/tests/compat/test_wrappers.py rename to flax/nnx/tests/compat/test_wrappers.py index 1b5cd2bf7..64f8c7743 100644 --- a/flax/experimental/nnx/tests/compat/test_wrappers.py +++ b/flax/nnx/tests/compat/test_wrappers.py @@ -15,8 +15,8 @@ import jax from flax import linen -from flax.experimental import nnx -from flax.experimental.nnx import compat +from flax import nnx +from flax.nnx import compat class TestCompatibility: diff --git a/flax/experimental/nnx/tests/nn/test_attention.py b/flax/nnx/tests/nn/test_attention.py similarity index 99% rename from flax/experimental/nnx/tests/nn/test_attention.py rename to flax/nnx/tests/nn/test_attention.py index 489c786a5..9c45264d9 100644 --- a/flax/experimental/nnx/tests/nn/test_attention.py +++ b/flax/nnx/tests/nn/test_attention.py @@ -16,7 +16,7 @@ from jax.lax import Precision from flax import linen -from flax.experimental import nnx +from flax import nnx from flax.typing import Dtype, PrecisionLike from numpy.testing import assert_array_equal diff --git a/flax/experimental/nnx/tests/nn/test_conv.py b/flax/nnx/tests/nn/test_conv.py similarity index 99% rename from flax/experimental/nnx/tests/nn/test_conv.py rename to flax/nnx/tests/nn/test_conv.py index f6a773905..41a3a8044 100644 --- a/flax/experimental/nnx/tests/nn/test_conv.py +++ b/flax/nnx/tests/nn/test_conv.py @@ -22,7 +22,7 @@ from numpy.testing import assert_array_equal from flax import linen -from flax.experimental import nnx +from flax import nnx from flax.typing import PaddingLike, Dtype, PrecisionLike diff --git a/flax/experimental/nnx/tests/nn/test_embed.py b/flax/nnx/tests/nn/test_embed.py similarity index 98% rename from flax/experimental/nnx/tests/nn/test_embed.py rename to flax/nnx/tests/nn/test_embed.py index bed5ab1a8..faababe00 100644 --- a/flax/experimental/nnx/tests/nn/test_embed.py +++ b/flax/nnx/tests/nn/test_embed.py @@ -20,7 +20,7 @@ from numpy.testing import assert_array_equal from flax import linen -from flax.experimental import nnx +from flax import nnx from flax.typing import Dtype diff --git a/flax/experimental/nnx/tests/nn/test_linear.py b/flax/nnx/tests/nn/test_linear.py similarity index 99% rename from flax/experimental/nnx/tests/nn/test_linear.py rename to flax/nnx/tests/nn/test_linear.py index 944f03b97..aa55eb642 100644 --- a/flax/experimental/nnx/tests/nn/test_linear.py +++ b/flax/nnx/tests/nn/test_linear.py @@ -21,7 +21,7 @@ from numpy.testing import assert_array_equal from flax import linen -from flax.experimental import nnx +from flax import nnx from flax.typing import Dtype, PrecisionLike, Shape diff --git a/flax/experimental/nnx/tests/nn/test_lora.py b/flax/nnx/tests/nn/test_lora.py similarity index 94% rename from flax/experimental/nnx/tests/nn/test_lora.py rename to flax/nnx/tests/nn/test_lora.py index d619f1e81..b58db0245 100644 --- a/flax/experimental/nnx/tests/nn/test_lora.py +++ b/flax/nnx/tests/nn/test_lora.py @@ -17,7 +17,7 @@ from absl.testing import absltest import numpy as np -from flax.experimental import nnx +from flax import nnx class TestLora(absltest.TestCase): @@ -31,7 +31,6 @@ def test_basic(self): assert module.lora_b.value.shape == (2, 4) np.testing.assert_allclose(y, x @ module.lora_a.value @ module.lora_b.value) - def test_lora_base_module(self): rngs = nnx.Rngs(0) linear = nnx.Linear(3, 4, use_bias=False, rngs=rngs) @@ -45,14 +44,16 @@ def test_lora_base_module(self): assert module.base_module.bias.value == None assert module.lora_a.value.shape == (3, 2) assert module.lora_b.value.shape == (2, 4) - np.testing.assert_allclose(y, x @ linear.kernel.value + x @ module.lora_a.value @ module.lora_b.value) - + np.testing.assert_allclose( + y, x @ linear.kernel.value + x @ module.lora_a.value @ module.lora_b.value + ) def test_layer_swap_lora(self): class MLP(nnx.Module): def __init__(self, dim, rngs: nnx.Rngs): self.linear1 = nnx.Linear(dim, dim, rngs=rngs) self.linear2 = nnx.Linear(dim, dim, rngs=rngs) + def __call__(self, x): x = self.linear1(x) return self.linear2(x) @@ -72,12 +73,12 @@ def __call__(self, x): a, b = model.linear2.lora_a.value, model.linear2.lora_b.value np.testing.assert_allclose(y + model.linear1(x) @ a @ b, lora_y) - def test_layer_swap_loralinear(self): class MLP(nnx.Module): def __init__(self, dim, rngs: nnx.Rngs): self.linear1 = nnx.Linear(dim, dim, rngs=rngs) self.linear2 = nnx.Linear(dim, dim, rngs=rngs) + def __call__(self, x): x = self.linear1(x) return self.linear2(x) @@ -88,7 +89,9 @@ def __call__(self, x): y = model(x) # Replace one of the linear layers as LoRA linear layer. - _, state = nnx.split(model.linear2) # To keep the kernel and bias of linear2 + _, state = nnx.split( + model.linear2 + ) # To keep the kernel and bias of linear2 model.linear2 = nnx.LoRALinear(3, 3, lora_rank=4, rngs=rngs) nnx.update(model.linear2, state) lora_y = model(x) @@ -99,7 +102,6 @@ def __call__(self, x): a, b = model.linear2.lora.lora_a.value, model.linear2.lora.lora_b.value np.testing.assert_allclose(y + model.linear1(x) @ a @ b, lora_y) - def test_lora_param_type(self): rngs = nnx.Rngs(0) model = nnx.LoRA(3, 4, 2, lora_param_type=nnx.LoRAParam, rngs=rngs) @@ -117,4 +119,3 @@ def test_lora_param_type(self): if __name__ == '__main__': absltest.main() - diff --git a/flax/experimental/nnx/tests/nn/test_normalization.py b/flax/nnx/tests/nn/test_normalization.py similarity index 99% rename from flax/experimental/nnx/tests/nn/test_normalization.py rename to flax/nnx/tests/nn/test_normalization.py index 854c367ae..3e30febcf 100644 --- a/flax/experimental/nnx/tests/nn/test_normalization.py +++ b/flax/nnx/tests/nn/test_normalization.py @@ -20,7 +20,7 @@ from numpy.testing import assert_array_equal from flax import linen -from flax.experimental import nnx +from flax import nnx from flax.typing import Dtype diff --git a/flax/experimental/nnx/tests/nn/test_stochastic.py b/flax/nnx/tests/nn/test_stochastic.py similarity index 98% rename from flax/experimental/nnx/tests/nn/test_stochastic.py rename to flax/nnx/tests/nn/test_stochastic.py index f302a34f1..1ba6944ae 100644 --- a/flax/experimental/nnx/tests/nn/test_stochastic.py +++ b/flax/nnx/tests/nn/test_stochastic.py @@ -16,7 +16,7 @@ import jax.numpy as jnp import numpy as np -from flax.experimental import nnx +from flax import nnx import pytest diff --git a/flax/experimental/nnx/tests/test_containers.py b/flax/nnx/tests/test_containers.py similarity index 97% rename from flax/experimental/nnx/tests/test_containers.py rename to flax/nnx/tests/test_containers.py index 582d661ab..4757d494e 100644 --- a/flax/experimental/nnx/tests/test_containers.py +++ b/flax/nnx/tests/test_containers.py @@ -13,7 +13,7 @@ # limitations under the License. -from flax.experimental import nnx +from flax import nnx class TestContainers: diff --git a/flax/experimental/nnx/tests/test_graph_utils.py b/flax/nnx/tests/test_graph_utils.py similarity index 99% rename from flax/experimental/nnx/tests/test_graph_utils.py rename to flax/nnx/tests/test_graph_utils.py index 64a07a193..52ebcba75 100644 --- a/flax/experimental/nnx/tests/test_graph_utils.py +++ b/flax/nnx/tests/test_graph_utils.py @@ -17,7 +17,7 @@ import jax import pytest -from flax.experimental import nnx +from flax import nnx from flax import struct diff --git a/flax/experimental/nnx/tests/test_helpers.py b/flax/nnx/tests/test_helpers.py similarity index 90% rename from flax/experimental/nnx/tests/test_helpers.py rename to flax/nnx/tests/test_helpers.py index 4e84f3b30..8a7cec4db 100644 --- a/flax/experimental/nnx/tests/test_helpers.py +++ b/flax/nnx/tests/test_helpers.py @@ -19,7 +19,8 @@ from numpy.testing import assert_array_equal from flax import linen -from flax.experimental import nnx +from flax import nnx + class TrainState(nnx.TrainState): batch_stats: nnx.State @@ -76,13 +77,17 @@ def test_nnx_linen_sequential_equivalence(self): rngs = nnx.Rngs(0) x = jax.random.uniform(key1, (3, 1, 5)) - model_nnx = nnx.Sequential(nnx.Linear(5, 4, rngs=rngs), nnx.Linear(4, 2, rngs=rngs)) + model_nnx = nnx.Sequential( + nnx.Linear(5, 4, rngs=rngs), nnx.Linear(4, 2, rngs=rngs) + ) model = linen.Sequential([linen.Dense(4), linen.Dense(2)]) variables = model.init(key2, x) for layer_index in range(2): for param in ('kernel', 'bias'): - variables['params'][f'layers_{layer_index}'][param] = getattr(model_nnx.layers[layer_index], param).value + variables['params'][f'layers_{layer_index}'][param] = getattr( + model_nnx.layers[layer_index], param + ).value out_nnx = model_nnx(x) out = model.apply(variables, x) assert_array_equal(out, out_nnx) @@ -90,7 +95,9 @@ def test_nnx_linen_sequential_equivalence(self): variables = model.init(key2, x) for layer_index in range(2): for param in ('kernel', 'bias'): - getattr(model_nnx.layers[layer_index], param).value = variables['params'][f'layers_{layer_index}'][param] + getattr(model_nnx.layers[layer_index], param).value = variables[ + 'params' + ][f'layers_{layer_index}'][param] out_nnx = model_nnx(x) out = model.apply(variables, x) - assert_array_equal(out, out_nnx) \ No newline at end of file + assert_array_equal(out, out_nnx) diff --git a/flax/experimental/nnx/tests/test_ids.py b/flax/nnx/tests/test_ids.py similarity index 95% rename from flax/experimental/nnx/tests/test_ids.py rename to flax/nnx/tests/test_ids.py index 9460e6724..d72490c83 100644 --- a/flax/experimental/nnx/tests/test_ids.py +++ b/flax/nnx/tests/test_ids.py @@ -14,7 +14,7 @@ import copy -from flax.experimental.nnx.nnx import ids +from flax.nnx.nnx import ids class TestIds: diff --git a/flax/experimental/nnx/tests/test_integration.py b/flax/nnx/tests/test_integration.py similarity index 99% rename from flax/experimental/nnx/tests/test_integration.py rename to flax/nnx/tests/test_integration.py index 49b58af2a..c473562b8 100644 --- a/flax/experimental/nnx/tests/test_integration.py +++ b/flax/nnx/tests/test_integration.py @@ -18,7 +18,7 @@ import jax.numpy as jnp import numpy as np -from flax.experimental import nnx +from flax import nnx A = tp.TypeVar('A') diff --git a/flax/experimental/nnx/tests/test_metrics.py b/flax/nnx/tests/test_metrics.py similarity index 96% rename from flax/experimental/nnx/tests/test_metrics.py rename to flax/nnx/tests/test_metrics.py index 2e0188ee7..9a84cceb9 100644 --- a/flax/experimental/nnx/tests/test_metrics.py +++ b/flax/nnx/tests/test_metrics.py @@ -15,7 +15,7 @@ import jax import jax.numpy as jnp -from flax.experimental import nnx +from flax import nnx from absl.testing import parameterized @@ -58,7 +58,7 @@ def test_multimetric(self): metrics.update(logits=logits2, labels=labels2, values=batch_loss2) values = metrics.compute() self.assertEqual(values['accuracy'], 0.7) - self.assertEqual(values['loss'], 2.) + self.assertEqual(values['loss'], 2.0) metrics.reset() values = metrics.compute() diff --git a/flax/experimental/nnx/tests/test_module.py b/flax/nnx/tests/test_module.py similarity index 99% rename from flax/experimental/nnx/tests/test_module.py rename to flax/nnx/tests/test_module.py index 1d4724c6f..1590d4934 100644 --- a/flax/experimental/nnx/tests/test_module.py +++ b/flax/nnx/tests/test_module.py @@ -21,15 +21,14 @@ import numpy as np import pytest -from flax.experimental import nnx +from flax import nnx A = TypeVar('A') class TestModule: def test_has_module_state(self): - class Foo(nnx.Module): - ... + class Foo(nnx.Module): ... foo = Foo() @@ -475,6 +474,7 @@ def __init__(self, din, dout, *, rngs: nnx.Rngs): raise_if_not_found=False, ) + class TestModulePytree: def test_tree_map(self): class Foo(nnx.Module, experimental_pytree=True): diff --git a/flax/experimental/nnx/tests/test_optimizer.py b/flax/nnx/tests/test_optimizer.py similarity index 94% rename from flax/experimental/nnx/tests/test_optimizer.py rename to flax/nnx/tests/test_optimizer.py index d1de7cc55..a7e0310f1 100644 --- a/flax/experimental/nnx/tests/test_optimizer.py +++ b/flax/nnx/tests/test_optimizer.py @@ -17,7 +17,7 @@ import numpy as np import optax -from flax.experimental import nnx +from flax import nnx from absl.testing import parameterized @@ -26,6 +26,7 @@ class Model(nnx.Module): def __init__(self, in_features, out_features, rngs): self.linear1 = nnx.Linear(in_features, 3, rngs=rngs) self.linear2 = nnx.Linear(3, out_features, rngs=rngs) + def __call__(self, x): return self.linear2(self.linear1(x)) @@ -54,7 +55,9 @@ def test_jit(self, module_cls, jit_decorator, optimizer): x = jax.random.normal(jax.random.key(0), (1, 2)) y = jnp.ones((1, 4)) model = module_cls(2, 4, rngs=nnx.Rngs(0)) - tx = optimizer(1e-3) # TODO: this doesn't work with adam optimizer for some reason + tx = optimizer( + 1e-3 + ) # TODO: this doesn't work with adam optimizer for some reason state = nnx.Optimizer(model, tx) if jit_decorator == jax.jit: @@ -76,7 +79,7 @@ def jax_jit_train_step(graphdef, state, x, y): new_loss = loss_fn(*nnx.split(state.model), x, y) else: - loss_fn = lambda model, x, y: ((model(x)-y)**2).mean() + loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() initial_loss = loss_fn(state.model, x, y) def nnx_jit_train_step(optimizer: nnx.Optimizer, x, y): @@ -109,7 +112,7 @@ def update(self, *, grads, **updates): # type: ignore[signature-mismatch] metrics = nnx.metrics.Average() state = TrainState(model, tx, metrics) - loss_fn = lambda model: ((model(x)-y)**2).mean() + loss_fn = lambda model: ((model(x) - y) ** 2).mean() grads = nnx.grad(loss_fn, wrt=nnx.Param)(state.model) state.update(grads=grads, values=loss_fn(state.model)) initial_loss = state.metrics.compute() diff --git a/flax/experimental/nnx/tests/test_partitioning.py b/flax/nnx/tests/test_partitioning.py similarity index 99% rename from flax/experimental/nnx/tests/test_partitioning.py rename to flax/nnx/tests/test_partitioning.py index b2d5fdfdc..e390887ae 100644 --- a/flax/experimental/nnx/tests/test_partitioning.py +++ b/flax/nnx/tests/test_partitioning.py @@ -16,7 +16,7 @@ import jax import pytest -from flax.experimental import nnx +from flax import nnx class TestPartitioning: diff --git a/flax/experimental/nnx/tests/test_rngs.py b/flax/nnx/tests/test_rngs.py similarity index 99% rename from flax/experimental/nnx/tests/test_rngs.py rename to flax/nnx/tests/test_rngs.py index 23c505c30..918a1be1e 100644 --- a/flax/experimental/nnx/tests/test_rngs.py +++ b/flax/nnx/tests/test_rngs.py @@ -19,7 +19,7 @@ import jax.numpy as jnp import pytest -from flax.experimental import nnx +from flax import nnx class TestRngs: @@ -51,7 +51,6 @@ def test_rng_stream(self): assert rngs.params.key.value is key0 assert not jnp.allclose(key1, key2) - def test_rng_trace_level_constraints(self): rngs = nnx.Rngs(0) diff --git a/flax/experimental/nnx/tests/test_spmd.py b/flax/nnx/tests/test_spmd.py similarity index 98% rename from flax/experimental/nnx/tests/test_spmd.py rename to flax/nnx/tests/test_spmd.py index 83e2f1b10..0353bfc53 100644 --- a/flax/experimental/nnx/tests/test_spmd.py +++ b/flax/nnx/tests/test_spmd.py @@ -19,7 +19,7 @@ from jax.experimental import mesh_utils from jax.sharding import Mesh, PartitionSpec -from flax.experimental import nnx +from flax import nnx class TestSPMD: diff --git a/flax/experimental/nnx/tests/test_state.py b/flax/nnx/tests/test_state.py similarity index 98% rename from flax/experimental/nnx/tests/test_state.py rename to flax/nnx/tests/test_state.py index be8a8a178..e1884134f 100644 --- a/flax/experimental/nnx/tests/test_state.py +++ b/flax/nnx/tests/test_state.py @@ -14,7 +14,7 @@ from absl.testing import absltest -from flax.experimental import nnx +from flax import nnx class StateTest(absltest.TestCase): diff --git a/flax/experimental/nnx/tests/test_transforms.py b/flax/nnx/tests/test_transforms.py similarity index 99% rename from flax/experimental/nnx/tests/test_transforms.py rename to flax/nnx/tests/test_transforms.py index 18721d941..1d9c0b707 100644 --- a/flax/experimental/nnx/tests/test_transforms.py +++ b/flax/nnx/tests/test_transforms.py @@ -22,7 +22,7 @@ import pytest from jax.experimental import mesh_utils -from flax.experimental import nnx +from flax import nnx class TestJIT: @@ -345,7 +345,6 @@ def constrain_object(m): m.kernel.value.sharding - class TestGrad: def test_grad(self): p1 = nnx.Param(10.0) @@ -1221,6 +1220,7 @@ def __call__(self, x: jax.Array) -> jax.Array: assert module.vmap_module.graphdef == 'hello' + class TestCond: def test_basic(self): class TimeStep(tp.NamedTuple): diff --git a/flax/experimental/nnx/tests/test_variable.py b/flax/nnx/tests/test_variable.py similarity index 98% rename from flax/experimental/nnx/tests/test_variable.py rename to flax/nnx/tests/test_variable.py index af297eeae..de5c5c52c 100644 --- a/flax/experimental/nnx/tests/test_variable.py +++ b/flax/nnx/tests/test_variable.py @@ -17,7 +17,7 @@ import jax import jax.numpy as jnp -from flax.experimental import nnx +from flax import nnx A = tp.TypeVar('A') diff --git a/pyproject.toml b/pyproject.toml index 9b3992306..f10838f0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,7 +115,7 @@ ignore_missing_imports = true disable_error_code = "annotation-unchecked" # exclude nnx examples [[tool.mypy.overrides]] -module = "flax.experimental.nnx.examples.*" +module = "flax.nnx.examples.*" ignore_errors = true [tool.pytest.ini_options] diff --git a/tests/run_all_tests.sh b/tests/run_all_tests.sh index 8106f2c5c..a130c26ca 100755 --- a/tests/run_all_tests.sh +++ b/tests/run_all_tests.sh @@ -85,7 +85,7 @@ if $RUN_DOCTEST; then pytest -n auto flax \ --doctest-modules \ --suppress-no-test-exit-code \ - --ignore=flax/experimental/nnx/examples + --ignore=flax/nnx/examples fi # check that flax is running on editable mode @@ -112,7 +112,7 @@ if $RUN_PYTEST; then echo "pytest -n auto tests $PYTEST_OPTS $PYTEST_IGNORE" pytest -n auto tests $PYTEST_OPTS $PYTEST_IGNORE # Run nnx tests - pytest -n auto flax/experimental/nnx/tests $PYTEST_OPTS $PYTEST_IGNORE + pytest -n auto flax/nnx/tests $PYTEST_OPTS $PYTEST_IGNORE pytest -n auto docs/_ext/codediff_test.py $PYTEST_OPTS $PYTEST_IGNORE # Per-example tests. @@ -128,7 +128,7 @@ if $RUN_PYTEST; then pytest $egd done - for egd in $(find flax/experimental/nnx/examples -maxdepth 1 -mindepth 1 -type d); do + for egd in $(find flax/nnx/examples -maxdepth 1 -mindepth 1 -type d); do # skip if folder starts with "_" or is "toy_examples" if [[ $egd == *"_"* ]] || [[ $egd == *"toy_examples"* ]]; then continue @@ -140,7 +140,7 @@ fi if $RUN_PYTYPE; then echo "=== RUNNING PYTYPE ===" # Validate types in NNX examples. - for egd in $(find flax/experimental/nnx/examples -maxdepth 1 -mindepth 1 -type d); do + for egd in $(find flax/nnx/examples -maxdepth 1 -mindepth 1 -type d); do # skip if folder starts with "_" or is "toy_examples" if [[ $egd == *"_"* ]] || [[ $egd == *"toy_examples"* ]]; then continue @@ -148,11 +148,11 @@ if $RUN_PYTYPE; then # use cd to make sure pytype cache lives in example dir and doesn't name clash # use *.py to avoid importing configs as a top-level import which leads to import errors # because config files use relative imports (e.g. from config import ...). - (cd $egd ; pytype "*.py" --jobs auto --config ../../../../../pyproject.toml) + (cd $egd ; pytype "*.py" --jobs auto --config ../../../../pyproject.toml) done # Validate types in library code. pytype --jobs auto --config pyproject.toml flax/ \ - --exclude flax/experimental/nnx/examples + --exclude flax/nnx/examples # Validate types in examples. for egd in $(find examples -maxdepth 1 -mindepth 1 -type d); do