Skip to content

Commit

Permalink
Merge pull request #27 from arviz-devs/obs
Browse files Browse the repository at this point in the history
Use observe, allow multiple observed variables
  • Loading branch information
aloctavodia authored Feb 14, 2025
2 parents 62568d6 + 0f74a86 commit a189866
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
6 changes: 3 additions & 3 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ click==8.0.4
pytest-cov>=2.6.1
pytest>=4.4.0
pre-commit>=2.19
ipytest==0.13.0
pymc>=5.20.1
ipytest==0.13.0
pymc @ git+https://github.com/pymc-devs/pymc@main
bambi>=0.15.0
arviz>=0.20.0
ruff==0.9.1
ruff==0.9.1
12 changes: 8 additions & 4 deletions simuk/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
self.formula = model.formula
self.new_data = copy(model.data)

self.observed_var = self.model.observed_RVs[0].name
self.observed_vars = [obs_rvs.name for obs_rvs in self.model.observed_RVs]
self.num_simulations = num_simulations

self.var_names = [v.name for v in self.model.free_RVs]
Expand Down Expand Up @@ -107,8 +107,8 @@ def _get_prior_predictive_samples(self):

def _get_posterior_samples(self, prior_predictive_draw):
"""Generate posterior samples conditioned to a prior predictive sample."""
model_do = pm.do(self.model, {self.observed_var: prior_predictive_draw})
with model_do:
new_model = pm.observe(self.model, prior_predictive_draw)
with new_model:
check = pm.sample(**self.sample_kwargs)

posterior = az.extract(check, group="posterior")
Expand All @@ -133,7 +133,11 @@ def run_simulations(self):
try:
while self._simulations_complete < self.num_simulations:
idx = self._simulations_complete
prior_predictive_draw = prior_pred[self.observed_var].sel(chain=0, draw=idx).values
prior_predictive_draw = {
var_name: prior_pred[var_name].sel(chain=0, draw=idx).values
for var_name in self.observed_vars
}

np.random.seed(seeds[idx])

posterior = self._get_posterior_samples(prior_predictive_draw)
Expand Down

0 comments on commit a189866

Please sign in to comment.