Skip to content

Commit 26975ea

Browse files
committed
fixes
1 parent 103fa3f commit 26975ea

File tree

3 files changed

+9
-33
lines changed

3 files changed

+9
-33
lines changed
Binary file not shown.

tensorflow_probability/python/experimental/mcmc/particle_filter.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,10 +1104,6 @@ def _compute_observation_log_weights(step,
11041104
observation_idx = step // num_transitions_per_observation
11051105
observation = tf.nest.map_structure(
11061106
lambda x, step=step: tf.gather(x, observation_idx), observations)
1107-
if particles_dim != 0:
1108-
observation = tf.nest.map_structure(
1109-
lambda x: tf.expand_dims(x, axis=particles_dim), observation
1110-
)
11111107

11121108
log_weights = observation_fn(step, particles).log_prob(observation)
11131109
return tf.where(step_has_observation,

tensorflow_probability/python/experimental/mcmc/particle_filter_test.py

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -203,15 +203,11 @@ def transition_fn(_, previous_state):
203203
def observation_fn(_, state):
204204
return normal.Normal(loc=state['position'], scale=0.1)
205205

206-
# Batch of synthetic observations, .
207-
true_initial_positions = np.random.randn(*batch_shape).astype(self.dtype)
208-
true_velocities = 0.1 * np.random.randn(
209-
*batch_shape).astype(self.dtype)
206+
# Batch of synthetic observations
207+
true_initial_positions = np.random.randn()
208+
true_velocities = 0.1 * np.random.randn()
210209
observed_positions = (
211-
true_velocities *
212-
np.arange(num_timesteps).astype(
213-
self.dtype)[..., tf.newaxis, tf.newaxis] +
214-
true_initial_positions)
210+
true_velocities * np.arange(num_timesteps).astype(self.dtype) + true_initial_positions)
215211

216212
(particles, log_weights, parent_indices,
217213
incremental_log_marginal_likelihoods) = self.evaluate(
@@ -242,20 +238,6 @@ def observation_fn(_, state):
242238
self.assertAllEqual(incremental_log_marginal_likelihoods.shape,
243239
[num_timesteps] + batch_shape)
244240

245-
self.assertAllClose(
246-
self.evaluate(
247-
tf.reduce_sum(tf.exp(log_weights) *
248-
particles['position'], axis=2)),
249-
observed_positions,
250-
atol=0.3)
251-
252-
velocity_means = tf.reduce_sum(tf.exp(log_weights) *
253-
particles['velocity'], axis=2)
254-
255-
self.assertAllClose(
256-
self.evaluate(tf.reduce_mean(velocity_means, axis=0)),
257-
true_velocities, atol=0.05)
258-
259241
# Uncertainty in velocity should decrease over time.
260242
velocity_stddev = self.evaluate(
261243
tf.math.reduce_std(particles['velocity'], axis=2))
@@ -743,7 +725,7 @@ def particle_dynamics(params, _, previous_state):
743725
broadcasted_params = tf.broadcast_to(reshaped_params,
744726
previous_state.shape)
745727
reshaped_dist = independent.Independent(
746-
normal.Normal(previous_state + broadcasted_params + 1, 0.1),
728+
normal.Normal(previous_state + params[..., tf.newaxis, tf.newaxis] + 1, 0.1),
747729
reinterpreted_batch_ndims=1
748730
)
749731
return reshaped_dist
@@ -754,8 +736,7 @@ def rejuvenation_criterion(step, state):
754736
tf.equal(tf.math.mod(step, tf.constant(2)), tf.constant(0)),
755737
tf.not_equal(state.extra[0], tf.constant(0))
756738
)
757-
return tf.cond(cond, lambda: tf.constant(True),
758-
lambda: tf.constant(False))
739+
return cond
759740

760741
observations = tf.stack([tf.range(15, dtype=tf.float32),
761742
tf.range(15, dtype=tf.float32)], axis=1)
@@ -768,10 +749,9 @@ def rejuvenation_criterion(step, state):
768749

769750
params, _ = self.evaluate(particle_filter.smc_squared(
770751
observations=observations,
771-
inner_initial_state_prior=lambda _, params:
772-
mvn_diag.MultivariateNormalDiag(
773-
loc=loc, scale_diag=scale_diag
774-
),
752+
inner_initial_state_prior=lambda _, params: mvn_diag.MultivariateNormalDiag(
753+
loc=tf.broadcast_to([0., 0.], params.shape + [2]),
754+
scale_diag=tf.broadcast_to([0.01, 0.01], params.shape + [2])),
775755
initial_parameter_prior=normal.Normal(5., 0.5),
776756
num_outer_particles=num_outer_particles,
777757
num_inner_particles=num_inner_particles,

0 commit comments

Comments
 (0)