From e8cce0096fe96c8ed0e1e764ec8cb09776549355 Mon Sep 17 00:00:00 2001 From: Niels Horn Date: Thu, 6 Mar 2025 10:07:29 +0100 Subject: [PATCH] Use nnx.filterlib.Filter instead of simple Iterable. --- flax/nnx/nn/normalization.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/flax/nnx/nn/normalization.py b/flax/nnx/nn/normalization.py index 8b747af4..c1a3dc2b 100644 --- a/flax/nnx/nn/normalization.py +++ b/flax/nnx/nn/normalization.py @@ -874,7 +874,7 @@ class WeightNorm(nnx.Module): ... def __init__(self, rngs: nnx.Rngs): ... self.normed_linear = nnx.WeightNorm( ... nnx.Linear(8, 4, rngs=rngs), - ... variable_filter=('kernel',), + ... variable_filter=nnx.PathContains('kernel'), ... rngs=rngs, ... ) ... @@ -901,7 +901,7 @@ class WeightNorm(nnx.Module): epsilon: The epsilon value for the normalization, by default 1e-12. dtype: The dtype of the result, by default infer from input and params. param_dtype: The dtype of the parameters, by default float32. - variable_filter: The variable filter, by default ``('kernel',)``. + variable_filter: The variable filter, by default ``nnx.PathContains('kernel')``. rngs: The rng key. """ def __init__( @@ -914,7 +914,7 @@ def __init__( epsilon: float = 1e-12, dtype: tp.Optional[Dtype] = None, param_dtype: Dtype = jnp.float32, - variable_filter: tp.Iterable[str] | None = ('kernel',), + variable_filter: nnx.filterlib.Filter = nnx.PathContains('kernel'), rngs: rnglib.Rngs, ): self.layer_instance = layer_instance @@ -944,14 +944,8 @@ def __call__(self, x: Array, *args, **kwargs) -> Array: state = nnx.state(self.layer_instance) def apply_weightnorm(path, var_state): - path_str = '/'.join(str(k) for k in path) - - if self.variable_filter: - for variable_name in self.variable_filter: - if variable_name in path_str: - break - else: - return var_state + if not self.variable_filter(path, var_state): + return var_state param_val = jnp.asarray(var_state.value) if self.feature_axes is None: