Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for numpyro models in SBC #30

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions docs/examples/gallery/sbc.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,52 @@ sbc.plot_results(kind="hist")
```

:::::

:::::{tab-item} Numpyro
:sync: numpyro

We define a Numpyro Model, we use the centered eight schools model.

```{jupyter-execute}

import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import NUTS

y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

def eight_schools_cauchy_prior(J, sigma, y=None):
mu = numpyro.sample("mu", dist.Normal(0, 5))
tau = numpyro.sample("tau", dist.HalfCauchy(5))
with numpyro.plate("J", J):
theta = numpyro.sample("theta", dist.Normal(mu, tau))
numpyro.sample("y", dist.Normal(theta, sigma), obs=y)

# We use the NUTS sampler
nuts_kernel = NUTS(eight_schools_cauchy_prior)
```

Pass the model to the `SBC` class, set the number of simulations to 8, and run the simulations. For numpyro model,
we pass in the ``data_dir`` parameter.

```{jupyter-execute}

sbc = simuk.SBC(nuts_kernel,
sample_kwargs={"num_warmup": 1000, "num_samples": 1000, "progress_bar": False},
num_simulations=8,
seed=random.PRNGKey(10),
data_dir={"J": 8, "sigma": sigma, "y": y},
)

sbc.run_simulations()
```

To compare the prior and posterior distributions, we will plot the results. You can customize the visualization type
using the `kind` parameter. The example below displays a histogram.

```{jupyter-execute}

sbc.plot_results(kind="hist")
```
2 changes: 1 addition & 1 deletion environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ dependencies:
- pytest>=4.4.0
- pre-commit>=2.19
- ruff==0.9.1

- numpyro>=0.17.0
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ pymc @ git+https://github.com/pymc-devs/pymc@main
bambi>=0.15.0
arviz>=0.20.0
ruff==0.9.1
numpyro>=0.17.0
1 change: 1 addition & 0 deletions requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ sphinx_tabs
sphinx-design
numpydoc
jupyter-sphinx
numpyro>=0.17.0
86 changes: 68 additions & 18 deletions simuk/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
import arviz as az
import numpy as np
import pymc as pm

try:
from numpyro.handlers import seed, trace
from numpyro.infer import MCMC, Predictive
from numpyro.infer.mcmc import MCMCKernel
except ImportError:
pass
from tqdm import tqdm

from simuk.plots import plot_results
Expand Down Expand Up @@ -36,8 +43,8 @@ class SBC:

Parameters
----------
model : function
A PyMC or Bambi model. If a PyMC model the data needs to be defined as
model : pymc.Model, bambi.Model or numpyro.infer.mcmc.MCMCKernel
A PyMC, Bambi model or Numpyro MCMC kernel. If a PyMC model the data needs to be defined as
mutable data.
num_simulations : int
How many simulations to run
Expand All @@ -46,6 +53,9 @@ class SBC:
seed : int (optional)
Random seed. This persists even if running the simulations is
paused for whatever reason.
data_dir : dict
Keyword arguments passed to numpyro.sample, intended for use when providing
an MCMC Kernel model.

Example
-------
Expand All @@ -62,38 +72,61 @@ class SBC:

"""

def __init__(
self,
model,
num_simulations=1000,
sample_kwargs=None,
seed=None,
):
def __init__(self, model, num_simulations=1000, sample_kwargs=None, seed=None, data_dir=None):
if isinstance(model, pm.Model):
self.engine = "pymc"
self.model = model
elif isinstance(model, MCMCKernel):
self.engine = "numpyro"
self.numpyro_model = model
self.model = self.numpyro_model.model
self._get_posterior_samples = self._get_posterior_samples_numpyro
self._get_prior_predictive_samples = self._get_prior_predictive_samples_numpyro
self.data_dir = data_dir
else:
self.engine = "bambi"
model.build()
self.bambi_model = model
self.model = model.backend.model
self.formula = model.formula
self.new_data = copy(model.data)

self.observed_vars = [obs_rvs.name for obs_rvs in self.model.observed_RVs]
self.num_simulations = num_simulations

self.var_names = [v.name for v in self.model.free_RVs]

if sample_kwargs is None:
sample_kwargs = {}
sample_kwargs.setdefault("progressbar", False)
sample_kwargs.setdefault("compute_convergence_checks", False)
if self.engine == "numpyro":
sample_kwargs.setdefault("num_warmup", 1000)
sample_kwargs.setdefault("num_samples", 1000)
sample_kwargs.setdefault("progress_bar", False)
else:
sample_kwargs.setdefault("progressbar", False)
sample_kwargs.setdefault("compute_convergence_checks", False)
self.sample_kwargs = sample_kwargs
self._seed = seed
self._extract_variable_names()

self.simulations = {name: [] for name in self.var_names}
self._simulations_complete = 0
self._seed = seed

def _extract_variable_names(self):
"""Extract observed and free variables from the model."""
if self.engine == "numpyro":
with trace() as tr:
with seed(rng_seed=self._seed):
self.numpyro_model.model(**self.data_dir)

self.var_names = [
name
for name, site in tr.items()
if site["type"] == "sample" and not site.get("is_observed", False)
]
self.observed_vars = [
name
for name, site in tr.items()
if site["type"] == "sample" and site.get("is_observed", False)
]
else:
self.observed_vars = [obs.name for obs in self.model.observed_RVs]
self.var_names = [v.name for v in self.model.free_RVs]

def _get_seeds(self):
"""Set the random seed, and generate seeds for all the simulations."""
Expand All @@ -109,6 +142,15 @@ def _get_prior_predictive_samples(self):
prior = az.extract(idata, group="prior")
return prior, prior_pred

def _get_prior_predictive_samples_numpyro(self):
"""Generate samples to use for the simulations using numpyro."""
predictive = Predictive(self.model, num_samples=self.num_simulations)
samples = predictive(self._seed, **self.data_dir)
prior = {k: v for k, v in samples.items() if k not in self.observed_vars}
prior_pred = {k: v for k, v in samples.items() if k in self.observed_vars}
idata = az.from_dict(prior=prior, prior_predictive=prior_pred)
return az.extract(idata, group="prior"), az.extract(idata, group="prior_predictive")

def _get_posterior_samples(self, prior_predictive_draw):
"""Generate posterior samples conditioned to a prior predictive sample."""
new_model = pm.observe(self.model, prior_predictive_draw)
Expand All @@ -118,7 +160,15 @@ def _get_posterior_samples(self, prior_predictive_draw):
posterior = az.extract(check, group="posterior")
return posterior

@quiet_logging("pymc", "pytensor.gof.compilelock", "bambi")
def _get_posterior_samples_numpyro(self, prior_predictive_draw):
"""Generate posterior samples using numpyro conditioned to a prior predictive sample."""
mcmc = MCMC(self.numpyro_model, **self.sample_kwargs)
free_vars_data = {k: v for k, v in self.data_dir.items() if k not in self.observed_vars}
mcmc.run(self._seed, **free_vars_data, **prior_predictive_draw)
idata = az.from_dict(posterior=mcmc.get_samples())
return az.extract(idata, group="posterior")

@quiet_logging("pymc", "pytensor.gof.compilelock", "bambi", "numpyro")
def run_simulations(self):
"""Run all the simulations.

Expand Down
24 changes: 24 additions & 0 deletions simuk/tests/test_sbc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import bambi as bmb
import numpy as np
import numpyro
import numpyro.distributions as dist
import pandas as pd
import pymc as pm
import pytest
from jax import random
from numpyro.infer import NUTS

import simuk

Expand Down Expand Up @@ -32,3 +36,23 @@ def test_sbc(model, kind):
)
sbc.run_simulations()
sbc.plot_results(kind=kind)


@pytest.mark.parametrize("kind, num_simulations", [("ecdf", 5), ("hist", 8)])
def test_sbc_numpyro(kind, num_simulations):
def eight_schools_cauchy_prior(J, sigma, y=None):
mu = numpyro.sample("mu", dist.Normal(0, 5))
tau = numpyro.sample("tau", dist.HalfCauchy(5))
with numpyro.plate("J", J):
theta = numpyro.sample("theta", dist.Normal(mu, tau))
numpyro.sample("y", dist.Normal(theta, sigma), obs=y)

sbc = simuk.SBC(
NUTS(eight_schools_cauchy_prior),
sample_kwargs={"num_warmup": 1000, "num_samples": 1000, "progress_bar": False},
num_simulations=num_simulations,
seed=random.PRNGKey(10),
data_dir={"J": 8, "sigma": sigma, "y": y},
)
sbc.run_simulations()
sbc.plot_results(kind=kind)