Skip to content

Commit

Permalink
Add prediction intervals for filter/smoother
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanheyder committed Aug 16, 2024
1 parent a549771 commit 096e9ac
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 9 deletions.
4 changes: 3 additions & 1 deletion isssm/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@
'isssm.kalman._smooth_step': ('kalman_filter_smoother.html#_smooth_step', 'isssm/kalman.py'),
'isssm.kalman.account_for_nans': ('kalman_filter_smoother.html#account_for_nans', 'isssm/kalman.py'),
'isssm.kalman.disturbance_smoother': ('kalman_filter_smoother.html#disturbance_smoother', 'isssm/kalman.py'),
'isssm.kalman.filter_intervals': ('kalman_filter_smoother.html#filter_intervals', 'isssm/kalman.py'),
'isssm.kalman.kalman': ('kalman_filter_smoother.html#kalman', 'isssm/kalman.py'),
'isssm.kalman.simulation_smoother': ('kalman_filter_smoother.html#simulation_smoother', 'isssm/kalman.py'),
'isssm.kalman.smoothed_signals': ('kalman_filter_smoother.html#smoothed_signals', 'isssm/kalman.py'),
'isssm.kalman.smoother': ('kalman_filter_smoother.html#smoother', 'isssm/kalman.py')},
'isssm.kalman.smoother': ('kalman_filter_smoother.html#smoother', 'isssm/kalman.py'),
'isssm.kalman.smoother_intervals': ('kalman_filter_smoother.html#smoother_intervals', 'isssm/kalman.py')},
'isssm.laplace_approximation': { 'isssm.laplace_approximation._initial_guess': ( 'laplace_approximation.html#_initial_guess',
'isssm/laplace_approximation.py'),
'isssm.laplace_approximation.laplace_approximation': ( 'laplace_approximation.html#laplace_approximation',
Expand Down
28 changes: 24 additions & 4 deletions isssm/kalman.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/10_kalman_filter_smoother.ipynb.

# %% auto 0
__all__ = ['State', 'StateCov', 'StateTransition', 'kalman', 'account_for_nans', 'smoother', 'FFBS', 'disturbance_smoother',
'smoothed_signals', 'simulation_smoother']
__all__ = ['State', 'StateCov', 'StateTransition', 'kalman', 'account_for_nans', 'smoother', 'filter_intervals',
'smoother_intervals', 'FFBS', 'disturbance_smoother', 'smoothed_signals', 'simulation_smoother']

# %% ../nbs/10_kalman_filter_smoother.ipynb 1
import jax.numpy as jnp
Expand Down Expand Up @@ -141,6 +141,26 @@ def step(carry, inputs):
return SmootherResult(x_smooth, Xi_smooth)

# %% ../nbs/10_kalman_filter_smoother.ipynb 22
from tensorflow_probability.substrates.jax.distributions import Normal
def filter_intervals(result: FilterResult, alpha: Float=.05) -> Float[Array, "2 n+1 m"]:
x_filt, Xi_filt, *_ = result
marginal_variances = vmap(jnp.diag)(Xi_filt)
dist = Normal(x_filt, marginal_variances)
lower = dist.quantile(alpha / 2)
upper = dist.quantile(1 - alpha / 2)

return jnp.concatenate((lower[None], upper[None]))

def smoother_intervals(result: SmootherResult, alpha: Float = .05) -> Float[Array, "2 n+1 m"]:
x_smooth, Xi_smooth = result
marginal_variances = vmap(jnp.diag)(Xi_smooth)
dist = Normal(x_smooth, marginal_variances)
lower = dist.quantile(alpha / 2)
upper = dist.quantile(1 - alpha / 2)

return jnp.concatenate((lower[None], upper[None]))

# %% ../nbs/10_kalman_filter_smoother.ipynb 25
def _simulate_smoothed_FW1994(
x_filt: Float[Array, "n+1 m"],
Xi_filt: Float[Array, "n+1 m m"],
Expand Down Expand Up @@ -192,7 +212,7 @@ def FFBS(
key, subkey = jrn.split(key)
return _simulate_smoothed_FW1994(x_filt, Xi_filt, Xi_pred, model.A, N, subkey)

# %% ../nbs/10_kalman_filter_smoother.ipynb 27
# %% ../nbs/10_kalman_filter_smoother.ipynb 30
def disturbance_smoother(
filtered: FilterResult, # filter result
y: Observations, # observations
Expand Down Expand Up @@ -236,7 +256,7 @@ def smoothed_signals(
eta_smooth, _ = disturbance_smoother(filtered, y, model)
return y - eta_smooth

# %% ../nbs/10_kalman_filter_smoother.ipynb 32
# %% ../nbs/10_kalman_filter_smoother.ipynb 35
from tensorflow_probability.substrates.jax.distributions import Chi2
from .util import degenerate_cholesky
from .util import location_antithetic, scale_antithethic
Expand Down
82 changes: 78 additions & 4 deletions nbs/10_kalman_filter_smoother.ipynb

Large diffs are not rendered by default.

0 comments on commit 096e9ac

Please sign in to comment.