Skip to content

Commit

Permalink
fixed one test
Browse files Browse the repository at this point in the history
  • Loading branch information
aleslamitz committed Jan 21, 2024
1 parent 4c5f86e commit ee5d2e8
Showing 1 changed file with 5 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,6 @@ def particle_dynamics(params, _, previous_state):
normal.Normal(previous_state + broadcasted_params + 1, 0.1),
reinterpreted_batch_ndims=1
)

return reshaped_dist

def rejuvenation_criterion(step, state):
Expand All @@ -758,22 +757,22 @@ def rejuvenation_criterion(step, state):
return tf.cond(cond, lambda: tf.constant(True),
lambda: tf.constant(False))

observations = tf.stack([tf.range(30, dtype=tf.float32),
tf.range(30, dtype=tf.float32)], axis=1)
observations = tf.stack([tf.range(15, dtype=tf.float32),
tf.range(15, dtype=tf.float32)], axis=1)

num_outer_particles = 3
num_inner_particles = 7
num_inner_particles = 5

loc = tf.broadcast_to([0., 0.], [num_outer_particles, 2])
scale_diag = tf.broadcast_to([0.05, 0.05], [num_outer_particles, 2])
scale_diag = tf.broadcast_to([0.01, 0.01], [num_outer_particles, 2])

params, _ = self.evaluate(particle_filter.smc_squared(
observations=observations,
inner_initial_state_prior=lambda _, params:
mvn_diag.MultivariateNormalDiag(
loc=loc, scale_diag=scale_diag
),
initial_parameter_prior=normal.Normal(3., 1.),
initial_parameter_prior=normal.Normal(5., 0.5),
num_outer_particles=num_outer_particles,
num_inner_particles=num_inner_particles,
outer_rejuvenation_criterion_fn=rejuvenation_criterion,
Expand Down

0 comments on commit ee5d2e8

Please sign in to comment.