Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 77c077a

Browse files
Add a Numba implementation for Generator.gumbel
1 parent 6178042 commit 77c077a

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

aesara/link/numba/dispatch/random.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from copy import copy
2+
from math import log
23
from textwrap import dedent, indent
34
from typing import Callable, Optional
45

@@ -8,6 +9,7 @@
89
from numba.core import types
910
from numba.core.extending import overload, overload_method, register_jitable
1011
from numba.np.random.distributions import random_beta, random_standard_gamma
12+
from numba.np.random.generator_core import next_double
1113
from numba.np.random.generator_methods import check_size, check_types, is_nonelike
1214

1315
import aesara.tensor.random.basic as aer
@@ -364,3 +366,39 @@ def impl(inst, alphas, size=None):
364366
return random_dirichlet(inst.bit_generator, alphas, size)
365367

366368
return impl
369+
370+
371+
@register_jitable
372+
def random_gumbel(bitgen, loc, scale):
373+
"""
374+
This implementation is adapted from ``numpy/random/src/distributions/distributions.c``.
375+
"""
376+
while True:
377+
u = 1.0 - next_double(bitgen)
378+
if u < 1.0:
379+
return loc - scale * log(-log(u))
380+
381+
382+
@overload_method(types.NumPyRandomGeneratorType, "gumbel")
383+
def NumPyRandomGeneratorType_gumbel(inst, loc=0.0, scale=1.0, size=None):
384+
check_types(loc, [types.Float, types.Integer, int, float], "loc")
385+
check_types(scale, [types.Float, types.Integer, int, float], "scale")
386+
387+
if isinstance(size, types.Omitted):
388+
size = size.value
389+
390+
if is_nonelike(size):
391+
392+
def impl(inst, loc=0.0, scale=1.0, size=None):
393+
return random_gumbel(inst.bit_generator, loc, scale)
394+
395+
else:
396+
check_size(size)
397+
398+
def impl(inst, loc=0.0, scale=1.0, size=None):
399+
out = np.empty(size)
400+
for i in np.ndindex(size):
401+
out[i] = random_gumbel(inst.bit_generator, loc, scale)
402+
return out
403+
404+
return impl

tests/link/numba/test_random.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
362362
"chi2",
363363
lambda *args: args,
364364
),
365-
pytest.param(
365+
(
366366
aer.gumbel,
367367
[
368368
set_test_value(
@@ -377,9 +377,6 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
377377
(2,),
378378
"gumbel_r",
379379
lambda *args: args,
380-
marks=pytest.mark.skip(
381-
reason="Not yet supported in Numba via `Generator`s"
382-
),
383380
),
384381
(
385382
aer.negative_binomial,

0 commit comments

Comments
 (0)