diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter.py b/tensorflow_probability/python/experimental/mcmc/particle_filter.py index 9c83fbbbe4..34ac008615 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter.py @@ -1105,9 +1105,9 @@ def _compute_observation_log_weights(step, observation = tf.nest.map_structure( lambda x, step=step: tf.gather(x, observation_idx), observations) if particles_dim != 0: - observation = tf.nest.map_structure( - lambda x: tf.expand_dims(x, axis=particles_dim), observation - ) + observation = tf.nest.map_structure( + lambda x: tf.expand_dims(x, axis=particles_dim), observation + ) log_weights = observation_fn(step, particles).log_prob(observation) return tf.where(step_has_observation,