Skip to content

Commit

Permalink
particles dim fix
Browse files Browse the repository at this point in the history
  • Loading branch information
aleslamitz committed Jan 29, 2024
1 parent ee5d2e8 commit 87d2d24
Showing 1 changed file with 5 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,7 @@ def _particle_filter_initial_weighted_particles(observations,
# initial extra for particle filter
extra = tf.constant(0)

# initial_state is [3, 1000, 2] perche' particles_dim = 1
# Return particles weighted by the initial observation.
return smc_kernel.WeightedParticles(
particles=initial_state,
Expand Down Expand Up @@ -1103,14 +1104,12 @@ def _compute_observation_log_weights(step,
observation_idx = step // num_transitions_per_observation
observation = tf.nest.map_structure(
lambda x, step=step: tf.gather(x, observation_idx), observations)

if particles_dim != 1:
observation = tf.nest.map_structure(
lambda x: tf.expand_dims(x, axis=particles_dim), observation
)
if particles_dim != 0:
observation = tf.nest.map_structure(
lambda x: tf.expand_dims(x, axis=particles_dim), observation
)

log_weights = observation_fn(step, particles).log_prob(observation)

return tf.where(step_has_observation,
log_weights,
tf.zeros_like(log_weights))
Expand Down

0 comments on commit 87d2d24

Please sign in to comment.