Skip to content

Commit

Permalink
fixed errors
Browse files Browse the repository at this point in the history
  • Loading branch information
aleslamitz committed Feb 13, 2024
1 parent 6ce8495 commit 222c197
Showing 1 changed file with 8 additions and 136 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,10 @@ def observation_fn(_, state):

# Batch of synthetic observations
true_initial_positions = np.random.randn()

true_velocities = 0.1 * np.random.randn()
observed_positions = (
true_velocities * np.arange(num_timesteps).astype(self.dtype) + true_initial_positions)
true_velocities * np.arange(num_timesteps).astype(self.dtype) +
true_initial_positions)

(particles, log_weights, parent_indices,
incremental_log_marginal_likelihoods) = self.evaluate(
Expand Down Expand Up @@ -288,128 +288,6 @@ def observation_fn(_, state):
self.assertAllEqual(incremental_log_marginal_likelihoods.shape,
[num_timesteps] + batch_shape)

def test_batch_of_filters_particles_dim_1(self):

batch_shape = [3, 2]
num_particles = 1000
num_timesteps = 40

# Batch of priors on object 1D positions and velocities.
initial_state_prior = jdn.JointDistributionNamed({
'position': normal.Normal(loc=0., scale=tf.ones(batch_shape)),
'velocity': normal.Normal(loc=0., scale=tf.ones(batch_shape) * 0.1)
})

def transition_fn(_, previous_state):
return jdn.JointDistributionNamed({
'position':
normal.Normal(
loc=previous_state['position'] + previous_state['velocity'],
scale=0.1),
'velocity':
normal.Normal(loc=previous_state['velocity'], scale=0.01)
})

def observation_fn(_, state):
return normal.Normal(loc=state['position'], scale=0.1)

# Batch of synthetic observations, .
true_initial_positions = np.random.randn(*batch_shape).astype(self.dtype)
true_velocities = 0.1 * np.random.randn(
*batch_shape).astype(self.dtype)
observed_positions = (
true_velocities *
np.arange(num_timesteps).astype(
self.dtype)[..., tf.newaxis, tf.newaxis] +
true_initial_positions)

(particles, log_weights, parent_indices,
incremental_log_marginal_likelihoods) = self.evaluate(
particle_filter.particle_filter(
observations=observed_positions,
initial_state_prior=initial_state_prior,
transition_fn=transition_fn,
observation_fn=observation_fn,
num_particles=num_particles,
seed=test_util.test_seed(),
particles_dim=1))

self.assertAllEqual(particles['position'].shape,
[num_timesteps,
batch_shape[0],
num_particles,
batch_shape[1]])
self.assertAllEqual(particles['velocity'].shape,
[num_timesteps,
batch_shape[0],
num_particles,
batch_shape[1]])
self.assertAllEqual(parent_indices.shape,
[num_timesteps,
batch_shape[0],
num_particles,
batch_shape[1]])
self.assertAllEqual(incremental_log_marginal_likelihoods.shape,
[num_timesteps] + batch_shape)

self.assertAllClose(
self.evaluate(
tf.reduce_sum(tf.exp(log_weights) *
particles['position'], axis=2)),
observed_positions,
atol=0.3)

velocity_means = tf.reduce_sum(tf.exp(log_weights) *
particles['velocity'], axis=2)

self.assertAllClose(
self.evaluate(tf.reduce_mean(velocity_means, axis=0)),
true_velocities, atol=0.05)

# Uncertainty in velocity should decrease over time.
velocity_stddev = self.evaluate(
tf.math.reduce_std(particles['velocity'], axis=2))
self.assertAllLess((velocity_stddev[-1] - velocity_stddev[0]), 0.)

trajectories = self.evaluate(
particle_filter.reconstruct_trajectories(particles,
parent_indices,
particles_dim=1))
self.assertAllEqual([num_timesteps,
batch_shape[0],
num_particles,
batch_shape[1]],
trajectories['position'].shape)
self.assertAllEqual([num_timesteps,
batch_shape[0],
num_particles,
batch_shape[1]],
trajectories['velocity'].shape)

# Verify that `infer_trajectories` also works on batches.
trajectories, incremental_log_marginal_likelihoods = self.evaluate(
particle_filter.infer_trajectories(
observations=observed_positions,
initial_state_prior=initial_state_prior,
transition_fn=transition_fn,
observation_fn=observation_fn,
num_particles=num_particles,
particles_dim=1,
seed=test_util.test_seed()))

self.assertAllEqual([num_timesteps,
batch_shape[0],
num_particles,
batch_shape[1]],
trajectories['position'].shape)
self.assertAllEqual([num_timesteps,
batch_shape[0],
num_particles,
batch_shape[1]],
trajectories['velocity'].shape)
self.assertAllEqual(incremental_log_marginal_likelihoods.shape,
[num_timesteps] + batch_shape)

def test_reconstruct_trajectories_toy_example(self):
particles = tf.convert_to_tensor([[1, 2, 3], [4, 5, 6,], [7, 8, 9]])
# 1 -- 4 -- 7
Expand Down Expand Up @@ -847,13 +725,10 @@ def marginal_log_likelihood(level_scale, noise_scale):

def test_smc_squared_rejuvenation_parameters(self):
def particle_dynamics(params, _, previous_state):
reshaped_params = tf.reshape(params,
[params.shape[0]] +
[1] * (previous_state.shape.rank - 1))
broadcasted_params = tf.broadcast_to(reshaped_params,
previous_state.shape)
reshaped_dist = independent.Independent(
normal.Normal(previous_state + broadcasted_params + 1, 0.1),
normal.Normal(
previous_state + params[..., tf.newaxis, tf.newaxis] + 1, 0.1
),
reinterpreted_batch_ndims=1
)
return reshaped_dist
Expand All @@ -864,23 +739,20 @@ def rejuvenation_criterion(step, state):
tf.equal(tf.math.mod(step, tf.constant(2)), tf.constant(0)),
tf.not_equal(state.extra[0], tf.constant(0))
)
return tf.cond(cond, lambda: tf.constant(True),
lambda: tf.constant(False))
return cond

observations = tf.stack([tf.range(15, dtype=tf.float32),
tf.range(15, dtype=tf.float32)], axis=1)

num_outer_particles = 3
num_inner_particles = 5

loc = tf.broadcast_to([0., 0.], [num_outer_particles, 2])
scale_diag = tf.broadcast_to([0.01, 0.01], [num_outer_particles, 2])

params, _ = self.evaluate(particle_filter.smc_squared(
observations=observations,
inner_initial_state_prior=lambda _, params:
mvn_diag.MultivariateNormalDiag(
loc=loc, scale_diag=scale_diag
loc=tf.broadcast_to([0., 0.], params.shape + [2]),
scale_diag=tf.broadcast_to([0.01, 0.01], params.shape + [2])
),
initial_parameter_prior=normal.Normal(5., 0.5),
num_outer_particles=num_outer_particles,
Expand Down

0 comments on commit 222c197

Please sign in to comment.