Skip to content

Commit

Permalink
Let MEIS return a Proposal
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanheyder committed Aug 26, 2024
1 parent 1b5df24 commit 553ff6c
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 59 deletions.
8 changes: 4 additions & 4 deletions isssm/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,17 +194,17 @@ def mle_pgssm(
def f(theta, key):
model = model_fn(theta, aux)

proposal, info = laplace_approximation(y, model, n_iter_la)
propsal_la, _ = laplace_approximation(y, model, n_iter_la)

key, subkey = jrn.split(key)
z, Omega = modified_efficient_importance_sampling(
y, model, proposal.z, proposal.Omega, n_iter_la, N, subkey
proposal_meis, _ = modified_efficient_importance_sampling(
y, model, propsal_la.z, propsal_la.Omega, n_iter_la, N, subkey
)

key, subkey = jrn.split(key)
# improve numerical stability by dividing by number of observations
n_obs = y.size
return pgnll(y, model, z, Omega, N, subkey) / n_obs
return pgnll(y, model, proposal_meis.z, proposal_meis.Omega, N, subkey) / n_obs

key, subkey = jrn.split(key)
result = minimize_scipy(f, theta0, method="BFGS", options=options, args=(subkey,))
Expand Down
30 changes: 26 additions & 4 deletions isssm/modified_efficient_importance_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .util import MVN_degenerate as MVN, mm_sim

from .glssm import mm_sim
from .typing import GLSSM, PGSSM
from .typing import GLSSM, PGSSM, GLSSMProposal, ConvergenceInformation


@jit
Expand Down Expand Up @@ -67,7 +67,7 @@ def modified_efficient_importance_sampling(
N: int, # number of samples
key: PRNGKeyArray, # random key
eps: Float = 1e-5, # convergence threshold
):
) -> tuple[GLSSMProposal, ConvergenceInformation]:
z, Omega = z_init, Omega_init

np1, p, m = model.B.shape
Expand Down Expand Up @@ -124,10 +124,32 @@ def _iteration(val):

_keep_going = lambda *args: jnp.logical_not(_break(*args))

n_iters, z, Omega, _, _ = while_loop(
n_iters, z, Omega, z_old, Omega_old = while_loop(
_keep_going,
_iteration,
(0, z_init, Omega_init, jnp.empty_like(z_init), jnp.empty_like(Omega_init)),
)

return z, Omega
proposal = GLSSMProposal(
u=model.u,
A=model.A,
D=model.D,
Sigma0=model.Sigma0,
Sigma=model.Sigma,
v=model.v,
B=model.B,
Omega=Omega,
z=z,
)

delta_z = jnp.max(jnp.abs(z - z_old))
delta_Omega = jnp.max(jnp.abs(Omega - Omega_old))
information = ConvergenceInformation(
converged=jnp.logical_and(
converged(z, z_old, eps), converged(Omega, Omega_old, eps)
),
n_iter=n_iters,
delta=jnp.max(jnp.array([delta_z, delta_Omega])),
)

return proposal, information
62 changes: 40 additions & 22 deletions nbs/50_modified_efficient_importance_sampling.ipynb

Large diffs are not rendered by default.

58 changes: 29 additions & 29 deletions nbs/60_maximum_likelihood_estimation.ipynb

Large diffs are not rendered by default.

0 comments on commit 553ff6c

Please sign in to comment.