Skip to content

Commit

Permalink
Use nnx.filterlib.Filter instead of simple Iterable.
Browse files Browse the repository at this point in the history
  • Loading branch information
nilq committed Mar 6, 2025
1 parent e4c6d89 commit e8cce00
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions flax/nnx/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ class WeightNorm(nnx.Module):
... def __init__(self, rngs: nnx.Rngs):
... self.normed_linear = nnx.WeightNorm(
... nnx.Linear(8, 4, rngs=rngs),
... variable_filter=('kernel',),
... variable_filter=nnx.PathContains('kernel'),
... rngs=rngs,
... )
...
Expand All @@ -901,7 +901,7 @@ class WeightNorm(nnx.Module):
epsilon: The epsilon value for the normalization, by default 1e-12.
dtype: The dtype of the result, by default infer from input and params.
param_dtype: The dtype of the parameters, by default float32.
variable_filter: The variable filter, by default ``('kernel',)``.
variable_filter: The variable filter, by default ``nnx.PathContains('kernel')``.
rngs: The rng key.
"""
def __init__(
Expand All @@ -914,7 +914,7 @@ def __init__(
epsilon: float = 1e-12,
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
variable_filter: tp.Iterable[str] | None = ('kernel',),
variable_filter: nnx.filterlib.Filter = nnx.PathContains('kernel'),
rngs: rnglib.Rngs,
):
self.layer_instance = layer_instance
Expand Down Expand Up @@ -944,14 +944,8 @@ def __call__(self, x: Array, *args, **kwargs) -> Array:
state = nnx.state(self.layer_instance)

def apply_weightnorm(path, var_state):
path_str = '/'.join(str(k) for k in path)

if self.variable_filter:
for variable_name in self.variable_filter:
if variable_name in path_str:
break
else:
return var_state
if not self.variable_filter(path, var_state):
return var_state

param_val = jnp.asarray(var_state.value)
if self.feature_axes is None:
Expand Down

0 comments on commit e8cce00

Please sign in to comment.