Skip to content

Commit e4f556a

Browse files
author
Hylke Donker
committed
Add support for inhomogeneous parameters
1 parent 8e9ee5f commit e4f556a

File tree

2 files changed

+155
-33
lines changed

2 files changed

+155
-33
lines changed

dynamax/linear_gaussian_ssm/models.py

Lines changed: 101 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,47 @@ def m_step(self,
643643
)
644644
return params, m_step_state
645645

646+
def _check_params(self, params: ParamsLGSSM, num_timesteps: int) -> ParamsLGSSM:
647+
"""Replace None parameters with zeros."""
648+
dynamics, emissions = params.dynamics, params.emissions
649+
is_inhomogeneous = dynamics.weights.ndim == 3
650+
651+
def _zeros_if_none(x, shape):
652+
if x is None:
653+
return jnp.zeros(shape)
654+
return x
655+
656+
shape_prefix = ()
657+
if is_inhomogeneous:
658+
shape_prefix = (num_timesteps - 1,)
659+
660+
clean_dynamics = ParamsLGSSMDynamics(
661+
weights=dynamics.weights,
662+
bias=_zeros_if_none(dynamics.bias, shape=shape_prefix + (self.state_dim,)),
663+
input_weights=_zeros_if_none(
664+
dynamics.input_weights, shape=shape_prefix + (self.state_dim, self.input_dim)
665+
),
666+
cov=dynamics.cov
667+
)
668+
shape_prefix = ()
669+
if is_inhomogeneous:
670+
shape_prefix = (num_timesteps,)
671+
672+
clean_emissions = ParamsLGSSMEmissions(
673+
weights=emissions.weights,
674+
bias=_zeros_if_none(emissions.bias, shape=shape_prefix + (self.emission_dim,)),
675+
input_weights=_zeros_if_none(
676+
emissions.input_weights, shape=shape_prefix + (self.emission_dim, self.input_dim)
677+
),
678+
cov=emissions.cov
679+
)
680+
return ParamsLGSSM(
681+
initial=params.initial,
682+
dynamics=clean_dynamics,
683+
emissions=clean_emissions,
684+
)
685+
686+
646687
def fit_blocked_gibbs(self,
647688
key: PRNGKeyT,
648689
initial_params: ParamsLGSSM,
@@ -654,7 +695,8 @@ def fit_blocked_gibbs(self,
654695
655696
Args:
656697
key: random number key.
657-
initial_params: starting parameters.
698+
initial_params: starting parameters. Include a leading time axis for
699+
the dynamics and emissions parameters in inhomogeneous models.
658700
sample_size: how many samples to draw.
659701
emissions: set of observation sequences.
660702
inputs: optional set of input sequences.
@@ -667,66 +709,95 @@ def fit_blocked_gibbs(self,
667709

668710
num_batches, num_timesteps = batch_emissions.shape[:2]
669711

712+
initial_params = self._check_params(initial_params, num_timesteps)
670713
if batch_inputs is None:
671714
batch_inputs = jnp.zeros((num_batches, num_timesteps, 0))
672715

716+
# Inhomogeneous models have a leading time dimension.
717+
is_inhomogeneous = initial_params.dynamics.weights.ndim == 3
718+
673719
def sufficient_stats_from_sample(y, inputs, states):
674720
"""Convert samples of states to sufficient statistics."""
675721
inputs_joint = jnp.concatenate((inputs, jnp.ones((num_timesteps, 1))), axis=1)
676722
# Let xn[t] = x[t+1] for t = 0...T-2
677-
x, xp, xn = states, states[:-1], states[1:]
678-
u, up = inputs_joint, inputs_joint[:-1]
723+
x, xn, xp = states, states[1:], states[:-1]
724+
u, un = inputs_joint, inputs_joint[1:]
725+
# Let zp[t] = [x[t], u[t+1]] for t = 0...T-2
726+
zp = jnp.concatenate([xp, un], axis=1)
727+
# Let z[t] = [x[t], u[t]] for t = 0...T-1
728+
z = jnp.concatenate([x, u], axis=-1)
679729

680730
init_stats = (x[0], jnp.outer(x[0], x[0]), 1)
681731

682732
# Quantities for the dynamics distribution
683-
# Let zp[t] = [x[t], u[t]] for t = 0...T-2
684-
sum_zpzpT = jnp.block([[xp.T @ xp, xp.T @ up], [up.T @ xp, up.T @ up]])
685-
sum_zpxnT = jnp.block([[xp.T @ xn], [up.T @ xn]])
686-
sum_xnxnT = xn.T @ xn
687-
dynamics_stats = (sum_zpzpT, sum_zpxnT, sum_xnxnT, num_timesteps - 1)
733+
sum_zpzpT = jnp.einsum('ti,tj->tij', zp, zp)
734+
sum_zpxnT = jnp.einsum('ti,tj->tij', zp, xn)
735+
sum_xnxnT = jnp.einsum('ti,tj->tij', xn, xn)
736+
n_t_dynamics = jnp.ones(num_timesteps - 1)
737+
# The dynamics stats have a leading time dimension.
738+
dynamics_stats = (sum_zpzpT, sum_zpxnT, sum_xnxnT, n_t_dynamics)
688739
if not self.has_dynamics_bias:
689-
dynamics_stats = (sum_zpzpT[:-1, :-1], sum_zpxnT[:-1, :], sum_xnxnT,
690-
num_timesteps - 1)
740+
dynamics_stats = (sum_zpzpT[:, :-1, :-1], sum_zpxnT[:, :-1, :], sum_xnxnT,
741+
n_t_dynamics)
691742

692743
# Quantities for the emissions
693-
# Let z[t] = [x[t], u[t]] for t = 0...T-1
694-
sum_zzT = jnp.block([[x.T @ x, x.T @ u], [u.T @ x, u.T @ u]])
695-
sum_zyT = jnp.block([[x.T @ y], [u.T @ y]])
696-
sum_yyT = y.T @ y
697-
emission_stats = (sum_zzT, sum_zyT, sum_yyT, num_timesteps)
744+
sum_zzT = jnp.einsum('ti,tj->tij', z, z)
745+
sum_zyT = jnp.einsum('ti,tj->tij', z, y)
746+
sum_yyT = jnp.einsum('ti,tj->tij', y, y)
747+
n_t_emissions = jnp.ones(num_timesteps)
748+
# The emissions stats have a leading time dimension.
749+
emission_stats = (sum_zzT, sum_zyT, sum_yyT, n_t_emissions)
698750
if not self.has_emissions_bias:
699-
emission_stats = (sum_zzT[:-1, :-1], sum_zyT[:-1, :], sum_yyT, num_timesteps)
751+
emission_stats = (sum_zzT[:, :-1, :-1], sum_zyT[:, :-1, :], sum_yyT, n_t_emissions)
700752

701753
return init_stats, dynamics_stats, emission_stats
702754

703-
def lgssm_params_sample(rng, stats):
704-
"""Sample parameters of the model given sufficient statistics from observed states and emissions."""
705-
init_stats, dynamics_stats, emission_stats = stats
706-
rngs = iter(jr.split(rng, 3))
707-
708-
# Sample the initial params
755+
def _sample_initial_params(rng, init_stats):
709756
initial_posterior = niw_posterior_update(self.initial_prior, init_stats)
710-
S, m = initial_posterior.sample(seed=next(rngs))
757+
S, m = initial_posterior.sample(seed=rng)
758+
return ParamsLGSSMInitial(mean=m, cov=S)
711759

712-
# Sample the dynamics params
760+
def _sample_dynamics_params(rng, dynamics_stats):
713761
dynamics_posterior = mniw_posterior_update(self.dynamics_prior, dynamics_stats)
714-
Q, FB = dynamics_posterior.sample(seed=next(rngs))
762+
Q, FB = dynamics_posterior.sample(seed=rng)
715763
F = FB[:, :self.state_dim]
716764
B, b = (FB[:, self.state_dim:-1], FB[:, -1]) if self.has_dynamics_bias \
717765
else (FB[:, self.state_dim:], jnp.zeros(self.state_dim))
766+
return ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q)
718767

719-
# Sample the emission params
768+
def _sample_emission_params(rng, emission_stats):
720769
emission_posterior = mniw_posterior_update(self.emission_prior, emission_stats)
721-
R, HD = emission_posterior.sample(seed=next(rngs))
770+
R, HD = emission_posterior.sample(seed=rng)
722771
H = HD[:, :self.state_dim]
723772
D, d = (HD[:, self.state_dim:-1], HD[:, -1]) if self.has_emissions_bias \
724773
else (HD[:, self.state_dim:], jnp.zeros(self.emission_dim))
774+
return ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R)
775+
776+
def lgssm_params_sample(rng, stats):
777+
"""Sample parameters of the model given sufficient statistics from observed states and emissions."""
778+
init_stats, dynamics_stats, emission_stats = stats
779+
rngs = iter(jr.split(rng, 3))
780+
781+
# Sample the initial params
782+
initial_params = _sample_initial_params(next(rngs), init_stats)
783+
784+
# Sample the dynamics and emission params.
785+
if not is_inhomogeneous:
786+
# Aggregate summary statistics across time for homogeneous model.
787+
dynamics_stats = tree.map(lambda x: jnp.sum(x, axis=0), dynamics_stats)
788+
emission_stats = tree.map(lambda x: jnp.sum(x, axis=0), emission_stats)
789+
dynamics_params = _sample_dynamics_params(next(rngs), dynamics_stats)
790+
emission_params = _sample_emission_params(next(rngs), emission_stats)
791+
else:
792+
keys_dynamics = jr.split(next(rngs), num_timesteps - 1)
793+
keys_emission = jr.split(next(rngs), num_timesteps)
794+
dynamics_params = vmap(_sample_dynamics_params)(keys_dynamics, dynamics_stats)
795+
emission_params = vmap(_sample_emission_params)(keys_emission, emission_stats)
725796

726797
params = ParamsLGSSM(
727-
initial=ParamsLGSSMInitial(mean=m, cov=S),
728-
dynamics=ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q),
729-
emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R)
798+
initial=initial_params,
799+
dynamics=dynamics_params,
800+
emissions=emission_params,
730801
)
731802
return params
732803

dynamax/linear_gaussian_ssm/models_test.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@
22
Tests for the linear Gaussian SSM models.
33
"""
44
from functools import partial
5-
from itertools import count
5+
from itertools import count, product
66

7-
import pytest
87
from jax import vmap
98
import jax.numpy as jnp
109
import jax.random as jr
10+
from jax import tree
11+
import pytest
1112

1213
from dynamax.linear_gaussian_ssm import LinearGaussianSSM
1314
from dynamax.linear_gaussian_ssm import LinearGaussianConjugateSSM
15+
from dynamax.linear_gaussian_ssm.inference import ParamsLGSSM
1416
from dynamax.utils.utils import monotonically_increasing
1517

1618
NUM_TIMESTEPS = 100
@@ -49,4 +51,53 @@ def test_fit_blocked_gibbs_batched():
4951
params, _ = model.initialize(next(keys))
5052
_, y_obs = vmap(partial(model.sample, params, num_timesteps=num_timesteps))(m_keys)
5153

52-
model.fit_blocked_gibbs(next(keys), params, sample_size=6, emissions=y_obs)
54+
model.fit_blocked_gibbs(next(keys), params, sample_size=6, emissions=y_obs)
55+
56+
@pytest.mark.parametrize(["has_dynamics_bias", "has_emissions_bias"], product([True, False], repeat=2))
57+
def test_inhomogeneous_lgcssm(has_dynamics_bias, has_emissions_bias):
58+
"""
59+
Test a LinearGaussianConjugateSSM with time-varying dynamics and emission model.
60+
"""
61+
state_dim = 2
62+
emission_dim = 3
63+
num_timesteps = 4
64+
keys = map(jr.PRNGKey, count())
65+
kwargs = {
66+
"state_dim": state_dim,
67+
"emission_dim": emission_dim,
68+
"has_dynamics_bias": has_dynamics_bias,
69+
"has_emissions_bias": has_emissions_bias,
70+
}
71+
model = LinearGaussianConjugateSSM(**kwargs)
72+
params, param_props = model.initialize(jr.PRNGKey(0))
73+
# Repeat the parameters for each timestep.
74+
inhomogeneous_dynamics = tree.map(
75+
lambda x: jnp.repeat(x[None], num_timesteps - 1, axis=0), params.dynamics,
76+
)
77+
inhomogeneous_emissions = tree.map(
78+
lambda x: jnp.repeat(x[None], num_timesteps, axis=0), params.emissions,
79+
)
80+
81+
_, emissions = model.sample(params, next(keys), num_timesteps=num_timesteps)
82+
inhomogeneous_params = ParamsLGSSM(
83+
initial=params.initial,
84+
dynamics=inhomogeneous_dynamics,
85+
emissions=inhomogeneous_emissions,
86+
)
87+
params_trace = model.fit_blocked_gibbs(
88+
next(keys),
89+
inhomogeneous_params,
90+
sample_size=5,
91+
emissions=emissions,
92+
)
93+
94+
# Arbitrarily check the last set of parameters from the Markov chain.
95+
last_params = tree.map(lambda x: x[-1], params_trace)
96+
assert last_params.initial.mean.shape == (state_dim,)
97+
assert last_params.initial.cov.shape == (state_dim, state_dim)
98+
assert last_params.dynamics.weights.shape == (num_timesteps - 1, state_dim, state_dim)
99+
assert last_params.emissions.weights.shape == (num_timesteps, emission_dim, state_dim)
100+
assert last_params.dynamics.bias.shape == (num_timesteps - 1, state_dim)
101+
assert last_params.emissions.bias.shape == (num_timesteps, emission_dim)
102+
assert last_params.dynamics.cov.shape == (num_timesteps - 1, state_dim, state_dim)
103+
assert last_params.emissions.cov.shape == (num_timesteps, emission_dim, emission_dim)

0 commit comments

Comments
 (0)