Skip to content

Commit

Permalink
Merge pull request #25 from arviz-devs/do
Browse files Browse the repository at this point in the history
Use do operator
  • Loading branch information
aloctavodia authored Feb 14, 2025
2 parents 117ae41 + ae36578 commit 956c73e
Showing 1 changed file with 5 additions and 25 deletions.
30 changes: 5 additions & 25 deletions simuk/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,6 @@
import arviz as az
import numpy as np
import pymc as pm

try:
import bambi as bmb
except ImportError:
pass

from tqdm import tqdm

from simuk.plots import plot_results
Expand Down Expand Up @@ -64,7 +58,6 @@ def __init__(
-------
with pm.Model() as model:
obs = pm.MutableData('obs', data)
x = pm.Normal('x')
y = pm.Normal('y', mu=2 * x, observed=obs)
Expand All @@ -75,20 +68,15 @@ def __init__(
if isinstance(model, pm.Model):
self.engine = "pymc"
self.model = model
self.observed_vars = {
model.rvs_to_values[rv].name: rv.name for rv in model.observed_RVs
}
else:
self.engine = "bambi"
model.build()
self.bambi_model = model
self.model = model.backend.model
self.formula = model.formula
self.new_data = copy(model.data)
self.observed_vars = {
model.response_component.term.name: model.response_component.term.name
}

self.observed_var = self.model.observed_RVs[0].name
self.num_simulations = num_simulations

self.var_names = [v.name for v in self.model.free_RVs]
Expand Down Expand Up @@ -119,14 +107,9 @@ def _get_prior_predictive_samples(self):

def _get_posterior_samples(self, prior_predictive_draw):
"""Generate posterior samples conditioned to a prior predictive sample."""
if self.engine == "pymc":
with self.model:
pm.set_data(prior_predictive_draw)
check = pm.sample(**self.sample_kwargs)
else:
for k, v in prior_predictive_draw.items():
self.new_data[k] = v
check = bmb.Model(self.formula, self.new_data).fit(**self.sample_kwargs)
model_do = pm.do(self.model, {self.observed_var: prior_predictive_draw})
with model_do:
check = pm.sample(**self.sample_kwargs)

posterior = az.extract(check, group="posterior")
return posterior
Expand All @@ -150,10 +133,7 @@ def run_simulations(self):
try:
while self._simulations_complete < self.num_simulations:
idx = self._simulations_complete
prior_predictive_draw = {
k: prior_pred[v].sel(chain=0, draw=idx).values
for k, v in self.observed_vars.items()
}
prior_predictive_draw = prior_pred[self.observed_var].sel(chain=0, draw=idx).values
np.random.seed(seeds[idx])

posterior = self._get_posterior_samples(prior_predictive_draw)
Expand Down

0 comments on commit 956c73e

Please sign in to comment.