From fa9dd175a5f58ed9ee866f09b8f78ca12cde9dc6 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 29 Apr 2022 20:03:25 -0700 Subject: [PATCH] [remove-units] partial_eval_jaxpr -> partial_eval_jaxpr_nounits --- flax/core/axes_scan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flax/core/axes_scan.py b/flax/core/axes_scan.py index bc09170d2..23437dd1d 100644 --- a/flax/core/axes_scan.py +++ b/flax/core/axes_scan.py @@ -132,7 +132,7 @@ def body_fn(c, xs, init_mode=False): f_flat, out_tree = jax.api_util.flatten_fun_nokwargs( lu.wrap_init(broadcast_body), in_tree) in_pvals = list(map(pe.PartialVal.unknown, in_avals)) - _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals) + _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals) out_flat = [] for pv, const in out_pvals: