Skip to content

Commit

Permalink
Add predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanheyder committed Aug 28, 2024
1 parent 60b0167 commit eb70fd5
Show file tree
Hide file tree
Showing 7 changed files with 434 additions and 81 deletions.
18 changes: 11 additions & 7 deletions isssm/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@
'isssm.importance_sampling.normalize_weights': ( 'importance_sampling.html#normalize_weights',
'isssm/importance_sampling.py'),
'isssm.importance_sampling.pgssm_importance_sampling': ( 'importance_sampling.html#pgssm_importance_sampling',
'isssm/importance_sampling.py')},
'isssm/importance_sampling.py'),
'isssm.importance_sampling.predict': ( 'importance_sampling.html#predict',
'isssm/importance_sampling.py'),
'isssm.importance_sampling.prediction_interval': ( 'importance_sampling.html#prediction_interval',
'isssm/importance_sampling.py')},
'isssm.kalman': { 'isssm.kalman.FFBS': ('kalman_filter_smoother.html#ffbs', 'isssm/kalman.py'),
'isssm.kalman._filter': ('kalman_filter_smoother.html#_filter', 'isssm/kalman.py'),
'isssm.kalman._predict': ('kalman_filter_smoother.html#_predict', 'isssm/kalman.py'),
Expand All @@ -66,12 +70,12 @@
'isssm/laplace_approximation.py'),
'isssm.laplace_approximation.posterior_mode': ( 'laplace_approximation.html#posterior_mode',
'isssm/laplace_approximation.py')},
'isssm.models.glssm': { 'isssm.models.glssm.ar1': ('Models/gaussian_models.html#ar1', 'isssm/models/glssm.py'),
'isssm.models.glssm.lcm': ('Models/gaussian_models.html#lcm', 'isssm/models/glssm.py'),
'isssm.models.glssm.mv_ar1': ('Models/gaussian_models.html#mv_ar1', 'isssm/models/glssm.py')},
'isssm.models.pgssm': { 'isssm.models.pgssm.nb_pgssm': ('Models/pgssm.html#nb_pgssm', 'isssm/models/pgssm.py'),
'isssm.models.pgssm.poisson_pgssm': ('Models/pgssm.html#poisson_pgssm', 'isssm/models/pgssm.py')},
'isssm.models.stsm': {'isssm.models.stsm.stsm': ('Models/stsm.html#stsm', 'isssm/models/stsm.py')},
'isssm.models.glssm': { 'isssm.models.glssm.ar1': ('models/gaussian_models.html#ar1', 'isssm/models/glssm.py'),
'isssm.models.glssm.lcm': ('models/gaussian_models.html#lcm', 'isssm/models/glssm.py'),
'isssm.models.glssm.mv_ar1': ('models/gaussian_models.html#mv_ar1', 'isssm/models/glssm.py')},
'isssm.models.pgssm': { 'isssm.models.pgssm.nb_pgssm': ('models/pgssm.html#nb_pgssm', 'isssm/models/pgssm.py'),
'isssm.models.pgssm.poisson_pgssm': ('models/pgssm.html#poisson_pgssm', 'isssm/models/pgssm.py')},
'isssm.models.stsm': {'isssm.models.stsm.stsm': ('models/stsm.html#stsm', 'isssm/models/stsm.py')},
'isssm.modified_efficient_importance_sampling': { 'isssm.modified_efficient_importance_sampling.modified_efficient_importance_sampling': ( 'modified_efficient_importance_sampling.html#modified_efficient_importance_sampling',
'isssm/modified_efficient_importance_sampling.py'),
'isssm.modified_efficient_importance_sampling.optimal_parameters': ( 'modified_efficient_importance_sampling.html#optimal_parameters',
Expand Down
105 changes: 104 additions & 1 deletion isssm/importance_sampling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/40_importance_sampling.ipynb.

# %% auto 0
__all__ = ['log_weights_t', 'log_weights', 'pgssm_importance_sampling', 'normalize_weights', 'ess', 'ess_lw', 'ess_pct']
__all__ = ['log_weights_t', 'log_weights', 'pgssm_importance_sampling', 'normalize_weights', 'ess', 'ess_lw', 'ess_pct',
'prediction_interval', 'predict']

# %% ../nbs/40_importance_sampling.ipynb 4
from tensorflow_probability.substrates.jax.distributions import (
Expand Down Expand Up @@ -119,3 +120,105 @@ def ess_pct(
) -> Float: # the effective sample size in percent, also called efficiency factor
(N,) = log_weights.shape
return ess_lw(log_weights) / N * 100

# %% ../nbs/40_importance_sampling.ipynb 19
from .typing import GLSSMProposal, GLSSMState
from .kalman import kalman
from .glssm import simulate_states
from .util import mm_time_sim
from jax import jit


def prediction_interval(Y, weights, alpha):
probs = jnp.array([alpha / 2, 1 - alpha / 2])

Y_sorted = jnp.sort(Y)
weights_sorted = weights[jnp.argsort(Y)]
cumsum = jnp.cumsum(weights_sorted)

# find indices of cumulative sum closest to prediction_probs
# take corresponding Y_sorted values
# with linear interpolation if necessary

indices = jnp.searchsorted(cumsum, probs)
indices = jnp.clip(indices, 1, len(Y_sorted) - 1)
left_indices = indices - 1
right_indices = indices
left_cumsum = cumsum[left_indices]
right_cumsum = cumsum[right_indices]
left_Y = Y_sorted[left_indices]
right_Y = Y_sorted[right_indices]
# linear interpolation
quantiles = left_Y + (probs - left_cumsum) / (right_cumsum - left_cumsum) * (
right_Y - left_Y
)
return quantiles


def predict(
model: PGSSM,
y: Float[Array, "n+1 p"],
proposal: GLSSMProposal,
future_model: PGSSM,
N: int,
key: PRNGKeyArray,
):

key, subkey = jrn.split(key)
signal_samples, log_weights = pgssm_importance_sampling(
y, model, proposal.z, proposal.Omega, N, subkey
)
(N,) = log_weights.shape

signal_model = GLSSM(
proposal.u,
proposal.A,
proposal.D,
proposal.Sigma0,
proposal.Sigma,
proposal.v,
proposal.B,
proposal.Omega,
)

@jit
def last_state_sample(signal_sample, key):
x_filt, Xi_filt, _, _ = kalman(signal_sample, signal_model)
return MVN(x_filt[-1], Xi_filt[-1]).sample(seed=key)

key, *subkeys = jrn.split(key, N + 1)
subkeys = jnp.array(subkeys)
samples = vmap(last_state_sample)(signal_samples, subkeys)

@jit
def future_sample(sample, key):
state = GLSSMState(
future_model.u.at[0].set(sample),
future_model.A,
future_model.D,
future_model.Sigma0,
future_model.Sigma,
)

(x,) = simulate_states(state, 1, key)
return x

key, *subkeys = jrn.split(key, N + 1)
subkeys = jnp.array(subkeys)

future_states = vmap(future_sample)(samples, subkeys)
future_signals = mm_time_sim(future_model.B, future_states)

sample_cond_expectation = future_model.dist(future_signals, future_model.xi).mean()

mean_prediction = (
sample_cond_expectation * normalize_weights(log_weights)[:, None, None]
).sum(axis=0)

pi = vmap(vmap(prediction_interval, (0, None, None)), (0, None, None))(
sample_cond_expectation.transpose((2, 1, 0)),
normalize_weights(log_weights),
0.05,
)

return mean_prediction, pi
8 changes: 4 additions & 4 deletions isssm/models/glssm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/Models/00_gaussian_models.ipynb.
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/models/00_gaussian_models.ipynb.

# %% auto 0
__all__ = ['lcm', 'ar1', 'mv_ar1']

# %% ../../nbs/Models/00_gaussian_models.ipynb 4
# %% ../../nbs/models/00_gaussian_models.ipynb 4
import jax.numpy as jnp
from jaxtyping import Float, Array
from ..typing import GLSSM
Expand All @@ -28,7 +28,7 @@ def lcm(
v = jnp.zeros((n + 1, 1))
return GLSSM(u, A, D, Sigma0, Sigma, v, B, Omega)

# %% ../../nbs/Models/00_gaussian_models.ipynb 7
# %% ../../nbs/models/00_gaussian_models.ipynb 7
def ar1(
mu: Float, # stationary mean
tau2: Float, # stationary variance
Expand All @@ -52,7 +52,7 @@ def ar1(

return GLSSM(u, A, D, Sigma0, Sigma, v, B, Omega)

# %% ../../nbs/Models/00_gaussian_models.ipynb 9
# %% ../../nbs/models/00_gaussian_models.ipynb 9
def mv_ar1(
mu: Float[Array, "m"], # stationary mean
Tau: Float[Array, "m m"], # stationary covariance
Expand Down
4 changes: 2 additions & 2 deletions isssm/models/pgssm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/Models/20_pgssm.ipynb.
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/models/20_pgssm.ipynb.

# %% auto 0
__all__ = ['nb_pgssm', 'poisson_pgssm']

# %% ../../nbs/Models/20_pgssm.ipynb 2
# %% ../../nbs/models/20_pgssm.ipynb 2
import jax.numpy as jnp
from jaxtyping import Float
from ..typing import GLSSM, PGSSM
Expand Down
6 changes: 3 additions & 3 deletions isssm/models/stsm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/Models/10_stsm.ipynb.
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/models/10_stsm.ipynb.

# %% auto 0
__all__ = ['stsm']

# %% ../../nbs/Models/10_stsm.ipynb 1
# %% ../../nbs/models/10_stsm.ipynb 1
import jax
import jax.numpy as jnp
from jaxtyping import Float, Array
from ..typing import GLSSM
import jax.scipy as jsp

# %% ../../nbs/Models/10_stsm.ipynb 5
# %% ../../nbs/models/10_stsm.ipynb 5
def stsm(
x0: Float[Array, "m"], # initial state
s2_mu: Float, # variance of trend innovations
Expand Down
243 changes: 219 additions & 24 deletions nbs/40_importance_sampling.ipynb

Large diffs are not rendered by default.

131 changes: 91 additions & 40 deletions nbs/60_maximum_likelihood_estimation.ipynb

Large diffs are not rendered by default.

0 comments on commit eb70fd5

Please sign in to comment.