Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 6178042

Browse files
Add a Numba implementation for Generator.dirichlet
1 parent 09601d1 commit 6178042

File tree

2 files changed

+78
-2
lines changed

2 files changed

+78
-2
lines changed

aesara/link/numba/dispatch/random.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import numba.np.unsafe.ndarray as numba_ndarray
77
import numpy as np
88
from numba.core import types
9-
from numba.core.extending import overload
9+
from numba.core.extending import overload, overload_method, register_jitable
10+
from numba.np.random.distributions import random_beta, random_standard_gamma
11+
from numba.np.random.generator_methods import check_size, check_types, is_nonelike
1012

1113
import aesara.tensor.random.basic as aer
1214
from aesara.graph.basic import Apply
@@ -287,3 +289,78 @@ def dirichlet_rv(rng, size, dtype, alphas):
287289
return (rng, rng.dirichlet(alphas, size))
288290

289291
return dirichlet_rv
292+
293+
294+
@register_jitable
295+
def random_dirichlet(bitgen, alpha, size):
296+
"""
297+
This implementation is straight from ``numpy/random/_generator.pyx``.
298+
"""
299+
300+
k = len(alpha)
301+
alpha_arr = np.asarray(alpha, dtype=np.float64)
302+
303+
if np.any(np.less_equal(alpha_arr, 0)):
304+
raise ValueError("alpha <= 0")
305+
306+
shape = size + (k,)
307+
308+
diric = np.zeros(shape, np.float64)
309+
310+
i = 0
311+
totsize = diric.size
312+
313+
if (k > 0) and (alpha_arr.max() < 0.1):
314+
alpha_csum_arr = np.empty_like(alpha_arr)
315+
csum = 0.0
316+
for j in range(k - 1, -1, -1):
317+
csum += alpha_arr[j]
318+
alpha_csum_arr[j] = csum
319+
320+
while i < totsize:
321+
acc = 1.0
322+
for j in range(k - 1):
323+
v = random_beta(bitgen, alpha_arr[j], alpha_csum_arr[j + 1])
324+
diric[i + j] = acc * v
325+
acc *= 1.0 - v
326+
diric[i + k - 1] = acc
327+
i = i + k
328+
329+
else:
330+
while i < totsize:
331+
acc = 0.0
332+
for j in range(k):
333+
diric[i + j] = random_standard_gamma(bitgen, alpha_arr[j])
334+
acc = acc + diric[i + j]
335+
invacc = 1.0 / acc
336+
for j in range(k):
337+
diric[i + j] = diric[i + j] * invacc
338+
i = i + k
339+
340+
return diric
341+
342+
343+
@overload_method(types.NumPyRandomGeneratorType, "dirichlet")
344+
def NumPyRandomGeneratorType_dirichlet(inst, alphas, size=None):
345+
check_types(alphas, [types.Array, types.List], "alphas")
346+
347+
if isinstance(size, types.Omitted):
348+
size = size.value
349+
350+
if is_nonelike(size):
351+
352+
def impl(inst, alphas, size=None):
353+
return random_dirichlet(inst.bit_generator, alphas, ())
354+
355+
elif isinstance(size, (int, types.Integer)):
356+
357+
def impl(inst, alphas, size=None):
358+
return random_dirichlet(inst.bit_generator, alphas, (size,))
359+
360+
else:
361+
check_size(size)
362+
363+
def impl(inst, alphas, size=None):
364+
return random_dirichlet(inst.bit_generator, alphas, size)
365+
366+
return impl

tests/link/numba/test_random.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,6 @@ def test_CategoricalRV(dist_args, size, cm):
520520
)
521521

522522

523-
@pytest.mark.skip(reason="Not yet supported in Numba via `Generator`s")
524523
@pytest.mark.parametrize(
525524
"a, size, cm",
526525
[

0 commit comments

Comments
 (0)