Skip to content

Commit

Permalink
Add missing observations to Kalman filter
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanheyder committed Aug 16, 2024
1 parent 781c7a5 commit a549771
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 127 deletions.
1 change: 1 addition & 0 deletions isssm/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
'isssm.kalman._simulate_smoothed_FW1994': ( 'kalman_filter_smoother.html#_simulate_smoothed_fw1994',
'isssm/kalman.py'),
'isssm.kalman._smooth_step': ('kalman_filter_smoother.html#_smooth_step', 'isssm/kalman.py'),
'isssm.kalman.account_for_nans': ('kalman_filter_smoother.html#account_for_nans', 'isssm/kalman.py'),
'isssm.kalman.disturbance_smoother': ('kalman_filter_smoother.html#disturbance_smoother', 'isssm/kalman.py'),
'isssm.kalman.kalman': ('kalman_filter_smoother.html#kalman', 'isssm/kalman.py'),
'isssm.kalman.simulation_smoother': ('kalman_filter_smoother.html#simulation_smoother', 'isssm/kalman.py'),
Expand Down
23 changes: 13 additions & 10 deletions isssm/importance_sampling.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/40_importance_sampling.ipynb.

# %% auto 0
__all__ = ['v_log_weights', 'log_weights_t', 'log_weights', 'pgssm_importance_sampling', 'normalize_weights', 'ess', 'ess_lw',
'ess_pct']
__all__ = ['log_weights_t', 'log_weights', 'pgssm_importance_sampling', 'normalize_weights', 'ess', 'ess_lw', 'ess_pct']

# %% ../nbs/40_importance_sampling.ipynb 4
from tensorflow_probability.substrates.jax.distributions import (
Expand Down Expand Up @@ -44,8 +43,9 @@ def log_weights(
) -> Float[Array, "n+1"]: # log weights
"""Log weights for all time points"""
p_ys = dist(s, xi).log_prob(y).sum()

# avoid triangular solve problems
omega = jnp.sqrt(vmap(jnp.diag)(Omega))
# avoid triangulra solve problems
g_zs = MVN_diag(s, omega).log_prob(z).sum()

return p_ys - g_zs
Expand All @@ -54,13 +54,12 @@ def log_weights(
from jaxtyping import Float, Array, PRNGKeyArray
from .kalman import FFBS, simulation_smoother
import jax.random as jrn
from .typing import GLSSM

v_log_weights = vmap(log_weights, (0, None, None, None, None, None))
from functools import partial
from .typing import GLSSM, PGSSM

def pgssm_importance_sampling(
y: Float[Array, "n+1 p"], # observations
model: PGSSM,
model: PGSSM, # model
z: Float[Array, "n+1 p"], # synthetic observations
Omega: Float[Array, "n+1 p p"], # covariance of synthetic observations
N: int, # number of samples
Expand All @@ -72,7 +71,9 @@ def pgssm_importance_sampling(
key, subkey = jrn.split(key)
s = simulation_smoother(glssm, z, N, subkey)

lw = v_log_weights(s, y, dist, xi, z, Omega)
model_log_weights = partial(log_weights, y=y, dist=dist, xi=xi, z=z, Omega=Omega)

lw = vmap(model_log_weights)(s)

return s, lw

Expand All @@ -97,7 +98,7 @@ def normalize_weights(


def ess(
normalized_weights: Float[Array, "N"] # the normailzed weights
normalized_weights: Float[Array, "N"] # normalized weights
) -> Float: # the effective sample size
"""Compute the effective sample size of a set of normalized weights"""
return 1 / (normalized_weights**2).sum()
Expand All @@ -109,6 +110,8 @@ def ess_lw(
"""Compute the effective sample size of a set of log weights"""
return ess(normalize_weights(log_weights))

def ess_pct(log_weights):
def ess_pct(
log_weights: Float[Array, "N"] # log weights
) -> Float: # the effective sample size in percent, also called efficiency factor
N, = log_weights.shape
return ess_lw(log_weights) / N * 100
51 changes: 32 additions & 19 deletions isssm/kalman.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/10_kalman_filter_smoother.ipynb.

# %% auto 0
__all__ = ['State', 'StateCov', 'StateTransition', 'kalman', 'smoother', 'FFBS', 'disturbance_smoother', 'smoothed_signals',
'simulation_smoother']
__all__ = ['State', 'StateCov', 'StateTransition', 'kalman', 'account_for_nans', 'smoother', 'FFBS', 'disturbance_smoother',
'smoothed_signals', 'simulation_smoother']

# %% ../nbs/10_kalman_filter_smoother.ipynb 1
import jax.numpy as jnp
Expand All @@ -18,10 +18,10 @@

# %% ../nbs/10_kalman_filter_smoother.ipynb 7
def _predict(
x_filt: Float[Array, "m"], # $X_{t|t}$
Xi_filt: Float[Array, "m m"], # $\Xi_{t|t}
A: Float[Array, "m m"], # $A_t$
Sigma: Float[Array, "m m"], # $\Sigma_{t + 1}
x_filt: Float[Array, "m"], # $X_{t|t}$
Xi_filt: Float[Array, "m m"], # $\Xi_{t|t}
A: Float[Array, "m m"], # $A_t$
Sigma: Float[Array, "m m"], # $\Sigma_{t + 1}
):
"""perform a single prediction step"""
x_pred = A @ x_filt
Expand All @@ -31,7 +31,7 @@ def _predict(


def _filter(
x_pred: Float[Array, "m"],
x_pred: Float[Array, "m"],
Xi_pred: Float[Array, "m m"],
y: Float[Array, "p"],
B: Float[Array, "p m"],
Expand All @@ -40,19 +40,21 @@ def _filter(
"""perform a single filtering step"""
y_pred = B @ x_pred
Psi_pred = B @ Xi_pred @ B.T + Omega
K = Xi_pred @ B.T @ jnp.linalg.pinv(Psi_pred)#jsla.solve(Psi_pred, B).T
K = Xi_pred @ B.T @ jnp.linalg.pinv(Psi_pred) # jsla.solve(Psi_pred, B).T
x_filt = x_pred + K @ (y - y_pred)
Xi_filt = Xi_pred - K @ Psi_pred @ K.T

return x_filt, Xi_filt


def kalman(
y: Observations, # observatoins
glssm: GLSSM, # model
) -> FilterResult: # filtered & predicted states and covariances
y: Observations, # observatoins
glssm: GLSSM, # model
) -> FilterResult: # filtered & predicted states and covariances
"""Perform the Kalman filter"""
x0, A, Sigma, B, Omega = glssm
(m,) = x0.shape

def step(carry, inputs):
x_filt, Xi_filt = carry
y, Sigma, Omega, A, B = inputs
Expand All @@ -66,11 +68,7 @@ def step(carry, inputs):
# covariance zero, transition identity
# will lead to X_0 having correct predictive distribution
# this avoids having to compute a separate filtering step beforehand

m, = x0.shape
A_ext = jnp.concatenate(
(jnp.eye(m)[jnp.newaxis], A)
)
A_ext = jnp.concatenate((jnp.eye(m)[jnp.newaxis], A))

_, (x_filt, Xi_filt, x_pred, Xi_pred) = scan(
step, (x0, jnp.zeros_like(Sigma[0])), (y, Sigma, Omega, A_ext, B)
Expand All @@ -79,6 +77,21 @@ def step(carry, inputs):
return FilterResult(x_filt, Xi_filt, x_pred, Xi_pred)

# %% ../nbs/10_kalman_filter_smoother.ipynb 13
# y not jittable: boolean indices have to be concrete
def account_for_nans(model: GLSSM, y: Observations) -> tuple[GLSSM, Observations]:
x0, A, Sigma, B, Omega = model

missing_indices = jnp.isnan(y)

y = jnp.nan_to_num(y, nan=0.0)
B = B.at[missing_indices].set(0.0)
# set rows and columns of Omega to 0.
Omega = Omega.at[missing_indices].set(0.0)
Omega = Omega.transpose((0, 2, 1)).at[missing_indices].set(0.0)

return GLSSM(x0=x0, A=A, Sigma=Sigma, B=B, Omega=Omega), y

# %% ../nbs/10_kalman_filter_smoother.ipynb 17
State = Float[Array, "m"]
StateCov = Float[Array, "m m"]
StateTransition = Float[Array, "m m"]
Expand Down Expand Up @@ -127,7 +140,7 @@ def step(carry, inputs):

return SmootherResult(x_smooth, Xi_smooth)

# %% ../nbs/10_kalman_filter_smoother.ipynb 18
# %% ../nbs/10_kalman_filter_smoother.ipynb 22
def _simulate_smoothed_FW1994(
x_filt: Float[Array, "n+1 m"],
Xi_filt: Float[Array, "n+1 m m"],
Expand Down Expand Up @@ -179,7 +192,7 @@ def FFBS(
key, subkey = jrn.split(key)
return _simulate_smoothed_FW1994(x_filt, Xi_filt, Xi_pred, model.A, N, subkey)

# %% ../nbs/10_kalman_filter_smoother.ipynb 23
# %% ../nbs/10_kalman_filter_smoother.ipynb 27
def disturbance_smoother(
filtered: FilterResult, # filter result
y: Observations, # observations
Expand Down Expand Up @@ -223,7 +236,7 @@ def smoothed_signals(
eta_smooth, _ = disturbance_smoother(filtered, y, model)
return y - eta_smooth

# %% ../nbs/10_kalman_filter_smoother.ipynb 28
# %% ../nbs/10_kalman_filter_smoother.ipynb 32
from tensorflow_probability.substrates.jax.distributions import Chi2
from .util import degenerate_cholesky
from .util import location_antithetic, scale_antithethic
Expand Down
127 changes: 105 additions & 22 deletions nbs/10_kalman_filter_smoother.ipynb

Large diffs are not rendered by default.

22 changes: 13 additions & 9 deletions nbs/40_importance_sampling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,9 @@
") -> Float[Array, \"n+1\"]: # log weights\n",
" \"\"\"Log weights for all time points\"\"\"\n",
" p_ys = dist(s, xi).log_prob(y).sum()\n",
"\n",
" # avoid triangular solve problems\n",
" omega = jnp.sqrt(vmap(jnp.diag)(Omega))\n",
" # avoid triangulra solve problems\n",
" g_zs = MVN_diag(s, omega).log_prob(z).sum()\n",
"\n",
" return p_ys - g_zs"
Expand Down Expand Up @@ -151,13 +152,12 @@
"from jaxtyping import Float, Array, PRNGKeyArray\n",
"from isssm.kalman import FFBS, simulation_smoother\n",
"import jax.random as jrn\n",
"from isssm.typing import GLSSM\n",
"\n",
"v_log_weights = vmap(log_weights, (0, None, None, None, None, None))\n",
"from functools import partial\n",
"from isssm.typing import GLSSM, PGSSM\n",
"\n",
"def pgssm_importance_sampling(\n",
" y: Float[Array, \"n+1 p\"], # observations\n",
" model: PGSSM,\n",
" model: PGSSM, # model\n",
" z: Float[Array, \"n+1 p\"], # synthetic observations\n",
" Omega: Float[Array, \"n+1 p p\"], # covariance of synthetic observations\n",
" N: int, # number of samples\n",
Expand All @@ -169,7 +169,9 @@
" key, subkey = jrn.split(key)\n",
" s = simulation_smoother(glssm, z, N, subkey)\n",
"\n",
" lw = v_log_weights(s, y, dist, xi, z, Omega)\n",
" model_log_weights = partial(log_weights, y=y, dist=dist, xi=xi, z=z, Omega=Omega)\n",
"\n",
" lw = vmap(model_log_weights)(s)\n",
"\n",
" return s, lw"
]
Expand All @@ -178,7 +180,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Let us perform importance sampling for our [running example](20_lcssm.ipynb#running) using the model obtained by [mode estimation](30_laplace_approximation.ipynb)."
"Let us perform importance sampling for our [running example](20_lcssm.ipynb#running) using the model obtained by [the laplace approximation](30_laplace_approximation.ipynb). Notice that we obtain four times the number of samples that we specified, which comes from the use of [antithetics](./99_util.ipynb#antithetic-variables)."
]
},
{
Expand Down Expand Up @@ -301,7 +303,7 @@
"\n",
"\n",
"def ess(\n",
" normalized_weights: Float[Array, \"N\"] # the normailzed weights\n",
" normalized_weights: Float[Array, \"N\"] # normalized weights\n",
") -> Float: # the effective sample size\n",
" \"\"\"Compute the effective sample size of a set of normalized weights\"\"\"\n",
" return 1 / (normalized_weights**2).sum()\n",
Expand All @@ -313,7 +315,9 @@
" \"\"\"Compute the effective sample size of a set of log weights\"\"\"\n",
" return ess(normalize_weights(log_weights))\n",
"\n",
"def ess_pct(log_weights):\n",
"def ess_pct(\n",
" log_weights: Float[Array, \"N\"] # log weights\n",
") -> Float: # the effective sample size in percent, also called efficiency factor\n",
" N, = log_weights.shape\n",
" return ess_lw(log_weights) / N * 100 "
]
Expand Down
Loading

0 comments on commit a549771

Please sign in to comment.