From 46ac8625786624423811bc32bd0cb419144f2a1f Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Wed, 16 Oct 2024 22:42:39 +0000 Subject: [PATCH] Update Flax NNX Glossary --- docs_nnx/nnx_glossary.rst | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs_nnx/nnx_glossary.rst b/docs_nnx/nnx_glossary.rst index 1d00fdcc6..864c8a0ad 100644 --- a/docs_nnx/nnx_glossary.rst +++ b/docs_nnx/nnx_glossary.rst @@ -10,7 +10,7 @@ For additional terms, refer to the `JAX glossary ` objects out of a Flax NNX :term:`Module` (``nnx.Module``). This is usually done by calling :meth:`nnx.split ` upon the :class:`nnx.Module`. Refer to the `Filter guide `__ to learn more. Folding in - In Flax, `folding in `__ means generating a new JAX `pseudorandom number generator (PRNG) `__ key, given an input PRNG key and integer. This is typically used when you want to generate a new key but still be able to use the original PRNG key afterwards. You can also do this in JAX with `jax.random.split `__, but this method will effectively create two PRNG keys, which is slower. Learn how Flax generates new PRNG keys automatically in the `Randomness/PRNG guide `__. + In Flax, `folding in `__ means generating a new `JAX pseudorandom number generator (PRNG) `__ key, given an input PRNG key and integer. This is typically used when you want to generate a new key but still be able to use the original PRNG key afterwards. You can also do this in JAX with `jax.random.split `__, but this method will effectively create two PRNG keys, which is slower. Learn how Flax generates new PRNG keys automatically in the `Randomness/PRNG guide `__. GraphDef :class:`nnx.GraphDef` is a class that represents all the static, stateless, and Pythonic parts of a Flax :term:`Module` (:class:`nnx.Module`). @@ -25,19 +25,19 @@ For additional terms, refer to the `JAX glossary ` is a particular subclass of :class:`nnx.Variable ` that generally contains the trainable weights. PRNG states - A Flax :class:`nnx.Module ` can keep a reference of a `pseudorandom number generator (PRNG) `__ state object :class:`nnx.Rngs ` that can generate new JAX `PRNG `__ keys. These keys are used to generate random JAX arrays through `JAX's functional pseudorandom number generators `__. + A Flax :class:`nnx.Module ` can keep a reference of a `pseudorandom number generator (PRNG) `__ state object :class:`nnx.Rngs ` that can generate new `JAX PRNG `__ keys. These keys are used to generate random JAX arrays through `JAX's functional PRNGs `__. You can use a PRNG state with different seeds to add more fine-grained control to your model (for example, to have independent random numbers for parameters and dropout masks). - Refer to the `Flax Randomness/PRNG guide `__ + Refer to the Flax `Randomness/PRNG guide `__ for more details. Split and merge - :meth:`nnx.split ` is a way to represent an :class:`nnx.Module ` by two parts: 1) a static Flax NNX :term:`GraphDef ` that captures its Pythonic static information; and 2) one or more :term:`Variable state(s)` that capture its `JAX arrays `__ (``jax.Array``) in the form of `JAX pytrees `__. They can be merged back to the original ``nnx.Module`` using :meth:`nnx.merge `. + :meth:`nnx.split ` is a way to represent an :class:`nnx.Module ` by two parts: 1) a static Flax NNX :term:`GraphDef ` that captures its Pythonic static information; and 2) one or more :term:`Variable state(s)` that capture its `JAX arrays `__ (``jax.Array``) in the form of `JAX pytrees `__. They can be merged back to the original ``nnx.Module`` using :meth:`nnx.merge `. Transformation A Flax NNX transformation (transform) is a wrapped version of a `JAX transformation `__ that allows the function that is being transformed to take the Flax NNX :term:`Module` (``nnx.Module``) as input or output. For example, a "lifted" version of `jax.jit `__ is :meth:`nnx.jit `. Check out the `Flax NNX transforms guide `__ to learn more. Variable - The weights / parameters / data / array :class:`Variable ` residing in a Flax :term:`Module`. Variables are defined inside modules as :class:`nnx.Variable ` or its subclasses. + The weights / parameters / data / array :class:`nnx.Variable ` residing in a Flax :term:`Module`. Variables are defined inside modules as :class:`nnx.Variable ` or its subclasses. Variable state - :class:`nnx.VariableState ` is a purely functional `JAX pytree `__ of all the :term:`Variables` inside a :term:`Module`. Since it is pure, it can be an input or output of a `JAX transformation `__ function. ``nnx.VariableState`` is obtained by using :meth:`nnx.split ` on the :class:`nnx.Module `. (Refer to :term:`splitting` and :term:`Module` to learn more.) + :class:`nnx.VariableState ` is a purely functional `JAX pytree `__ of all the :term:`Variables` inside a :term:`Module`. Since it is pure, it can be an input or output of a `JAX transformation `__ function. ``nnx.VariableState`` is obtained by using :meth:`nnx.split ` on the :class:`nnx.Module `. (Refer to :term:`splitting` and :term:`Module` to learn more.)