Skip to content

Commit 0ecbc7f

Browse files
fehiepsiOlaRonning
authored andcommitted
Improve subsample warning keys (pyro-ppl#1303)
1 parent 2673ef1 commit 0ecbc7f

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

numpyro/infer/autoguide.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import numpy as np
1111

1212
import jax
13-
from jax import grad, hessian, lax, random, tree_map
13+
from jax import grad, hessian, lax, random
14+
from jax.tree_util import tree_map
1415

1516
from numpyro.util import _versiontuple, find_stack_level
1617

numpyro/primitives.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,12 @@ def module(name, nn, input_shape=None):
383383

384384

385385
def _subsample_fn(size, subsample_size, rng_key=None):
386-
assert rng_key is not None, "Missing random key to generate subsample indices."
386+
if rng_key is None:
387+
raise ValueError(
388+
"Missing random key to generate subsample indices."
389+
" Algorithms like HMC/NUTS do not support subsampling."
390+
" You might want to use SVI or HMCECS instead."
391+
)
387392
if jax.default_backend() == "cpu":
388393
# ref: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm
389394
rng_keys = random.split(rng_key, subsample_size)

0 commit comments

Comments
 (0)