Skip to content

Exclude numba rewrites from JAX Scan rewrites#1825

Merged
ricardoV94 merged 1 commit intopymc-devs:mainfrom
ricardoV94:jax_scan_rewrite_numba
Jan 5, 2026
Merged

Exclude numba rewrites from JAX Scan rewrites#1825
ricardoV94 merged 1 commit intopymc-devs:mainfrom
ricardoV94:jax_scan_rewrite_numba

Conversation

@ricardoV94
Copy link
Member

Now that the default mode can be Numba, we could get numba-specific rewrites (which are never the default otherwise) in the JAX scan optimization.

This is a symptom of Scan having the mode concept, which we should move away from. For now the band-aid is simple enough.

Led to failures in pymc-devs/pymc-extras#615

@ricardoV94 ricardoV94 added bug Something isn't working backend compatibility scan labels Jan 5, 2026
# Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode)
rewriter = (
get_mode(op.mode).including("jax").excluding(*JAX._optimizer.exclude).optimizer
get_mode(op.mode)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically when op.mode is None, it will return a NumbaLinker if pytensor.config.linker="numba" or pytensor.config.mode="NUMBA".

@ricardoV94 ricardoV94 requested review from jessegrabowski and removed request for jessegrabowski January 5, 2026 17:32
@ricardoV94 ricardoV94 merged commit 25e41c3 into pymc-devs:main Jan 5, 2026
64 checks passed
@ricardoV94 ricardoV94 deleted the jax_scan_rewrite_numba branch January 5, 2026 17:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants