From 222c197fd8942536c5320f3fd692fd15b6e5003b Mon Sep 17 00:00:00 2001 From: slamitza Date: Tue, 13 Feb 2024 19:44:10 +0100 Subject: [PATCH] fixed errors --- .../experimental/mcmc/particle_filter_test.py | 144 +----------------- 1 file changed, 8 insertions(+), 136 deletions(-) diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py index 96c7243896..7ddfd3ac30 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py @@ -205,10 +205,10 @@ def observation_fn(_, state): # Batch of synthetic observations true_initial_positions = np.random.randn() - true_velocities = 0.1 * np.random.randn() observed_positions = ( - true_velocities * np.arange(num_timesteps).astype(self.dtype) + true_initial_positions) + true_velocities * np.arange(num_timesteps).astype(self.dtype) + + true_initial_positions) (particles, log_weights, parent_indices, incremental_log_marginal_likelihoods) = self.evaluate( @@ -288,128 +288,6 @@ def observation_fn(_, state): self.assertAllEqual(incremental_log_marginal_likelihoods.shape, [num_timesteps] + batch_shape) - def test_batch_of_filters_particles_dim_1(self): - - batch_shape = [3, 2] - num_particles = 1000 - num_timesteps = 40 - - # Batch of priors on object 1D positions and velocities. - initial_state_prior = jdn.JointDistributionNamed({ - 'position': normal.Normal(loc=0., scale=tf.ones(batch_shape)), - 'velocity': normal.Normal(loc=0., scale=tf.ones(batch_shape) * 0.1) - }) - - def transition_fn(_, previous_state): - return jdn.JointDistributionNamed({ - 'position': - normal.Normal( - loc=previous_state['position'] + previous_state['velocity'], - scale=0.1), - 'velocity': - normal.Normal(loc=previous_state['velocity'], scale=0.01) - }) - - def observation_fn(_, state): - return normal.Normal(loc=state['position'], scale=0.1) - - # Batch of synthetic observations, . - true_initial_positions = np.random.randn(*batch_shape).astype(self.dtype) - true_velocities = 0.1 * np.random.randn( - *batch_shape).astype(self.dtype) - observed_positions = ( - true_velocities * - np.arange(num_timesteps).astype( - self.dtype)[..., tf.newaxis, tf.newaxis] + - true_initial_positions) - - (particles, log_weights, parent_indices, - incremental_log_marginal_likelihoods) = self.evaluate( - particle_filter.particle_filter( - observations=observed_positions, - initial_state_prior=initial_state_prior, - transition_fn=transition_fn, - observation_fn=observation_fn, - num_particles=num_particles, - seed=test_util.test_seed(), - particles_dim=1)) - - self.assertAllEqual(particles['position'].shape, - [num_timesteps, - batch_shape[0], - num_particles, - batch_shape[1]]) - self.assertAllEqual(particles['velocity'].shape, - [num_timesteps, - batch_shape[0], - num_particles, - batch_shape[1]]) - self.assertAllEqual(parent_indices.shape, - [num_timesteps, - batch_shape[0], - num_particles, - batch_shape[1]]) - self.assertAllEqual(incremental_log_marginal_likelihoods.shape, - [num_timesteps] + batch_shape) - - self.assertAllClose( - self.evaluate( - tf.reduce_sum(tf.exp(log_weights) * - particles['position'], axis=2)), - observed_positions, - atol=0.3) - - velocity_means = tf.reduce_sum(tf.exp(log_weights) * - particles['velocity'], axis=2) - - self.assertAllClose( - self.evaluate(tf.reduce_mean(velocity_means, axis=0)), - true_velocities, atol=0.05) - - # Uncertainty in velocity should decrease over time. - velocity_stddev = self.evaluate( - tf.math.reduce_std(particles['velocity'], axis=2)) - self.assertAllLess((velocity_stddev[-1] - velocity_stddev[0]), 0.) - - trajectories = self.evaluate( - particle_filter.reconstruct_trajectories(particles, - parent_indices, - particles_dim=1)) - self.assertAllEqual([num_timesteps, - batch_shape[0], - num_particles, - batch_shape[1]], - trajectories['position'].shape) - self.assertAllEqual([num_timesteps, - batch_shape[0], - num_particles, - batch_shape[1]], - trajectories['velocity'].shape) - - # Verify that `infer_trajectories` also works on batches. - trajectories, incremental_log_marginal_likelihoods = self.evaluate( - particle_filter.infer_trajectories( - observations=observed_positions, - initial_state_prior=initial_state_prior, - transition_fn=transition_fn, - observation_fn=observation_fn, - num_particles=num_particles, - particles_dim=1, - seed=test_util.test_seed())) - - self.assertAllEqual([num_timesteps, - batch_shape[0], - num_particles, - batch_shape[1]], - trajectories['position'].shape) - self.assertAllEqual([num_timesteps, - batch_shape[0], - num_particles, - batch_shape[1]], - trajectories['velocity'].shape) - self.assertAllEqual(incremental_log_marginal_likelihoods.shape, - [num_timesteps] + batch_shape) - def test_reconstruct_trajectories_toy_example(self): particles = tf.convert_to_tensor([[1, 2, 3], [4, 5, 6,], [7, 8, 9]]) # 1 -- 4 -- 7 @@ -847,13 +725,10 @@ def marginal_log_likelihood(level_scale, noise_scale): def test_smc_squared_rejuvenation_parameters(self): def particle_dynamics(params, _, previous_state): - reshaped_params = tf.reshape(params, - [params.shape[0]] + - [1] * (previous_state.shape.rank - 1)) - broadcasted_params = tf.broadcast_to(reshaped_params, - previous_state.shape) reshaped_dist = independent.Independent( - normal.Normal(previous_state + broadcasted_params + 1, 0.1), + normal.Normal( + previous_state + params[..., tf.newaxis, tf.newaxis] + 1, 0.1 + ), reinterpreted_batch_ndims=1 ) return reshaped_dist @@ -864,8 +739,7 @@ def rejuvenation_criterion(step, state): tf.equal(tf.math.mod(step, tf.constant(2)), tf.constant(0)), tf.not_equal(state.extra[0], tf.constant(0)) ) - return tf.cond(cond, lambda: tf.constant(True), - lambda: tf.constant(False)) + return cond observations = tf.stack([tf.range(15, dtype=tf.float32), tf.range(15, dtype=tf.float32)], axis=1) @@ -873,14 +747,12 @@ def rejuvenation_criterion(step, state): num_outer_particles = 3 num_inner_particles = 5 - loc = tf.broadcast_to([0., 0.], [num_outer_particles, 2]) - scale_diag = tf.broadcast_to([0.01, 0.01], [num_outer_particles, 2]) - params, _ = self.evaluate(particle_filter.smc_squared( observations=observations, inner_initial_state_prior=lambda _, params: mvn_diag.MultivariateNormalDiag( - loc=loc, scale_diag=scale_diag + loc=tf.broadcast_to([0., 0.], params.shape + [2]), + scale_diag=tf.broadcast_to([0.01, 0.01], params.shape + [2]) ), initial_parameter_prior=normal.Normal(5., 0.5), num_outer_particles=num_outer_particles,