Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 09601d1

Browse files
Use Numba's new Generator support for sampling
1 parent 079a165 commit 09601d1

File tree

6 files changed

+207
-327
lines changed

6 files changed

+207
-327
lines changed

aesara/link/numba/dispatch/basic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from aesara.scalar.math import Softplus
3232
from aesara.tensor.blas import BatchedDot
3333
from aesara.tensor.math import Dot
34+
from aesara.tensor.random.type import RandomGeneratorType
3435
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
3536
from aesara.tensor.slinalg import Cholesky, Solve
3637
from aesara.tensor.subtensor import (
@@ -92,6 +93,8 @@ def get_numba_type(
9293
dtype = np.dtype(aesara_type.dtype)
9394
numba_dtype = numba.from_dtype(dtype)
9495
return numba_dtype
96+
elif isinstance(aesara_type, RandomGeneratorType):
97+
return numba.types.npy_rng
9598
else:
9699
raise NotImplementedError(f"Numba type not implemented for {aesara_type}")
97100

0 commit comments

Comments
 (0)