diff --git a/pytensor/link/jax/dispatch/scan.py b/pytensor/link/jax/dispatch/scan.py index 881eb8123b..32a7e3a988 100644 --- a/pytensor/link/jax/dispatch/scan.py +++ b/pytensor/link/jax/dispatch/scan.py @@ -29,7 +29,10 @@ def jax_funcify_Scan(op: Scan, **kwargs): # Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode) rewriter = ( - get_mode(op.mode).including("jax").excluding(*JAX._optimizer.exclude).optimizer + get_mode(op.mode) + .including("jax") + .excluding("numba", *JAX._optimizer.exclude) + .optimizer ) rewriter(op.fgraph) scan_inner_func = jax_funcify(op.fgraph, **kwargs)