Skip to content

Commit 553ff6c

Browse files
committed
Let MEIS return a Proposal
1 parent 1b5df24 commit 553ff6c

File tree

4 files changed

+99
-59
lines changed

4 files changed

+99
-59
lines changed

isssm/estimation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,17 +194,17 @@ def mle_pgssm(
194194
def f(theta, key):
195195
model = model_fn(theta, aux)
196196

197-
proposal, info = laplace_approximation(y, model, n_iter_la)
197+
propsal_la, _ = laplace_approximation(y, model, n_iter_la)
198198

199199
key, subkey = jrn.split(key)
200-
z, Omega = modified_efficient_importance_sampling(
201-
y, model, proposal.z, proposal.Omega, n_iter_la, N, subkey
200+
proposal_meis, _ = modified_efficient_importance_sampling(
201+
y, model, propsal_la.z, propsal_la.Omega, n_iter_la, N, subkey
202202
)
203203

204204
key, subkey = jrn.split(key)
205205
# improve numerical stability by dividing by number of observations
206206
n_obs = y.size
207-
return pgnll(y, model, z, Omega, N, subkey) / n_obs
207+
return pgnll(y, model, proposal_meis.z, proposal_meis.Omega, N, subkey) / n_obs
208208

209209
key, subkey = jrn.split(key)
210210
result = minimize_scipy(f, theta0, method="BFGS", options=options, args=(subkey,))

isssm/modified_efficient_importance_sampling.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .util import MVN_degenerate as MVN, mm_sim
1818

1919
from .glssm import mm_sim
20-
from .typing import GLSSM, PGSSM
20+
from .typing import GLSSM, PGSSM, GLSSMProposal, ConvergenceInformation
2121

2222

2323
@jit
@@ -67,7 +67,7 @@ def modified_efficient_importance_sampling(
6767
N: int, # number of samples
6868
key: PRNGKeyArray, # random key
6969
eps: Float = 1e-5, # convergence threshold
70-
):
70+
) -> tuple[GLSSMProposal, ConvergenceInformation]:
7171
z, Omega = z_init, Omega_init
7272

7373
np1, p, m = model.B.shape
@@ -124,10 +124,32 @@ def _iteration(val):
124124

125125
_keep_going = lambda *args: jnp.logical_not(_break(*args))
126126

127-
n_iters, z, Omega, _, _ = while_loop(
127+
n_iters, z, Omega, z_old, Omega_old = while_loop(
128128
_keep_going,
129129
_iteration,
130130
(0, z_init, Omega_init, jnp.empty_like(z_init), jnp.empty_like(Omega_init)),
131131
)
132132

133-
return z, Omega
133+
proposal = GLSSMProposal(
134+
u=model.u,
135+
A=model.A,
136+
D=model.D,
137+
Sigma0=model.Sigma0,
138+
Sigma=model.Sigma,
139+
v=model.v,
140+
B=model.B,
141+
Omega=Omega,
142+
z=z,
143+
)
144+
145+
delta_z = jnp.max(jnp.abs(z - z_old))
146+
delta_Omega = jnp.max(jnp.abs(Omega - Omega_old))
147+
information = ConvergenceInformation(
148+
converged=jnp.logical_and(
149+
converged(z, z_old, eps), converged(Omega, Omega_old, eps)
150+
),
151+
n_iter=n_iters,
152+
delta=jnp.max(jnp.array([delta_z, delta_Omega])),
153+
)
154+
155+
return proposal, information

nbs/50_modified_efficient_importance_sampling.ipynb

Lines changed: 40 additions & 22 deletions
Large diffs are not rendered by default.

nbs/60_maximum_likelihood_estimation.ipynb

Lines changed: 29 additions & 29 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)