38
38
from flax .linen .module import Module
39
39
from flax .linen .module import Variable
40
40
from flax .linen .module import wrap_method_once
41
+ from flax .linen .module import _get_unbound_fn
41
42
import jax
42
43
43
44
traceback_util .register_exclusion (__file__ )
@@ -419,6 +420,8 @@ def lift_direct_transform(transform: Callable[..., Any],
419
420
' That is function that takes a Module instance as its first arg.' )
420
421
elif not callable (target ):
421
422
raise ValueError ('transform target must be callable' )
423
+ # normalize self.foo bound methods to class.foo unbound methods.
424
+ targets = tuple (_get_unbound_fn (target ) for target in targets )
422
425
aug_transform = lambda * fns : functools .partial (transform , * fns )
423
426
return decorator_lift_transform (
424
427
aug_transform , targets , multi_scope = multi_scope )(mdl , * args , ** kwargs )
@@ -1068,10 +1071,10 @@ def _cond_wrapper(t_fn, f_fn, scope, pred, *ops, variables, rngs):
1068
1071
1069
1072
1070
1073
def cond (
1071
- pred : Any ,
1074
+ pred : Any ,
1072
1075
true_fun : Callable [..., C ], false_fun : Callable [..., C ],
1073
1076
mdl : Module , * operands ,
1074
- variables : lift .CollectionFilter = True ,
1077
+ variables : lift .CollectionFilter = True ,
1075
1078
rngs : lift .PRNGSequenceFilter = True ) -> C :
1076
1079
"""Lifted version of ``jax.lax.cond``.
1077
1080
@@ -1082,7 +1085,7 @@ def cond(
1082
1085
Note that this constraint is violated when
1083
1086
creating variables or submodules in only one branch.
1084
1087
Because initializing variables in just one branch
1085
- causes the paramater structure to be different.
1088
+ causes the parameter structure to be different.
1086
1089
1087
1090
Example::
1088
1091
@@ -1098,15 +1101,15 @@ def false_fn(mdl, x):
1098
1101
mdl.variable('state', 'false_count').value += 1
1099
1102
return -nn.Dense(2, name='dense')(x)
1100
1103
return nn.cond(pred, true_fn, false_fn, self, x)
1101
-
1102
-
1104
+
1105
+
1103
1106
Args:
1104
1107
pred: determines if true_fun or false_fun is evaluated.
1105
1108
true_fun: The function evalauted when ``pred`` is `True`.
1106
- The signature is (Scope , *operands) -> T.
1109
+ The signature is (module , *operands) -> T.
1107
1110
false_fun: The function evalauted when ``pred`` is `False`.
1108
- The signature is (Scope , *operands) -> T.
1109
- scope : A Scope or Pytree of scopes to pass
1111
+ The signature is (module , *operands) -> T.
1112
+ mdl : A Module target to pass.
1110
1113
*operands: The arguments passed to ``true_fun`` and ``false_fun``
1111
1114
variables: The variable collections passed to the conditional
1112
1115
branches (default: all)
0 commit comments