Skip to content

Commit 19fc095

Browse files
author
Flax Authors
committed
Merge pull request #2172 from levskaya:condfix
PiperOrigin-RevId: 453935471
2 parents 25f2920 + 294bb41 commit 19fc095

File tree

3 files changed

+34
-9
lines changed

3 files changed

+34
-9
lines changed

flax/linen/module.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,8 @@ def _get_unbound_fn(method_or_fn: Callable[..., Any]) -> Callable[..., Any]:
384384
Returns:
385385
An unbound version of input function.
386386
"""
387-
if inspect.ismethod(method_or_fn):
387+
if (inspect.ismethod(method_or_fn) and
388+
isinstance(method_or_fn.__self__, Module)): # pytype: disable=attribute-error
388389
method_or_fn = method_or_fn.__func__ # pytype: disable=attribute-error
389390

390391
# The method should be callable, and it should have at least one argument

flax/linen/transforms.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from flax.linen.module import Module
3939
from flax.linen.module import Variable
4040
from flax.linen.module import wrap_method_once
41+
from flax.linen.module import _get_unbound_fn
4142
import jax
4243

4344
traceback_util.register_exclusion(__file__)
@@ -419,6 +420,8 @@ def lift_direct_transform(transform: Callable[..., Any],
419420
' That is function that takes a Module instance as its first arg.')
420421
elif not callable(target):
421422
raise ValueError('transform target must be callable')
423+
# normalize self.foo bound methods to class.foo unbound methods.
424+
targets = tuple(_get_unbound_fn(target) for target in targets)
422425
aug_transform = lambda *fns: functools.partial(transform, *fns)
423426
return decorator_lift_transform(
424427
aug_transform, targets, multi_scope=multi_scope)(mdl, *args, **kwargs)
@@ -1068,10 +1071,10 @@ def _cond_wrapper(t_fn, f_fn, scope, pred, *ops, variables, rngs):
10681071

10691072

10701073
def cond(
1071-
pred: Any,
1074+
pred: Any,
10721075
true_fun: Callable[..., C], false_fun: Callable[..., C],
10731076
mdl: Module, *operands,
1074-
variables: lift.CollectionFilter = True,
1077+
variables: lift.CollectionFilter = True,
10751078
rngs: lift.PRNGSequenceFilter = True) -> C:
10761079
"""Lifted version of ``jax.lax.cond``.
10771080
@@ -1082,7 +1085,7 @@ def cond(
10821085
Note that this constraint is violated when
10831086
creating variables or submodules in only one branch.
10841087
Because initializing variables in just one branch
1085-
causes the paramater structure to be different.
1088+
causes the parameter structure to be different.
10861089
10871090
Example::
10881091
@@ -1098,15 +1101,15 @@ def false_fn(mdl, x):
10981101
mdl.variable('state', 'false_count').value += 1
10991102
return -nn.Dense(2, name='dense')(x)
11001103
return nn.cond(pred, true_fn, false_fn, self, x)
1101-
1102-
1104+
1105+
11031106
Args:
11041107
pred: determines if true_fun or false_fun is evaluated.
11051108
true_fun: The function evalauted when ``pred`` is `True`.
1106-
The signature is (Scope, *operands) -> T.
1109+
The signature is (module, *operands) -> T.
11071110
false_fun: The function evalauted when ``pred`` is `False`.
1108-
The signature is (Scope, *operands) -> T.
1109-
scope: A Scope or Pytree of scopes to pass
1111+
The signature is (module, *operands) -> T.
1112+
mdl: A Module target to pass.
11101113
*operands: The arguments passed to ``true_fun`` and ``false_fun``
11111114
variables: The variable collections passed to the conditional
11121115
branches (default: all)

tests/linen/linen_transforms_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,6 +1451,27 @@ def body_fn(mdl, x):
14511451
self.assertEqual(v['params']['Dense_0']['kernel'].shape, (5, 3, 3))
14521452
m.apply(v, x)
14531453

1454+
def test_bound_methods_in_direct_transforms(self):
1455+
class CondModel(nn.Module):
1456+
def setup(self):
1457+
self.dense = nn.Dense(3)
1458+
1459+
def f1(self, arr):
1460+
arr = self.dense(arr)
1461+
return arr
1462+
1463+
def f2(self, arr):
1464+
_ = self.dense(arr)
1465+
return arr
1466+
1467+
def __call__(self, x):
1468+
return nn.cond(x.sum() > 0, self.f1, self.f2, self, x)
1469+
1470+
cond_model = CondModel()
1471+
1472+
output, init_params = jax.jit(cond_model.init_with_output)(
1473+
jax.random.PRNGKey(0),
1474+
x=jnp.ones(3))
14541475

14551476

14561477
if __name__ == '__main__':

0 commit comments

Comments
 (0)