|
1 | 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/10_kalman_filter_smoother.ipynb.
|
2 | 2 |
|
3 | 3 | # %% auto 0
|
4 |
| -__all__ = ['State', 'StateCov', 'StateTransition', 'kalman', 'account_for_nans', 'smoother', 'FFBS', 'disturbance_smoother', |
5 |
| - 'smoothed_signals', 'simulation_smoother'] |
| 4 | +__all__ = ['State', 'StateCov', 'StateTransition', 'kalman', 'account_for_nans', 'smoother', 'filter_intervals', |
| 5 | + 'smoother_intervals', 'FFBS', 'disturbance_smoother', 'smoothed_signals', 'simulation_smoother'] |
6 | 6 |
|
7 | 7 | # %% ../nbs/10_kalman_filter_smoother.ipynb 1
|
8 | 8 | import jax.numpy as jnp
|
@@ -141,6 +141,26 @@ def step(carry, inputs):
|
141 | 141 | return SmootherResult(x_smooth, Xi_smooth)
|
142 | 142 |
|
143 | 143 | # %% ../nbs/10_kalman_filter_smoother.ipynb 22
|
| 144 | +from tensorflow_probability.substrates.jax.distributions import Normal |
| 145 | +def filter_intervals(result: FilterResult, alpha: Float=.05) -> Float[Array, "2 n+1 m"]: |
| 146 | + x_filt, Xi_filt, *_ = result |
| 147 | + marginal_variances = vmap(jnp.diag)(Xi_filt) |
| 148 | + dist = Normal(x_filt, marginal_variances) |
| 149 | + lower = dist.quantile(alpha / 2) |
| 150 | + upper = dist.quantile(1 - alpha / 2) |
| 151 | + |
| 152 | + return jnp.concatenate((lower[None], upper[None])) |
| 153 | + |
| 154 | +def smoother_intervals(result: SmootherResult, alpha: Float = .05) -> Float[Array, "2 n+1 m"]: |
| 155 | + x_smooth, Xi_smooth = result |
| 156 | + marginal_variances = vmap(jnp.diag)(Xi_smooth) |
| 157 | + dist = Normal(x_smooth, marginal_variances) |
| 158 | + lower = dist.quantile(alpha / 2) |
| 159 | + upper = dist.quantile(1 - alpha / 2) |
| 160 | + |
| 161 | + return jnp.concatenate((lower[None], upper[None])) |
| 162 | + |
| 163 | +# %% ../nbs/10_kalman_filter_smoother.ipynb 25 |
144 | 164 | def _simulate_smoothed_FW1994(
|
145 | 165 | x_filt: Float[Array, "n+1 m"],
|
146 | 166 | Xi_filt: Float[Array, "n+1 m m"],
|
@@ -192,7 +212,7 @@ def FFBS(
|
192 | 212 | key, subkey = jrn.split(key)
|
193 | 213 | return _simulate_smoothed_FW1994(x_filt, Xi_filt, Xi_pred, model.A, N, subkey)
|
194 | 214 |
|
195 |
| -# %% ../nbs/10_kalman_filter_smoother.ipynb 27 |
| 215 | +# %% ../nbs/10_kalman_filter_smoother.ipynb 30 |
196 | 216 | def disturbance_smoother(
|
197 | 217 | filtered: FilterResult, # filter result
|
198 | 218 | y: Observations, # observations
|
@@ -236,7 +256,7 @@ def smoothed_signals(
|
236 | 256 | eta_smooth, _ = disturbance_smoother(filtered, y, model)
|
237 | 257 | return y - eta_smooth
|
238 | 258 |
|
239 |
| -# %% ../nbs/10_kalman_filter_smoother.ipynb 32 |
| 259 | +# %% ../nbs/10_kalman_filter_smoother.ipynb 35 |
240 | 260 | from tensorflow_probability.substrates.jax.distributions import Chi2
|
241 | 261 | from .util import degenerate_cholesky
|
242 | 262 | from .util import location_antithetic, scale_antithethic
|
|
0 commit comments