Skip to content

Commit 096e9ac

Browse files
committed
Add prediction intervals for filter/smoother
1 parent a549771 commit 096e9ac

File tree

3 files changed

+105
-9
lines changed

3 files changed

+105
-9
lines changed

isssm/_modidx.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,12 @@
5454
'isssm.kalman._smooth_step': ('kalman_filter_smoother.html#_smooth_step', 'isssm/kalman.py'),
5555
'isssm.kalman.account_for_nans': ('kalman_filter_smoother.html#account_for_nans', 'isssm/kalman.py'),
5656
'isssm.kalman.disturbance_smoother': ('kalman_filter_smoother.html#disturbance_smoother', 'isssm/kalman.py'),
57+
'isssm.kalman.filter_intervals': ('kalman_filter_smoother.html#filter_intervals', 'isssm/kalman.py'),
5758
'isssm.kalman.kalman': ('kalman_filter_smoother.html#kalman', 'isssm/kalman.py'),
5859
'isssm.kalman.simulation_smoother': ('kalman_filter_smoother.html#simulation_smoother', 'isssm/kalman.py'),
5960
'isssm.kalman.smoothed_signals': ('kalman_filter_smoother.html#smoothed_signals', 'isssm/kalman.py'),
60-
'isssm.kalman.smoother': ('kalman_filter_smoother.html#smoother', 'isssm/kalman.py')},
61+
'isssm.kalman.smoother': ('kalman_filter_smoother.html#smoother', 'isssm/kalman.py'),
62+
'isssm.kalman.smoother_intervals': ('kalman_filter_smoother.html#smoother_intervals', 'isssm/kalman.py')},
6163
'isssm.laplace_approximation': { 'isssm.laplace_approximation._initial_guess': ( 'laplace_approximation.html#_initial_guess',
6264
'isssm/laplace_approximation.py'),
6365
'isssm.laplace_approximation.laplace_approximation': ( 'laplace_approximation.html#laplace_approximation',

isssm/kalman.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/10_kalman_filter_smoother.ipynb.
22

33
# %% auto 0
4-
__all__ = ['State', 'StateCov', 'StateTransition', 'kalman', 'account_for_nans', 'smoother', 'FFBS', 'disturbance_smoother',
5-
'smoothed_signals', 'simulation_smoother']
4+
__all__ = ['State', 'StateCov', 'StateTransition', 'kalman', 'account_for_nans', 'smoother', 'filter_intervals',
5+
'smoother_intervals', 'FFBS', 'disturbance_smoother', 'smoothed_signals', 'simulation_smoother']
66

77
# %% ../nbs/10_kalman_filter_smoother.ipynb 1
88
import jax.numpy as jnp
@@ -141,6 +141,26 @@ def step(carry, inputs):
141141
return SmootherResult(x_smooth, Xi_smooth)
142142

143143
# %% ../nbs/10_kalman_filter_smoother.ipynb 22
144+
from tensorflow_probability.substrates.jax.distributions import Normal
145+
def filter_intervals(result: FilterResult, alpha: Float=.05) -> Float[Array, "2 n+1 m"]:
146+
x_filt, Xi_filt, *_ = result
147+
marginal_variances = vmap(jnp.diag)(Xi_filt)
148+
dist = Normal(x_filt, marginal_variances)
149+
lower = dist.quantile(alpha / 2)
150+
upper = dist.quantile(1 - alpha / 2)
151+
152+
return jnp.concatenate((lower[None], upper[None]))
153+
154+
def smoother_intervals(result: SmootherResult, alpha: Float = .05) -> Float[Array, "2 n+1 m"]:
155+
x_smooth, Xi_smooth = result
156+
marginal_variances = vmap(jnp.diag)(Xi_smooth)
157+
dist = Normal(x_smooth, marginal_variances)
158+
lower = dist.quantile(alpha / 2)
159+
upper = dist.quantile(1 - alpha / 2)
160+
161+
return jnp.concatenate((lower[None], upper[None]))
162+
163+
# %% ../nbs/10_kalman_filter_smoother.ipynb 25
144164
def _simulate_smoothed_FW1994(
145165
x_filt: Float[Array, "n+1 m"],
146166
Xi_filt: Float[Array, "n+1 m m"],
@@ -192,7 +212,7 @@ def FFBS(
192212
key, subkey = jrn.split(key)
193213
return _simulate_smoothed_FW1994(x_filt, Xi_filt, Xi_pred, model.A, N, subkey)
194214

195-
# %% ../nbs/10_kalman_filter_smoother.ipynb 27
215+
# %% ../nbs/10_kalman_filter_smoother.ipynb 30
196216
def disturbance_smoother(
197217
filtered: FilterResult, # filter result
198218
y: Observations, # observations
@@ -236,7 +256,7 @@ def smoothed_signals(
236256
eta_smooth, _ = disturbance_smoother(filtered, y, model)
237257
return y - eta_smooth
238258

239-
# %% ../nbs/10_kalman_filter_smoother.ipynb 32
259+
# %% ../nbs/10_kalman_filter_smoother.ipynb 35
240260
from tensorflow_probability.substrates.jax.distributions import Chi2
241261
from .util import degenerate_cholesky
242262
from .util import location_antithetic, scale_antithethic

nbs/10_kalman_filter_smoother.ipynb

Lines changed: 78 additions & 4 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)