From 87d2d2404cda9c4efa04282a645098e228552d61 Mon Sep 17 00:00:00 2001 From: slamitza Date: Mon, 29 Jan 2024 17:48:40 +0100 Subject: [PATCH] particles dim fix --- .../python/experimental/mcmc/particle_filter.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter.py b/tensorflow_probability/python/experimental/mcmc/particle_filter.py index c5060747ee..9c83fbbbe4 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter.py @@ -1005,6 +1005,7 @@ def _particle_filter_initial_weighted_particles(observations, # initial extra for particle filter extra = tf.constant(0) + # initial_state is [3, 1000, 2] perche' particles_dim = 1 # Return particles weighted by the initial observation. return smc_kernel.WeightedParticles( particles=initial_state, @@ -1103,14 +1104,12 @@ def _compute_observation_log_weights(step, observation_idx = step // num_transitions_per_observation observation = tf.nest.map_structure( lambda x, step=step: tf.gather(x, observation_idx), observations) - - if particles_dim != 1: - observation = tf.nest.map_structure( - lambda x: tf.expand_dims(x, axis=particles_dim), observation - ) + if particles_dim != 0: + observation = tf.nest.map_structure( + lambda x: tf.expand_dims(x, axis=particles_dim), observation + ) log_weights = observation_fn(step, particles).log_prob(observation) - return tf.where(step_has_observation, log_weights, tf.zeros_like(log_weights))