Skip to content

Commit

Permalink
improve simulation smoother speed
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanheyder committed Aug 26, 2024
1 parent 553ff6c commit 60b0167
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 83 deletions.
15 changes: 8 additions & 7 deletions isssm/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,14 @@ def f(theta):
signal = posterior_mode(proposal)
_, _, x_pred, Xi_pred = kalman(z, glssm_la)

negloglik = (
gnll(z, x_pred, Xi_pred, B, Omega)
- log_weights(signal, y, model.dist, model.xi, z, Omega).sum()
negloglik = gnll(z, x_pred, Xi_pred, B, Omega) - log_weights(
signal, y, model.dist, model.xi, z, Omega
)
# improve numerical stability by dividing by number of observations
n_obs = y.size
return negloglik / n_obs

result = minimize_scipy(f, theta0, method="BFGS", options=options)
result = minimize_scipy(f, theta0, method="BFGS", jac="3-point", options=options)
return result

# %% ../nbs/60_maximum_likelihood_estimation.ipynb 22
Expand All @@ -194,11 +193,11 @@ def mle_pgssm(
def f(theta, key):
model = model_fn(theta, aux)

propsal_la, _ = laplace_approximation(y, model, n_iter_la)
proposal_la, _ = laplace_approximation(y, model, n_iter_la)

key, subkey = jrn.split(key)
proposal_meis, _ = modified_efficient_importance_sampling(
y, model, propsal_la.z, propsal_la.Omega, n_iter_la, N, subkey
y, model, proposal_la.z, proposal_la.Omega, n_iter_la, N, subkey
)

key, subkey = jrn.split(key)
Expand All @@ -207,5 +206,7 @@ def f(theta, key):
return pgnll(y, model, proposal_meis.z, proposal_meis.Omega, N, subkey) / n_obs

key, subkey = jrn.split(key)
result = minimize_scipy(f, theta0, method="BFGS", options=options, args=(subkey,))
result = minimize_scipy(
f, theta0, method="BFGS", jac="3-point", options=options, args=(subkey,)
)
return result
3 changes: 1 addition & 2 deletions isssm/importance_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# %% ../nbs/40_importance_sampling.ipynb 4
from tensorflow_probability.substrates.jax.distributions import (
MultivariateNormalFullCovariance as MVN,
MultivariateNormalDiag as MVN_diag,
)
import jax.numpy as jnp
from jaxtyping import Float, Array
Expand Down Expand Up @@ -40,7 +39,7 @@ def log_weights(
xi: Float[Array, "n+1 p"], # observation parameters
z: Float[Array, "n+1 p"], # synthetic observations
Omega: Float[Array, "n+1 p p"], # synthetic observation covariances:
) -> Float[Array, "n+1"]: # log weights
) -> Float: # log weights
"""Log weights for all time points"""
p_ys = dist(s, xi).log_prob(y).sum()

Expand Down
39 changes: 25 additions & 14 deletions isssm/kalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ def FFBS(
return _simulate_smoothed_FW1994(x_filt, Xi_filt, Xi_pred, model.A, N, subkey)

# %% ../nbs/10_kalman_filter_smoother.ipynb 30
from .util import mm_time


def disturbance_smoother(
filtered: FilterResult, # filter result
y: Observations, # observations
Expand All @@ -250,27 +253,35 @@ def disturbance_smoother(

def step(carry, inputs):
(r,) = carry
y_tilde, A, B, Omega, Xi_pred = inputs
O_Pinv_y, O_KT_AT, BT_Pinv_y, L = inputs

Psi_pred = B @ Xi_pred @ B.T + Omega
Psi_pred_pinv = jnp.linalg.pinv(Psi_pred)
K = Xi_pred @ B.T @ Psi_pred_pinv
eta_smooth = O_Pinv_y - O_KT_AT @ r
r_prev = BT_Pinv_y + L.T @ r

eta_smooth = Omega @ (Psi_pred_pinv @ y_tilde - K.T @ A.T @ r)
L = A @ (jnp.eye(m) - K @ B)
return (r_prev,), eta_smooth

r_prev = B.T @ Psi_pred_pinv @ y_tilde + L.T @ r
A_ext = jnp.concatenate((A, jnp.eye(m)[jnp.newaxis]), axis=0)
BT = B.transpose((0, 2, 1))

return (r_prev,), (eta_smooth, Psi_pred_pinv, K, L)
# offline computation is faster
y_tilde = y - mm_time(B, x_pred)
Psi_pred = B @ Xi_pred @ BT + Omega
Psi_pred_pinv = jnp.linalg.pinv(Psi_pred)
O_Pinv_y = mm_time(Omega @ Psi_pred_pinv, y_tilde)
K = Xi_pred @ BT @ Psi_pred_pinv

y_tilde = y - vmap(jnp.matmul)(B, x_pred)
KT = K.transpose((0, 2, 1))
AT = A_ext.transpose((0, 2, 1))
O_KT_AT = Omega @ KT @ AT
BT_Pinv_y = mm_time(BT @ Psi_pred_pinv, y_tilde)

A_ext = jnp.concatenate((A, jnp.eye(m)[jnp.newaxis]), axis=0)
_, (eta_smooth, Psi_pred_pinv, K, L) = scan(
step, (jnp.zeros(m),), (y_tilde, A_ext, B, Omega, Xi_pred), reverse=True
L = A_ext @ (jnp.eye(m)[None] - K @ B)

_, eta_smooth = scan(
step, (jnp.zeros(m),), (O_Pinv_y, O_KT_AT, BT_Pinv_y, L), reverse=True
)

return eta_smooth, (Psi_pred_pinv, K, L)
return eta_smooth


def smoothed_signals(
Expand All @@ -279,7 +290,7 @@ def smoothed_signals(
model: GLSSM, # model
) -> Float[Array, "n+1 m"]: # smoothed signals
"""compute smoothed signals from filter result"""
eta_smooth, _ = disturbance_smoother(filtered, y, model)
eta_smooth = disturbance_smoother(filtered, y, model)
return y - eta_smooth

# %% ../nbs/10_kalman_filter_smoother.ipynb 35
Expand Down
2 changes: 1 addition & 1 deletion isssm/laplace_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

def _initial_guess(xi_ti, y_ti, dist, link=default_link):
result = minimize(
lambda s_ti: dist(s_ti, xi_ti).log_prob(y_ti).sum(),
lambda s_ti: -dist(s_ti, xi_ti).log_prob(y_ti).sum(),
jnp.atleast_1d(default_link(y_ti)),
method="BFGS",
)
Expand Down
61 changes: 35 additions & 26 deletions nbs/10_kalman_filter_smoother.ipynb

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions nbs/30_laplace_approximation.ipynb

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions nbs/40_importance_sampling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
"# | export\n",
"from tensorflow_probability.substrates.jax.distributions import (\n",
" MultivariateNormalFullCovariance as MVN,\n",
" MultivariateNormalDiag as MVN_diag,\n",
")\n",
"import jax.numpy as jnp\n",
"from jaxtyping import Float, Array\n",
Expand Down Expand Up @@ -113,7 +112,7 @@
" xi: Float[Array, \"n+1 p\"], # observation parameters\n",
" z: Float[Array, \"n+1 p\"], # synthetic observations\n",
" Omega: Float[Array, \"n+1 p p\"], # synthetic observation covariances:\n",
") -> Float[Array, \"n+1\"]: # log weights\n",
") -> Float: # log weights\n",
" \"\"\"Log weights for all time points\"\"\"\n",
" p_ys = dist(s, xi).log_prob(y).sum()\n",
"\n",
Expand Down
53 changes: 27 additions & 26 deletions nbs/60_maximum_likelihood_estimation.ipynb

Large diffs are not rendered by default.

0 comments on commit 60b0167

Please sign in to comment.