Skip to content

Commit

Permalink
Merge pull request #2172 from levskaya:condfix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 453935471
  • Loading branch information
Flax Authors committed Jun 9, 2022
2 parents 25f2920 + 294bb41 commit 19fc095
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 9 deletions.
3 changes: 2 additions & 1 deletion flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,8 @@ def _get_unbound_fn(method_or_fn: Callable[..., Any]) -> Callable[..., Any]:
Returns:
An unbound version of input function.
"""
if inspect.ismethod(method_or_fn):
if (inspect.ismethod(method_or_fn) and
isinstance(method_or_fn.__self__, Module)): # pytype: disable=attribute-error
method_or_fn = method_or_fn.__func__ # pytype: disable=attribute-error

# The method should be callable, and it should have at least one argument
Expand Down
19 changes: 11 additions & 8 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from flax.linen.module import Module
from flax.linen.module import Variable
from flax.linen.module import wrap_method_once
from flax.linen.module import _get_unbound_fn
import jax

traceback_util.register_exclusion(__file__)
Expand Down Expand Up @@ -419,6 +420,8 @@ def lift_direct_transform(transform: Callable[..., Any],
' That is function that takes a Module instance as its first arg.')
elif not callable(target):
raise ValueError('transform target must be callable')
# normalize self.foo bound methods to class.foo unbound methods.
targets = tuple(_get_unbound_fn(target) for target in targets)
aug_transform = lambda *fns: functools.partial(transform, *fns)
return decorator_lift_transform(
aug_transform, targets, multi_scope=multi_scope)(mdl, *args, **kwargs)
Expand Down Expand Up @@ -1068,10 +1071,10 @@ def _cond_wrapper(t_fn, f_fn, scope, pred, *ops, variables, rngs):


def cond(
pred: Any,
pred: Any,
true_fun: Callable[..., C], false_fun: Callable[..., C],
mdl: Module, *operands,
variables: lift.CollectionFilter = True,
variables: lift.CollectionFilter = True,
rngs: lift.PRNGSequenceFilter = True) -> C:
"""Lifted version of ``jax.lax.cond``.
Expand All @@ -1082,7 +1085,7 @@ def cond(
Note that this constraint is violated when
creating variables or submodules in only one branch.
Because initializing variables in just one branch
causes the paramater structure to be different.
causes the parameter structure to be different.
Example::
Expand All @@ -1098,15 +1101,15 @@ def false_fn(mdl, x):
mdl.variable('state', 'false_count').value += 1
return -nn.Dense(2, name='dense')(x)
return nn.cond(pred, true_fn, false_fn, self, x)
Args:
pred: determines if true_fun or false_fun is evaluated.
true_fun: The function evalauted when ``pred`` is `True`.
The signature is (Scope, *operands) -> T.
The signature is (module, *operands) -> T.
false_fun: The function evalauted when ``pred`` is `False`.
The signature is (Scope, *operands) -> T.
scope: A Scope or Pytree of scopes to pass
The signature is (module, *operands) -> T.
mdl: A Module target to pass.
*operands: The arguments passed to ``true_fun`` and ``false_fun``
variables: The variable collections passed to the conditional
branches (default: all)
Expand Down
21 changes: 21 additions & 0 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,6 +1451,27 @@ def body_fn(mdl, x):
self.assertEqual(v['params']['Dense_0']['kernel'].shape, (5, 3, 3))
m.apply(v, x)

def test_bound_methods_in_direct_transforms(self):
class CondModel(nn.Module):
def setup(self):
self.dense = nn.Dense(3)

def f1(self, arr):
arr = self.dense(arr)
return arr

def f2(self, arr):
_ = self.dense(arr)
return arr

def __call__(self, x):
return nn.cond(x.sum() > 0, self.f1, self.f2, self, x)

cond_model = CondModel()

output, init_params = jax.jit(cond_model.init_with_output)(
jax.random.PRNGKey(0),
x=jnp.ones(3))


if __name__ == '__main__':
Expand Down

0 comments on commit 19fc095

Please sign in to comment.