diff --git a/flax/core/axes_scan.py b/flax/core/axes_scan.py index bc09170d2..23437dd1d 100644 --- a/flax/core/axes_scan.py +++ b/flax/core/axes_scan.py @@ -132,7 +132,7 @@ def body_fn(c, xs, init_mode=False): f_flat, out_tree = jax.api_util.flatten_fun_nokwargs( lu.wrap_init(broadcast_body), in_tree) in_pvals = list(map(pe.PartialVal.unknown, in_avals)) - _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals) + _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals) out_flat = [] for pv, const in out_pvals: