Skip to content

Commit

Permalink
Merge pull request #4141 from google:nnx-landing-page
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668020773
  • Loading branch information
Flax Authors committed Aug 27, 2024
2 parents 3a9d833 + 4e17aa1 commit a0622b0
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 31 deletions.
15 changes: 10 additions & 5 deletions docs/nnx/index.rst
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@

NNX
========
.. div:: sd-text-left sd-font-italic

**N**\ eural **N**\ etworks for JA\ **X**

NNX is a **N**\ eural **N**\ etwork library for JA\ **X** that focuses on providing the best
development experience, so building and experimenting with neural networks is easy and
intuitive. It achieves this by embracing Python’s object-oriented model and making it
compatible with JAX transforms, resulting in code that is easy to inspect, debug, and
analyze.

----

NNX is a new Flax API that is designed to make it easier to create, inspect, debug,
and analyze neural networks in JAX. It achieves this by adding first class support
for Python reference semantics, allowing users to express their models using regular
Python objects. NNX takes years of feedback from Linen and brings to Flax a simpler
and more user-friendly experience.

Features
^^^^^^^^^
Expand Down
17 changes: 7 additions & 10 deletions docs/nnx/nnx_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,12 @@
"source": [
"# NNX Basics\n",
"\n",
"NNX is a **N**eural **N**etwork library for JA**X** that focuses on providing the best \n",
"development experience, so building and experimenting with neural networks is easy and\n",
"intuitive. It achieves this by representing objects as PyGraphs (instead of PyTrees), \n",
"enabling reference sharing and mutability. This design allows your models to resemble \n",
"familiar Python object-oriented code, particularly appealing to users of frameworks\n",
"like PyTorch.\n",
"\n",
"Despite its simplified implementation, NNX supports the same powerful design patterns \n",
"that have allowed Linen to scale effectively to large codebases."
"NNX is a new Flax API that is designed to make it easier to create, inspect, debug,\n",
"and analyze neural networks in JAX. It achieves this by adding first class support\n",
"for Python reference semantics, allowing users to express their models using regular\n",
"Python objects, which are modeled as PyGraphs (instead of PyTrees), enabling reference\n",
"sharing and mutability. This design should should make PyTorch or Keras users feel at\n",
"home."
]
},
{
Expand Down Expand Up @@ -68,7 +65,7 @@
}
],
"source": [
"! pip install -U flax treescope"
"# ! pip install -U flax treescope"
]
},
{
Expand Down
17 changes: 7 additions & 10 deletions docs/nnx/nnx_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,17 @@ jupytext:

# NNX Basics

NNX is a **N**eural **N**etwork library for JA**X** that focuses on providing the best
development experience, so building and experimenting with neural networks is easy and
intuitive. It achieves this by representing objects as PyGraphs (instead of PyTrees),
enabling reference sharing and mutability. This design allows your models to resemble
familiar Python object-oriented code, particularly appealing to users of frameworks
like PyTorch.

Despite its simplified implementation, NNX supports the same powerful design patterns
that have allowed Linen to scale effectively to large codebases.
NNX is a new Flax API that is designed to make it easier to create, inspect, debug,
and analyze neural networks in JAX. It achieves this by adding first class support
for Python reference semantics, allowing users to express their models using regular
Python objects, which are modeled as PyGraphs (instead of PyTrees), enabling reference
sharing and mutability. This design should should make PyTorch or Keras users feel at
home.

```{code-cell} ipython3
:tags: [skip-execution]
! pip install -U flax treescope
# ! pip install -U flax treescope
```

```{code-cell} ipython3
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
from .nnx.variables import (
Param as Param,
register_variable_name_type_pair as register_variable_name_type_pair,
)
)
# this needs to be imported before optimizer to prevent circular import
from .nnx.training import optimizer as optimizer
from .nnx.training.metrics import Metric as Metric
Expand Down
5 changes: 4 additions & 1 deletion flax/nnx/nnx/rnglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,10 @@ def split_rngs_wrapper(*args, **kwargs):
key = stream()
backups.append((stream, stream.key.value, stream.count.value))
stream.key.value = jax.random.split(key, splits)
counts_shape = (splits, *stream.count.shape)
if isinstance(splits, int):
counts_shape = (splits, *stream.count.shape)
else:
counts_shape = (*splits, *stream.count.shape)
stream.count.value = jnp.zeros(counts_shape, dtype=jnp.uint32)

return SplitBackups(backups)
Expand Down
1 change: 0 additions & 1 deletion flax/nnx/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from flax.nnx.nnx.transforms.transforms import resolve_kwargs
from flax.typing import Leaf, MISSING, Missing, PytreeDeque
import jax
from jax._src.tree_util import broadcast_prefix
import jax.core
import jax.numpy as jnp
import jax.stages
Expand Down
1 change: 0 additions & 1 deletion flax/nnx/tests/bridge/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

from absl.testing import absltest
import flax
Expand Down
4 changes: 2 additions & 2 deletions flax/nnx/tests/graph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,8 +499,8 @@ def __init__(self, dout: int, rngs: nnx.Rngs):
self.rngs = rngs

def __call__(self, x):

@partial(nnx.vmap, in_axes=(0, None), axis_size=5)
@nnx.split_rngs(splits=5)
@nnx.vmap(in_axes=(0, None), axis_size=5)
def vmap_fn(inner, x):
return inner(x)

Expand Down

0 comments on commit a0622b0

Please sign in to comment.