Skip to content

Commit

Permalink
[bridge] improve sow and add initializers
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 729649600
  • Loading branch information
Cristian Garcia authored and Flax Authors committed Feb 28, 2025
1 parent 6af8fcb commit 6f59e6f
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 16 deletions.
3 changes: 2 additions & 1 deletion flax/nnx/bridge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@
from .variables import with_partitioning as with_partitioning
from .module import Module as Module
from .module import Scope as Scope
from .module import compact as compact
from .module import compact as compact
from flax.nnx.nn import initializers as initializers
7 changes: 3 additions & 4 deletions flax/nnx/bridge/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,6 @@ def _bind_module(parent: Module, module: Module) -> Module:


class ModuleMeta(nnx_module.ModuleMeta):
if not tp.TYPE_CHECKING:

def __call__(cls, *args, **kwargs):
return _module_meta_call(cls, *args, **kwargs)

def _object_meta_construct(cls, self, *args, **kwargs):
vars(self)['scope'] = None
Expand Down Expand Up @@ -159,6 +155,9 @@ def _module_meta_call(cls: type[M], *args, **kwargs) -> M:

return module # type: ignore

# set ModuleMeta.__call__ like this because pytype doesn't understand
# the use of TYPE_CHECKING conditionals for metaclass methods
ModuleMeta.__call__ = _module_meta_call

class ModuleBase:
if tp.TYPE_CHECKING:
Expand Down
35 changes: 24 additions & 11 deletions flax/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ class Module(Object, metaclass=ModuleMeta):
"""

def sow(
self,
variable_type: tp.Type[variableslib.Variable[tp.Any]],
name: str,
value: A,
reduce_fn: tp.Callable[[B, A], B] = tuple_reduce,
init_fn: tp.Callable[[], B] = tuple_init, # type: ignore
) -> None:
self,
variable_type: type[variableslib.Variable[A]] | str,
name: str,
value: A,
reduce_fn: tp.Callable[[B, A], B] = tuple_reduce,
init_fn: tp.Callable[[], B] = tuple_init, # type: ignore
) -> bool:
"""``sow()`` can be used to collect intermediate values without
the overhead of explicitly passing a container through each Module call.
``sow()`` stores a value in a new ``Module`` attribute, denoted by ``name``.
Expand Down Expand Up @@ -169,6 +169,11 @@ def sow(
of ``init_fn`` together with the value to be stored. The default is an
empty tuple.
"""
if isinstance(variable_type, str):
variable_type = variableslib.variable_type_from_name(
variable_type, allow_register=True
)

if hasattr(self, name):
variable = getattr(self, name)
if not isinstance(variable, variableslib.Variable):
Expand All @@ -185,11 +190,15 @@ def sow(
reduced_value = reduce_fn(init_fn(), value)
setattr(self, name, variable_type(reduced_value))

return True

def perturb(
self,
name: str,
value: tp.Any,
variable_type: tp.Type[variableslib.Variable[tp.Any]] = variableslib.Perturbation,
self,
name: str,
value: tp.Any,
variable_type: (
str | type[variableslib.Variable[tp.Any]]
) = variableslib.Perturbation,
):
"""Add an zero-value variable ("perturbation") to the intermediate value.
Expand Down Expand Up @@ -246,6 +255,10 @@ def perturb(
variable_type: The :class:`Variable` type for the stored perturbation.
Defaulted at :class:`nnx.Perturbation`.
"""
if isinstance(variable_type, str):
variable_type = variableslib.variable_type_from_name(
variable_type, allow_register=True
)
if not hasattr(self, name):
zeros = jax.tree.map(jnp.zeros_like, value)
setattr(self, name, variable_type(zeros))
Expand Down

0 comments on commit 6f59e6f

Please sign in to comment.