From 4e17aa10c10e3a97e86322d91ccc46825393c29f Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Sat, 24 Aug 2024 10:38:49 +0100 Subject: [PATCH] [nnx] improve landing page --- docs/nnx/index.rst | 15 ++++++++++----- docs/nnx/nnx_basics.ipynb | 17 +++++++---------- docs/nnx/nnx_basics.md | 17 +++++++---------- flax/nnx/__init__.py | 2 +- flax/nnx/nnx/rnglib.py | 5 ++++- flax/nnx/nnx/transforms/iteration.py | 1 - flax/nnx/tests/bridge/wrappers_test.py | 1 - flax/nnx/tests/graph_utils_test.py | 4 ++-- 8 files changed, 31 insertions(+), 31 deletions(-) diff --git a/docs/nnx/index.rst b/docs/nnx/index.rst index d85a5a15b..1b6067b32 100644 --- a/docs/nnx/index.rst +++ b/docs/nnx/index.rst @@ -1,13 +1,18 @@ NNX ======== +.. div:: sd-text-left sd-font-italic + **N**\ eural **N**\ etworks for JA\ **X** -NNX is a **N**\ eural **N**\ etwork library for JA\ **X** that focuses on providing the best -development experience, so building and experimenting with neural networks is easy and -intuitive. It achieves this by embracing Python’s object-oriented model and making it -compatible with JAX transforms, resulting in code that is easy to inspect, debug, and -analyze. + +---- + +NNX is a new Flax API that is designed to make it easier to create, inspect, debug, +and analyze neural networks in JAX. It achieves this by adding first class support +for Python reference semantics, allowing users to express their models using regular +Python objects. NNX takes years of feedback from Linen and brings to Flax a simpler +and more user-friendly experience. Features ^^^^^^^^^ diff --git a/docs/nnx/nnx_basics.ipynb b/docs/nnx/nnx_basics.ipynb index d0f438d2a..a44337586 100644 --- a/docs/nnx/nnx_basics.ipynb +++ b/docs/nnx/nnx_basics.ipynb @@ -6,15 +6,12 @@ "source": [ "# NNX Basics\n", "\n", - "NNX is a **N**eural **N**etwork library for JA**X** that focuses on providing the best \n", - "development experience, so building and experimenting with neural networks is easy and\n", - "intuitive. It achieves this by representing objects as PyGraphs (instead of PyTrees), \n", - "enabling reference sharing and mutability. This design allows your models to resemble \n", - "familiar Python object-oriented code, particularly appealing to users of frameworks\n", - "like PyTorch.\n", - "\n", - "Despite its simplified implementation, NNX supports the same powerful design patterns \n", - "that have allowed Linen to scale effectively to large codebases." + "NNX is a new Flax API that is designed to make it easier to create, inspect, debug,\n", + "and analyze neural networks in JAX. It achieves this by adding first class support\n", + "for Python reference semantics, allowing users to express their models using regular\n", + "Python objects, which are modeled as PyGraphs (instead of PyTrees), enabling reference\n", + "sharing and mutability. This design should should make PyTorch or Keras users feel at\n", + "home." ] }, { @@ -68,7 +65,7 @@ } ], "source": [ - "! pip install -U flax treescope" + "# ! pip install -U flax treescope" ] }, { diff --git a/docs/nnx/nnx_basics.md b/docs/nnx/nnx_basics.md index 7cdddaf69..70b3ff754 100644 --- a/docs/nnx/nnx_basics.md +++ b/docs/nnx/nnx_basics.md @@ -10,20 +10,17 @@ jupytext: # NNX Basics -NNX is a **N**eural **N**etwork library for JA**X** that focuses on providing the best -development experience, so building and experimenting with neural networks is easy and -intuitive. It achieves this by representing objects as PyGraphs (instead of PyTrees), -enabling reference sharing and mutability. This design allows your models to resemble -familiar Python object-oriented code, particularly appealing to users of frameworks -like PyTorch. - -Despite its simplified implementation, NNX supports the same powerful design patterns -that have allowed Linen to scale effectively to large codebases. +NNX is a new Flax API that is designed to make it easier to create, inspect, debug, +and analyze neural networks in JAX. It achieves this by adding first class support +for Python reference semantics, allowing users to express their models using regular +Python objects, which are modeled as PyGraphs (instead of PyTrees), enabling reference +sharing and mutability. This design should should make PyTorch or Keras users feel at +home. ```{code-cell} ipython3 :tags: [skip-execution] -! pip install -U flax treescope +# ! pip install -U flax treescope ``` ```{code-cell} ipython3 diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index d0a76e52f..5cf1c1667 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -125,7 +125,7 @@ from .nnx.variables import ( Param as Param, register_variable_name_type_pair as register_variable_name_type_pair, -) +) # this needs to be imported before optimizer to prevent circular import from .nnx.training import optimizer as optimizer from .nnx.training.metrics import Metric as Metric diff --git a/flax/nnx/nnx/rnglib.py b/flax/nnx/nnx/rnglib.py index c1e19da2f..8013c69ae 100644 --- a/flax/nnx/nnx/rnglib.py +++ b/flax/nnx/nnx/rnglib.py @@ -437,7 +437,10 @@ def split_rngs_wrapper(*args, **kwargs): key = stream() backups.append((stream, stream.key.value, stream.count.value)) stream.key.value = jax.random.split(key, splits) - counts_shape = (splits, *stream.count.shape) + if isinstance(splits, int): + counts_shape = (splits, *stream.count.shape) + else: + counts_shape = (*splits, *stream.count.shape) stream.count.value = jnp.zeros(counts_shape, dtype=jnp.uint32) return SplitBackups(backups) diff --git a/flax/nnx/nnx/transforms/iteration.py b/flax/nnx/nnx/transforms/iteration.py index a8193b3de..bf7650d4f 100644 --- a/flax/nnx/nnx/transforms/iteration.py +++ b/flax/nnx/nnx/transforms/iteration.py @@ -40,7 +40,6 @@ from flax.nnx.nnx.transforms.transforms import resolve_kwargs from flax.typing import Leaf, MISSING, Missing, PytreeDeque import jax -from jax._src.tree_util import broadcast_prefix import jax.core import jax.numpy as jnp import jax.stages diff --git a/flax/nnx/tests/bridge/wrappers_test.py b/flax/nnx/tests/bridge/wrappers_test.py index eb3dfaceb..d28d84ec1 100644 --- a/flax/nnx/tests/bridge/wrappers_test.py +++ b/flax/nnx/tests/bridge/wrappers_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from absl.testing import absltest import flax diff --git a/flax/nnx/tests/graph_utils_test.py b/flax/nnx/tests/graph_utils_test.py index 0273eb959..9917e2645 100644 --- a/flax/nnx/tests/graph_utils_test.py +++ b/flax/nnx/tests/graph_utils_test.py @@ -499,8 +499,8 @@ def __init__(self, dout: int, rngs: nnx.Rngs): self.rngs = rngs def __call__(self, x): - - @partial(nnx.vmap, in_axes=(0, None), axis_size=5) + @nnx.split_rngs(splits=5) + @nnx.vmap(in_axes=(0, None), axis_size=5) def vmap_fn(inner, x): return inner(x)