diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py index 63f139a63e..920c9705d3 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py @@ -725,7 +725,9 @@ def particle_dynamics(params, _, previous_state): broadcasted_params = tf.broadcast_to(reshaped_params, previous_state.shape) reshaped_dist = independent.Independent( - normal.Normal(previous_state + params[..., tf.newaxis, tf.newaxis] + 1, 0.1), + normal.Normal( + previous_state + params[..., tf.newaxis, tf.newaxis] + 1, 0.1 + ), reinterpreted_batch_ndims=1 ) return reshaped_dist