@@ -643,6 +643,47 @@ def m_step(self,
643
643
)
644
644
return params , m_step_state
645
645
646
+ def _check_params (self , params : ParamsLGSSM , num_timesteps : int ) -> ParamsLGSSM :
647
+ """Replace None parameters with zeros."""
648
+ dynamics , emissions = params .dynamics , params .emissions
649
+ is_inhomogeneous = dynamics .weights .ndim == 3
650
+
651
+ def _zeros_if_none (x , shape ):
652
+ if x is None :
653
+ return jnp .zeros (shape )
654
+ return x
655
+
656
+ shape_prefix = ()
657
+ if is_inhomogeneous :
658
+ shape_prefix = (num_timesteps - 1 ,)
659
+
660
+ clean_dynamics = ParamsLGSSMDynamics (
661
+ weights = dynamics .weights ,
662
+ bias = _zeros_if_none (dynamics .bias , shape = shape_prefix + (self .state_dim ,)),
663
+ input_weights = _zeros_if_none (
664
+ dynamics .input_weights , shape = shape_prefix + (self .state_dim , self .input_dim )
665
+ ),
666
+ cov = dynamics .cov
667
+ )
668
+ shape_prefix = ()
669
+ if is_inhomogeneous :
670
+ shape_prefix = (num_timesteps ,)
671
+
672
+ clean_emissions = ParamsLGSSMEmissions (
673
+ weights = emissions .weights ,
674
+ bias = _zeros_if_none (emissions .bias , shape = shape_prefix + (self .emission_dim ,)),
675
+ input_weights = _zeros_if_none (
676
+ emissions .input_weights , shape = shape_prefix + (self .emission_dim , self .input_dim )
677
+ ),
678
+ cov = emissions .cov
679
+ )
680
+ return ParamsLGSSM (
681
+ initial = params .initial ,
682
+ dynamics = clean_dynamics ,
683
+ emissions = clean_emissions ,
684
+ )
685
+
686
+
646
687
def fit_blocked_gibbs (self ,
647
688
key : PRNGKeyT ,
648
689
initial_params : ParamsLGSSM ,
@@ -654,7 +695,8 @@ def fit_blocked_gibbs(self,
654
695
655
696
Args:
656
697
key: random number key.
657
- initial_params: starting parameters.
698
+ initial_params: starting parameters. Include a leading time axis for
699
+ the dynamics and emissions parameters in inhomogeneous models.
658
700
sample_size: how many samples to draw.
659
701
emissions: set of observation sequences.
660
702
inputs: optional set of input sequences.
@@ -667,66 +709,95 @@ def fit_blocked_gibbs(self,
667
709
668
710
num_batches , num_timesteps = batch_emissions .shape [:2 ]
669
711
712
+ initial_params = self ._check_params (initial_params , num_timesteps )
670
713
if batch_inputs is None :
671
714
batch_inputs = jnp .zeros ((num_batches , num_timesteps , 0 ))
672
715
716
+ # Inhomogeneous models have a leading time dimension.
717
+ is_inhomogeneous = initial_params .dynamics .weights .ndim == 3
718
+
673
719
def sufficient_stats_from_sample (y , inputs , states ):
674
720
"""Convert samples of states to sufficient statistics."""
675
721
inputs_joint = jnp .concatenate ((inputs , jnp .ones ((num_timesteps , 1 ))), axis = 1 )
676
722
# Let xn[t] = x[t+1] for t = 0...T-2
677
- x , xp , xn = states , states [:- 1 ], states [1 :]
678
- u , up = inputs_joint , inputs_joint [:- 1 ]
723
+ x , xn , xp = states , states [1 :], states [:- 1 ]
724
+ u , un = inputs_joint , inputs_joint [1 :]
725
+ # Let zp[t] = [x[t], u[t+1]] for t = 0...T-2
726
+ zp = jnp .concatenate ([xp , un ], axis = 1 )
727
+ # Let z[t] = [x[t], u[t]] for t = 0...T-1
728
+ z = jnp .concatenate ([x , u ], axis = - 1 )
679
729
680
730
init_stats = (x [0 ], jnp .outer (x [0 ], x [0 ]), 1 )
681
731
682
732
# Quantities for the dynamics distribution
683
- # Let zp[t] = [x[t], u[t]] for t = 0...T-2
684
- sum_zpzpT = jnp .block ([[xp .T @ xp , xp .T @ up ], [up .T @ xp , up .T @ up ]])
685
- sum_zpxnT = jnp .block ([[xp .T @ xn ], [up .T @ xn ]])
686
- sum_xnxnT = xn .T @ xn
687
- dynamics_stats = (sum_zpzpT , sum_zpxnT , sum_xnxnT , num_timesteps - 1 )
733
+ sum_zpzpT = jnp .einsum ('ti,tj->tij' , zp , zp )
734
+ sum_zpxnT = jnp .einsum ('ti,tj->tij' , zp , xn )
735
+ sum_xnxnT = jnp .einsum ('ti,tj->tij' , xn , xn )
736
+ n_t_dynamics = jnp .ones (num_timesteps - 1 )
737
+ # The dynamics stats have a leading time dimension.
738
+ dynamics_stats = (sum_zpzpT , sum_zpxnT , sum_xnxnT , n_t_dynamics )
688
739
if not self .has_dynamics_bias :
689
- dynamics_stats = (sum_zpzpT [:- 1 , :- 1 ], sum_zpxnT [:- 1 , :], sum_xnxnT ,
690
- num_timesteps - 1 )
740
+ dynamics_stats = (sum_zpzpT [:, : - 1 , :- 1 ], sum_zpxnT [:, :- 1 , :], sum_xnxnT ,
741
+ n_t_dynamics )
691
742
692
743
# Quantities for the emissions
693
- # Let z[t] = [x[t], u[t]] for t = 0...T-1
694
- sum_zzT = jnp .block ([[x .T @ x , x .T @ u ], [u .T @ x , u .T @ u ]])
695
- sum_zyT = jnp .block ([[x .T @ y ], [u .T @ y ]])
696
- sum_yyT = y .T @ y
697
- emission_stats = (sum_zzT , sum_zyT , sum_yyT , num_timesteps )
744
+ sum_zzT = jnp .einsum ('ti,tj->tij' , z , z )
745
+ sum_zyT = jnp .einsum ('ti,tj->tij' , z , y )
746
+ sum_yyT = jnp .einsum ('ti,tj->tij' , y , y )
747
+ n_t_emissions = jnp .ones (num_timesteps )
748
+ # The emissions stats have a leading time dimension.
749
+ emission_stats = (sum_zzT , sum_zyT , sum_yyT , n_t_emissions )
698
750
if not self .has_emissions_bias :
699
- emission_stats = (sum_zzT [:- 1 , :- 1 ], sum_zyT [:- 1 , :], sum_yyT , num_timesteps )
751
+ emission_stats = (sum_zzT [:, : - 1 , :- 1 ], sum_zyT [:, : - 1 , :], sum_yyT , n_t_emissions )
700
752
701
753
return init_stats , dynamics_stats , emission_stats
702
754
703
- def lgssm_params_sample (rng , stats ):
704
- """Sample parameters of the model given sufficient statistics from observed states and emissions."""
705
- init_stats , dynamics_stats , emission_stats = stats
706
- rngs = iter (jr .split (rng , 3 ))
707
-
708
- # Sample the initial params
755
+ def _sample_initial_params (rng , init_stats ):
709
756
initial_posterior = niw_posterior_update (self .initial_prior , init_stats )
710
- S , m = initial_posterior .sample (seed = next (rngs ))
757
+ S , m = initial_posterior .sample (seed = rng )
758
+ return ParamsLGSSMInitial (mean = m , cov = S )
711
759
712
- # Sample the dynamics params
760
+ def _sample_dynamics_params ( rng , dynamics_stats ):
713
761
dynamics_posterior = mniw_posterior_update (self .dynamics_prior , dynamics_stats )
714
- Q , FB = dynamics_posterior .sample (seed = next ( rngs ) )
762
+ Q , FB = dynamics_posterior .sample (seed = rng )
715
763
F = FB [:, :self .state_dim ]
716
764
B , b = (FB [:, self .state_dim :- 1 ], FB [:, - 1 ]) if self .has_dynamics_bias \
717
765
else (FB [:, self .state_dim :], jnp .zeros (self .state_dim ))
766
+ return ParamsLGSSMDynamics (weights = F , bias = b , input_weights = B , cov = Q )
718
767
719
- # Sample the emission params
768
+ def _sample_emission_params ( rng , emission_stats ):
720
769
emission_posterior = mniw_posterior_update (self .emission_prior , emission_stats )
721
- R , HD = emission_posterior .sample (seed = next ( rngs ) )
770
+ R , HD = emission_posterior .sample (seed = rng )
722
771
H = HD [:, :self .state_dim ]
723
772
D , d = (HD [:, self .state_dim :- 1 ], HD [:, - 1 ]) if self .has_emissions_bias \
724
773
else (HD [:, self .state_dim :], jnp .zeros (self .emission_dim ))
774
+ return ParamsLGSSMEmissions (weights = H , bias = d , input_weights = D , cov = R )
775
+
776
+ def lgssm_params_sample (rng , stats ):
777
+ """Sample parameters of the model given sufficient statistics from observed states and emissions."""
778
+ init_stats , dynamics_stats , emission_stats = stats
779
+ rngs = iter (jr .split (rng , 3 ))
780
+
781
+ # Sample the initial params
782
+ initial_params = _sample_initial_params (next (rngs ), init_stats )
783
+
784
+ # Sample the dynamics and emission params.
785
+ if not is_inhomogeneous :
786
+ # Aggregate summary statistics across time for homogeneous model.
787
+ dynamics_stats = tree .map (lambda x : jnp .sum (x , axis = 0 ), dynamics_stats )
788
+ emission_stats = tree .map (lambda x : jnp .sum (x , axis = 0 ), emission_stats )
789
+ dynamics_params = _sample_dynamics_params (next (rngs ), dynamics_stats )
790
+ emission_params = _sample_emission_params (next (rngs ), emission_stats )
791
+ else :
792
+ keys_dynamics = jr .split (next (rngs ), num_timesteps - 1 )
793
+ keys_emission = jr .split (next (rngs ), num_timesteps )
794
+ dynamics_params = vmap (_sample_dynamics_params )(keys_dynamics , dynamics_stats )
795
+ emission_params = vmap (_sample_emission_params )(keys_emission , emission_stats )
725
796
726
797
params = ParamsLGSSM (
727
- initial = ParamsLGSSMInitial ( mean = m , cov = S ) ,
728
- dynamics = ParamsLGSSMDynamics ( weights = F , bias = b , input_weights = B , cov = Q ) ,
729
- emissions = ParamsLGSSMEmissions ( weights = H , bias = d , input_weights = D , cov = R )
798
+ initial = initial_params ,
799
+ dynamics = dynamics_params ,
800
+ emissions = emission_params ,
730
801
)
731
802
return params
732
803
0 commit comments