diff --git a/tensorflow_probability/python/experimental/mcmc/BUILD b/tensorflow_probability/python/experimental/mcmc/BUILD index 7fb140497b..a96087e35c 100644 --- a/tensorflow_probability/python/experimental/mcmc/BUILD +++ b/tensorflow_probability/python/experimental/mcmc/BUILD @@ -548,6 +548,9 @@ multi_substrate_py_library( "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:tensor_util", "//tensorflow_probability/python/internal:tensorshape_util", + "//tensorflow_probability/python/distributions:batch_reshape", + "//tensorflow_probability/python/distributions:batch_broadcast", + "//tensorflow_probability/python/distributions:independent" ], ) @@ -574,6 +577,8 @@ multi_substrate_py_test( "//tensorflow_probability/python/distributions:sample", "//tensorflow_probability/python/distributions:transformed_distribution", "//tensorflow_probability/python/distributions:uniform", + "//tensorflow_probability/python/distributions:categorical", + "//tensorflow_probability/python/distributions:hidden_markov_model", "//tensorflow_probability/python/internal:test_util", "//tensorflow_probability/python/math:gradient", # "//third_party/tensorflow/compiler/jit:xla_cpu_jit", # DisableOnExport @@ -652,6 +657,7 @@ multi_substrate_py_test( "//tensorflow_probability/python/distributions:mvn_diag", "//tensorflow_probability/python/distributions:normal", "//tensorflow_probability/python/distributions:sample", + "//tensorflow_probability/python/experimental/mcmc:sequential_monte_carlo_kernel", "//tensorflow_probability/python/distributions:uniform", "//tensorflow_probability/python/distributions/internal:statistical_testing", "//tensorflow_probability/python/internal:test_util", diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter.py b/tensorflow_probability/python/experimental/mcmc/particle_filter.py index 1bcbc870f4..9d1df1bb89 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter.py @@ -25,6 +25,9 @@ from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.mcmc.internal import util as mcmc_util +from tensorflow_probability.python.distributions import normal +from tensorflow_probability.python.distributions import uniform + __all__ = [ 'infer_trajectories', @@ -44,6 +47,13 @@ def _default_trace_fn(state, kernel_results): kernel_results.incremental_log_marginal_likelihood) +def _default_kernel(parameters): + mean, variance = tf.nn.moments(parameters, axes=[0]) + proposal_distribution = normal.Normal(loc=tf.fill(parameters.shape, mean), + scale=tf.sqrt(variance)) + return proposal_distribution + + particle_filter_arg_str = """\ Each latent state is a `Tensor` or nested structure of `Tensor`s, as defined by the `initial_state_prior`. @@ -435,6 +445,421 @@ def seeded_one_step(seed_state_results, _): return traced_results +def smc_squared( + observations, + initial_parameter_prior, + num_outer_particles, + inner_initial_state_prior, + inner_transition_fn, + inner_observation_fn, + num_inner_particles, + outer_trace_fn=_default_trace_fn, + outer_rejuvenation_criterion_fn=None, + outer_resample_criterion_fn=None, + outer_resample_fn=weighted_resampling.resample_systematic, + inner_resample_criterion_fn=smc_kernel.ess_below_threshold, + inner_resample_fn=weighted_resampling.resample_systematic, + parameter_proposal_kernel=_default_kernel, + inner_proposal_fn=None, + inner_initial_state_proposal=None, + outer_trace_criterion_fn=_always_trace, + parallel_iterations=1, + num_transitions_per_observation=1, + static_trace_allocation_size=None, + initial_parameter_proposal=None, + unbiased_gradients=True, + seed=None, +): + params_seed, particles_seed, smc_seed = samplers.split_seed( + seed, n=3, salt='smc_squared' + ) + + num_observation_steps = ps.size0(tf.nest.flatten(observations)[0]) + + # TODO: The following two lines compensates for having the + # first empty step in smc2 + num_timesteps = (1 + num_transitions_per_observation * + (num_observation_steps - 1)) + 1 + last_obs_expanded = tf.expand_dims(observations[-1], axis=0) + inner_observations = tf.concat([observations, + last_obs_expanded], + axis=0) + + if outer_rejuvenation_criterion_fn is None: + outer_rejuvenation_criterion_fn = lambda *_: tf.constant(False) + + if outer_resample_criterion_fn is None: + outer_resample_criterion_fn = lambda *_: tf.constant(False) + + # If trace criterion is `None`, we'll return only the final results. + never_trace = lambda *_: False + if outer_trace_criterion_fn is None: + static_trace_allocation_size = 0 + outer_trace_criterion_fn = never_trace + + if initial_parameter_proposal is None: + initial_state = initial_parameter_prior.sample(num_outer_particles, + seed=params_seed) + initial_log_weights = ps.zeros_like( + initial_parameter_prior.log_prob(initial_state)) + else: + initial_state = initial_parameter_proposal.sample(num_outer_particles, + seed=params_seed) + initial_log_weights = ( + initial_parameter_prior.log_prob(initial_state) - + initial_parameter_proposal.log_prob(initial_state) + ) + + # Normalize the initial weights. If we used a proposal, the weights are + # normalized in expectation, but actually normalizing them reduces variance. + initial_log_weights = tf.nn.log_softmax(initial_log_weights, axis=0) + + inner_weighted_particles = _particle_filter_initial_weighted_particles( + observations=inner_observations, + observation_fn=inner_observation_fn(initial_state), + initial_state_prior=inner_initial_state_prior(0, initial_state), + initial_state_proposal=(inner_initial_state_proposal(0, initial_state) + if inner_initial_state_proposal is not None + else None), + num_particles=num_inner_particles, + particles_dim=1, + seed=particles_seed + ) + + init_state = smc_kernel.WeightedParticles(*inner_weighted_particles) + + batch_zeros = tf.zeros(ps.shape(initial_state)) + + initial_filter_results = smc_kernel.SequentialMonteCarloResults( + steps=0, + parent_indices=smc_kernel._dummy_indices_like(init_state.log_weights), + incremental_log_marginal_likelihood=batch_zeros, + accumulated_log_marginal_likelihood=batch_zeros, + seed=samplers.zeros_seed()) + + initial_state = smc_kernel.WeightedParticles( + particles=(initial_state, + inner_weighted_particles, + initial_filter_results.parent_indices, + initial_filter_results.incremental_log_marginal_likelihood, + initial_filter_results.accumulated_log_marginal_likelihood), + log_weights=initial_log_weights, + extra=initial_filter_results.seed + ) + + outer_propose_and_update_log_weights_fn = ( + _outer_particle_filter_propose_and_update_log_weights_fn( + outer_rejuvenation_criterion_fn=outer_rejuvenation_criterion_fn, + inner_observations=inner_observations, + inner_transition_fn=inner_transition_fn, + inner_proposal_fn=inner_proposal_fn, + inner_observation_fn=inner_observation_fn, + inner_resample_fn=inner_resample_fn, + inner_resample_criterion_fn=inner_resample_criterion_fn, + parameter_proposal_kernel=parameter_proposal_kernel, + initial_parameter_prior=initial_parameter_prior, + num_transitions_per_observation=num_transitions_per_observation, + unbiased_gradients=unbiased_gradients, + inner_initial_state_prior=inner_initial_state_prior, + inner_initial_state_proposal=inner_initial_state_proposal, + num_inner_particles=num_inner_particles, + num_outer_particles=num_outer_particles + ) + ) + + traced_results = sequential_monte_carlo( + initial_weighted_particles=initial_state, + propose_and_update_log_weights_fn= + outer_propose_and_update_log_weights_fn, + resample_fn=outer_resample_fn, + resample_criterion_fn=outer_resample_criterion_fn, + trace_criterion_fn=outer_trace_criterion_fn, + static_trace_allocation_size=static_trace_allocation_size, + parallel_iterations=parallel_iterations, + unbiased_gradients=unbiased_gradients, + num_steps=num_timesteps, + particles_dim=0, + trace_fn=outer_trace_fn, + seed=smc_seed + ) + + return traced_results + + +def _outer_particle_filter_propose_and_update_log_weights_fn( + inner_observations, + inner_transition_fn, + inner_proposal_fn, + inner_observation_fn, + initial_parameter_prior, + inner_initial_state_prior, + inner_initial_state_proposal, + num_transitions_per_observation, + inner_resample_fn, + inner_resample_criterion_fn, + outer_rejuvenation_criterion_fn, + unbiased_gradients, + parameter_proposal_kernel, + num_inner_particles, + num_outer_particles +): + """Build a function specifying a particle filter update step.""" + def _outer_propose_and_update_log_weights_fn(step, state, seed=None): + ( + outside_parameters, + inner_particles, + inner_parent_indices, + inner_incremental_likelihood, + inner_accumulated_likelihood + ) = state.particles + log_weights = state.log_weights + + filter_results = smc_kernel.SequentialMonteCarloResults( + steps=step, + parent_indices=inner_parent_indices, + incremental_log_marginal_likelihood=inner_incremental_likelihood, + accumulated_log_marginal_likelihood=inner_accumulated_likelihood, + seed=state.extra) + + inner_propose_and_update_log_weights_fn = ( + _particle_filter_propose_and_update_log_weights_fn( + observations=inner_observations, + transition_fn=inner_transition_fn(outside_parameters), + proposal_fn=(inner_proposal_fn(outside_parameters) + if inner_proposal_fn is not None else None), + observation_fn=inner_observation_fn(outside_parameters), + particles_dim=1, + num_transitions_per_observation=num_transitions_per_observation + ) + ) + + kernel = smc_kernel.SequentialMonteCarlo( + propose_and_update_log_weights_fn= + inner_propose_and_update_log_weights_fn, + resample_fn=inner_resample_fn, + resample_criterion_fn=inner_resample_criterion_fn, + particles_dim=1, + unbiased_gradients=unbiased_gradients + ) + + inner_weighted_particles, filter_results = kernel.one_step( + inner_particles, + filter_results, + seed=seed + ) + + updated_log_weights = ( + log_weights + filter_results.incremental_log_marginal_likelihood + ) + + do_rejuvenation = outer_rejuvenation_criterion_fn(step, state) + + def rejuvenate_particles(outside_parameters, + updated_log_weights, + inner_weighted_particles, + filter_results): + proposed_parameters = parameter_proposal_kernel( + outside_parameters + ).sample(seed=seed) + + rej_params_log_weights = ps.zeros_like( + initial_parameter_prior.log_prob(proposed_parameters) + ) + rej_params_log_weights = tf.nn.log_softmax( + rej_params_log_weights, + axis=0 + ) + + rej_inner_weighted_particles = \ + _particle_filter_initial_weighted_particles( + observations=inner_observations, + observation_fn=inner_observation_fn(proposed_parameters), + initial_state_prior=inner_initial_state_prior( + 0, + proposed_parameters + ), + initial_state_proposal=( + inner_initial_state_proposal(0, proposed_parameters) + if inner_initial_state_proposal is not None + else None), + num_particles=num_inner_particles, + particles_dim=1, + seed=seed + ) + + batch_zeros = tf.zeros(ps.shape(log_weights)) + + rej_filter_results = smc_kernel.SequentialMonteCarloResults( + steps=tf.constant(0, dtype=tf.int32), + parent_indices=smc_kernel._dummy_indices_like( + rej_inner_weighted_particles.log_weights + ), + incremental_log_marginal_likelihood=batch_zeros, + accumulated_log_marginal_likelihood=batch_zeros, + seed=samplers.zeros_seed() + ) + + rej_inner_particles_weights = rej_inner_weighted_particles.log_weights + + rej_inner_propose_and_update_log_weights_fn = ( + _particle_filter_propose_and_update_log_weights_fn( + observations=inner_observations, + transition_fn=inner_transition_fn(proposed_parameters), + proposal_fn=(inner_proposal_fn(proposed_parameters) + if inner_proposal_fn is not None else None), + observation_fn=inner_observation_fn(proposed_parameters), + particles_dim=1, + num_transitions_per_observation= + num_transitions_per_observation) + ) + + rej_kernel = smc_kernel.SequentialMonteCarlo( + propose_and_update_log_weights_fn= + rej_inner_propose_and_update_log_weights_fn, + resample_fn=inner_resample_fn, + resample_criterion_fn=inner_resample_criterion_fn, + particles_dim=1, + unbiased_gradients=unbiased_gradients + ) + + def condition(i, + rej_inner_weighted_particles, + rej_filter_results, + rej_parameters_weights, + rej_params_log_weights): + return tf.less_equal(i, step) + + def body(i, + rej_inner_weighted_particles, + rej_filter_results, + rej_parameters_weights, + rej_params_log_weights + ): + + rej_inner_weighted_particles, rej_filter_results = \ + rej_kernel.one_step( + rej_inner_weighted_particles, rej_filter_results, seed=seed + ) + + rej_parameters_weights += rej_inner_weighted_particles.log_weights + + rej_params_log_weights = \ + rej_params_log_weights + \ + rej_filter_results.incremental_log_marginal_likelihood + return i + 1, \ + rej_inner_weighted_particles, \ + rej_filter_results, \ + rej_parameters_weights, \ + rej_params_log_weights + + _, \ + rej_inner_weighted_particles, \ + rej_filter_results, \ + rej_inner_particles_weights, \ + rej_params_log_weights = tf.while_loop( + condition, + body, + loop_vars=[0, + rej_inner_weighted_particles, + rej_filter_results, + rej_inner_particles_weights, + rej_params_log_weights + ] + ) + + log_a = rej_filter_results.accumulated_log_marginal_likelihood - \ + filter_results.accumulated_log_marginal_likelihood + \ + parameter_proposal_kernel( + proposed_parameters).log_prob(outside_parameters) - \ + parameter_proposal_kernel( + outside_parameters).log_prob(proposed_parameters) + + acceptance_probs = tf.minimum(1., tf.exp(log_a)) + + random_numbers = uniform.Uniform(0., 1.).sample(num_outer_particles, + seed=seed) + + # Determine if the proposed particle should be accepted or reject + accept = random_numbers > acceptance_probs + + # Update the chosen particles and filter restults + # based on the acceptance step + outside_parameters = tf.where(accept, + outside_parameters, + proposed_parameters) + updated_log_weights = tf.where(accept, + updated_log_weights, + rej_params_log_weights) + + inner_weighted_particles_particles = mcmc_util.choose( + accept, + inner_weighted_particles.particles, + rej_inner_weighted_particles.particles + ) + inner_weighted_particles_log_weights = mcmc_util.choose( + accept, + inner_weighted_particles.log_weights, + rej_inner_weighted_particles.log_weights + ) + + inner_weighted_particles = smc_kernel.WeightedParticles( + particles=inner_weighted_particles_particles, + log_weights=inner_weighted_particles_log_weights, + extra=inner_weighted_particles.extra + ) + + parent_indices, \ + incremental_log_marginal_likelihood, \ + accumulated_log_marginal_likelihood = mcmc_util.choose( + accept, + (filter_results.parent_indices, + filter_results.incremental_log_marginal_likelihood, + filter_results.accumulated_log_marginal_likelihood), + (rej_filter_results.parent_indices, + rej_filter_results.incremental_log_marginal_likelihood, + rej_filter_results.accumulated_log_marginal_likelihood) + ) + + filter_results = smc_kernel.SequentialMonteCarloResults( + steps=filter_results.steps, + parent_indices=parent_indices, + incremental_log_marginal_likelihood= + incremental_log_marginal_likelihood, + accumulated_log_marginal_likelihood= + accumulated_log_marginal_likelihood, + seed=filter_results.seed + ) + + return outside_parameters, updated_log_weights, \ + inner_weighted_particles, filter_results + + outside_parameters, \ + updated_log_weights, \ + inner_weighted_particles, \ + filter_results = tf.cond( + do_rejuvenation, + lambda: (rejuvenate_particles(outside_parameters, + updated_log_weights, + inner_weighted_particles, + filter_results)), + lambda: (outside_parameters, + updated_log_weights, + inner_weighted_particles, + filter_results) + ) + + return smc_kernel.WeightedParticles( + particles=(outside_parameters, + inner_weighted_particles, + filter_results.parent_indices, + filter_results.incremental_log_marginal_likelihood, + filter_results.accumulated_log_marginal_likelihood), + log_weights=updated_log_weights, + extra=filter_results.seed + ) + return _outer_propose_and_update_log_weights_fn + + @docstring_util.expand_docstring( particle_filter_arg_str=particle_filter_arg_str.format(scibor_ref_idx=1)) def particle_filter(observations, @@ -526,7 +951,8 @@ def particle_filter(observations, particles_dim=particles_dim, proposal_fn=proposal_fn, observation_fn=observation_fn, - num_transitions_per_observation=num_transitions_per_observation)) + num_transitions_per_observation=num_transitions_per_observation + )) return sequential_monte_carlo( initial_weighted_particles=initial_weighted_particles, @@ -549,6 +975,7 @@ def _particle_filter_initial_weighted_particles(observations, initial_state_proposal, num_particles, particles_dim=0, + extra=(), seed=None): """Initialize a set of weighted particles including the first observation.""" # Propose an initial state. @@ -573,6 +1000,12 @@ def _particle_filter_initial_weighted_particles(observations, initial_log_weights = tf.nn.log_softmax(initial_log_weights, axis=particles_dim) + if extra == (): + if len(ps.shape(initial_log_weights)) == 1: + # initial extra for particle filter + extra = tf.constant(0) + + # initial_state is [3, 1000, 2] perche' particles_dim = 1 # Return particles weighted by the initial observation. return smc_kernel.WeightedParticles( particles=initial_state, @@ -581,7 +1014,8 @@ def _particle_filter_initial_weighted_particles(observations, particles=initial_state, observations=observations, observation_fn=observation_fn, - particles_dim=particles_dim)) + particles_dim=particles_dim), + extra=extra) def _particle_filter_propose_and_update_log_weights_fn( @@ -625,7 +1059,8 @@ def propose_and_update_log_weights_fn(step, state, seed=None): log_weights=log_weights + _compute_observation_log_weights( step + 1, proposed_particles, observations, observation_fn, num_transitions_per_observation=num_transitions_per_observation, - particles_dim=particles_dim)) + particles_dim=particles_dim), + extra=state.extra) return propose_and_update_log_weights_fn @@ -670,9 +1105,6 @@ def _compute_observation_log_weights(step, observation = tf.nest.map_structure( lambda x, step=step: tf.gather(x, observation_idx), observations) - observation = tf.nest.map_structure( - lambda x: tf.expand_dims(x, axis=particles_dim), observation) - log_weights = observation_fn(step, particles).log_prob(observation) return tf.where(step_has_observation, log_weights, diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py index 6508eb6231..7ddfd3ac30 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py @@ -21,6 +21,7 @@ from tensorflow_probability.python.bijectors import shift from tensorflow_probability.python.distributions import bernoulli from tensorflow_probability.python.distributions import deterministic +from tensorflow_probability.python.distributions import independent from tensorflow_probability.python.distributions import joint_distribution_auto_batched as jdab from tensorflow_probability.python.distributions import joint_distribution_named as jdn from tensorflow_probability.python.distributions import linear_gaussian_ssm as lgssm @@ -202,15 +203,12 @@ def transition_fn(_, previous_state): def observation_fn(_, state): return normal.Normal(loc=state['position'], scale=0.1) - # Batch of synthetic observations, . - true_initial_positions = np.random.randn(*batch_shape).astype(self.dtype) - true_velocities = 0.1 * np.random.randn( - *batch_shape).astype(self.dtype) + # Batch of synthetic observations + true_initial_positions = np.random.randn() + true_velocities = 0.1 * np.random.randn() observed_positions = ( - true_velocities * - np.arange(num_timesteps).astype( - self.dtype)[..., tf.newaxis, tf.newaxis] + - true_initial_positions) + true_velocities * np.arange(num_timesteps).astype(self.dtype) + + true_initial_positions) (particles, log_weights, parent_indices, incremental_log_marginal_likelihoods) = self.evaluate( @@ -238,23 +236,14 @@ def observation_fn(_, state): batch_shape[0], num_particles, batch_shape[1]]) + self.assertAllEqual(log_weights.shape, + [num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]]) self.assertAllEqual(incremental_log_marginal_likelihoods.shape, [num_timesteps] + batch_shape) - self.assertAllClose( - self.evaluate( - tf.reduce_sum(tf.exp(log_weights) * - particles['position'], axis=2)), - observed_positions, - atol=0.3) - - velocity_means = tf.reduce_sum(tf.exp(log_weights) * - particles['velocity'], axis=2) - - self.assertAllClose( - self.evaluate(tf.reduce_mean(velocity_means, axis=0)), - true_velocities, atol=0.05) - # Uncertainty in velocity should decrease over time. velocity_stddev = self.evaluate( tf.math.reduce_std(particles['velocity'], axis=2)) @@ -734,6 +723,204 @@ def marginal_log_likelihood(level_scale, noise_scale): self.assertAllNotNone(grads) self.assertAllAssertsNested(self.assertNotAllZero, grads) + def test_smc_squared_rejuvenation_parameters(self): + def particle_dynamics(params, _, previous_state): + reshaped_dist = independent.Independent( + normal.Normal( + previous_state + params[..., tf.newaxis, tf.newaxis] + 1, 0.1 + ), + reinterpreted_batch_ndims=1 + ) + return reshaped_dist + + def rejuvenation_criterion(step, state): + # Rejuvenation every 2 steps + cond = tf.logical_and( + tf.equal(tf.math.mod(step, tf.constant(2)), tf.constant(0)), + tf.not_equal(state.extra[0], tf.constant(0)) + ) + return cond + + observations = tf.stack([tf.range(15, dtype=tf.float32), + tf.range(15, dtype=tf.float32)], axis=1) + + num_outer_particles = 3 + num_inner_particles = 5 + + params, _ = self.evaluate(particle_filter.smc_squared( + observations=observations, + inner_initial_state_prior=lambda _, params: + mvn_diag.MultivariateNormalDiag( + loc=tf.broadcast_to([0., 0.], params.shape + [2]), + scale_diag=tf.broadcast_to([0.01, 0.01], params.shape + [2]) + ), + initial_parameter_prior=normal.Normal(5., 0.5), + num_outer_particles=num_outer_particles, + num_inner_particles=num_inner_particles, + outer_rejuvenation_criterion_fn=rejuvenation_criterion, + inner_transition_fn=lambda params: + lambda _, state: particle_dynamics(params, _, state), + inner_observation_fn=lambda params: ( + lambda _, state: independent.Independent( + normal.Normal(state, 2.), 1) + ), + outer_trace_fn=lambda s, r: ( + s.particles[0], + s.particles[1] + ), + parameter_proposal_kernel=lambda params: normal.Normal(params, 3), + seed=test_util.test_seed() + ) + ) + + abs_params = tf.abs(params) + differences = abs_params[1:] - abs_params[:-1] + mask_parameters = tf.reduce_all(tf.less_equal(differences, 0), axis=0) + + self.assertAllTrue(mask_parameters) + + def test_smc_squared_can_step_dynamics_faster_than_observations(self): + initial_state_prior = jdn.JointDistributionNamed({ + 'position': deterministic.Deterministic([1.]), + 'velocity': deterministic.Deterministic([0.]) + }) + + # Use 100 steps between observations to integrate a simple harmonic + # oscillator. + dt = 0.01 + def simple_harmonic_motion_transition_fn(_, state): + return jdn.JointDistributionNamed({ + 'position': + normal.Normal( + loc=state['position'] + dt * state['velocity'], + scale=dt * 0.01), + 'velocity': + normal.Normal( + loc=state['velocity'] - dt * state['position'], + scale=dt * 0.01) + }) + + def observe_position(_, state): + return normal.Normal(loc=state['position'], scale=0.01) + + particles, lps = self.evaluate(particle_filter.smc_squared( + observations=tf.convert_to_tensor( + [tf.math.cos(0.), tf.math.cos(1.)]), + inner_initial_state_prior=lambda _, params: initial_state_prior, + initial_parameter_prior=deterministic.Deterministic(0.), + num_outer_particles=1, + inner_transition_fn=lambda params: + simple_harmonic_motion_transition_fn, + inner_observation_fn=lambda params: observe_position, + num_inner_particles=1024, + outer_trace_fn=lambda s, r: ( + s.particles[1].particles, + s.particles[3] + ), + num_transitions_per_observation=100, + seed=test_util.test_seed()) + ) + + self.assertAllEqual(ps.shape(particles['position']), tf.constant([102, + 1, + 1024])) + + self.assertAllClose(tf.transpose(np.mean(particles['position'], axis=-1)), + tf.reshape(tf.math.cos(dt * np.arange(102)), [1, -1]), + atol=0.04) + + self.assertAllEqual(ps.shape(lps), [102, 1]) + self.assertGreater(lps[1][0], 1.) + self.assertGreater(lps[-1][0], 3.) + + def test_smc_squared_custom_outer_trace_fn(self): + def trace_fn(state, _): + # Traces the mean and stddev of the particle population at each step. + weights = tf.exp(state[0][1].log_weights[0]) + mean = tf.reduce_sum(weights * state[0][1].particles[0], axis=0) + variance = tf.reduce_sum( + weights * (state[0][1].particles[0] - mean[tf.newaxis, ...]) ** 2) + return {'mean': mean, + 'stddev': tf.sqrt(variance), + # In real usage we would likely not track the particles and + # weights. We keep them here just so we can double-check the + # stats, below. + 'particles': state[0][1].particles[0], + 'weights': weights} + + results = self.evaluate(particle_filter.smc_squared( + observations=tf.convert_to_tensor([1., 3., 5., 7., 9.]), + inner_initial_state_prior=lambda _, params: normal.Normal([0.], 1.), + initial_parameter_prior=deterministic.Deterministic(0.), + inner_transition_fn=lambda params: (lambda _, state: + normal.Normal(state, 1.)), + inner_observation_fn=lambda params: (lambda _, state: + normal.Normal(state, 1.)), + num_inner_particles=1024, + num_outer_particles=1, + outer_trace_fn=trace_fn, + seed=test_util.test_seed()) + ) + + # Verify that posterior means are increasing. + self.assertAllGreater(results['mean'][1:] - results['mean'][:-1], 0.) + + # Check that our traced means and scales match values computed + # by averaging over particles after the fact. + all_means = self.evaluate(tf.reduce_sum( + results['weights'] * results['particles'], axis=1)) + all_variances = self.evaluate( + tf.reduce_sum( + results['weights'] * + (results['particles'] - all_means[..., tf.newaxis])**2, + axis=1)) + self.assertAllClose(results['mean'], all_means) + self.assertAllClose(results['stddev'], np.sqrt(all_variances)) + + def test_smc_squared_indices_to_trace(self): + num_outer_particles = 7 + num_inner_particles = 13 + + def rejuvenation_criterion(step, state): + # Rejuvenation every 3 steps + cond = tf.logical_and( + tf.equal(tf.math.mod(step, tf.constant(3)), tf.constant(0)), + tf.not_equal(state.extra[0], tf.constant(0)) + ) + return tf.cond(cond, lambda: tf.constant(True), + lambda: tf.constant(False)) + + (parameters, weight_parameters, + inner_particles, inner_log_weights, lp) = self.evaluate( + particle_filter.smc_squared( + observations=tf.convert_to_tensor([1., 3., 5., 7., 9.]), + initial_parameter_prior=deterministic.Deterministic(0.), + inner_initial_state_prior=lambda _, params: normal.Normal( + [0.] * num_outer_particles, 1. + ), + inner_transition_fn=lambda params: + (lambda _, state: normal.Normal(state, 10.)), + inner_observation_fn=lambda params: + (lambda _, state: normal.Normal(state, 0.1)), + num_inner_particles=num_inner_particles, + num_outer_particles=num_outer_particles, + outer_rejuvenation_criterion_fn=rejuvenation_criterion, + outer_trace_fn=lambda s, r: ( # pylint: disable=g-long-lambda + s.particles[0], + s.log_weights, + s.particles[1].particles, + s.particles[1].log_weights, + r.accumulated_log_marginal_likelihood), + seed=test_util.test_seed()) + ) + + # TODO: smc_squared at the moment starts his run with an empty step + self.assertAllEqual(ps.shape(parameters), [6, 7]) + self.assertAllEqual(ps.shape(weight_parameters), [6, 7]) + self.assertAllEqual(ps.shape(inner_particles), [6, 7, 13]) + self.assertAllEqual(ps.shape(inner_log_weights), [6, 7, 13]) + self.assertAllEqual(ps.shape(lp), [6]) + # TODO(b/186068104): add tests with dynamic shapes. class ParticleFilterTestFloat32(_ParticleFilterTest): diff --git a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py index 73cb0f8414..36ef979a52 100644 --- a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py +++ b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py @@ -34,7 +34,7 @@ # SequentialMonteCarlo `state` structure. class WeightedParticles(collections.namedtuple( - 'WeightedParticles', ['particles', 'log_weights'])): + 'WeightedParticles', ['particles', 'log_weights', 'extra'])): """Particles with corresponding log weights. This structure serves as the `state` for the `SequentialMonteCarlo` transition @@ -50,11 +50,18 @@ class WeightedParticles(collections.namedtuple( `exp(reduce_logsumexp(log_weights, axis=0)) == 1.`. These must be used in conjunction with `particles` to compute expectations under the target distribution. + extra: a (structure of) Tensor(s) each of shape + `concat([[b1, ..., bN], event_shape])`, where `event_shape` + may differ across component `Tensor`s. This represents global state of the + sampling process that is not associated with individual particles. + Defaults to an empty tuple. In some contexts, particles may be stacked across multiple inference steps, in which case all `Tensor` shapes will be prefixed by an additional dimension of size `num_steps`. """ + def __new__(cls, particles, log_weights, extra=()): + return super().__new__(cls, particles, log_weights, extra) # SequentialMonteCarlo `kernel_results` structure. @@ -292,7 +299,7 @@ def one_step(self, state, kernel_results, seed=None): - tf.gather(normalized_log_weights, 0, axis=self.particles_dim)) do_resample = self.resample_criterion_fn( - state, particles_dim=self.particles_dim) + state, self.particles_dim) # Some batch elements may require resampling and others not, so # we first do the resampling for all elements, then select whether to # use the resampled values for each batch element according to @@ -326,7 +333,8 @@ def one_step(self, state, kernel_results, seed=None): normalized_log_weights)) return (WeightedParticles(particles=resampled_particles, - log_weights=log_weights), + log_weights=log_weights, + extra=state.extra), SequentialMonteCarloResults( steps=kernel_results.steps + 1, parent_indices=resample_indices, diff --git a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py index 2a9302a420..098769f36a 100644 --- a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py +++ b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py @@ -42,7 +42,8 @@ def propose_and_update_log_weights_fn(_, weighted_particles, seed=None): return WeightedParticles( particles=proposed_particles, log_weights=weighted_particles.log_weights + - normal.Normal(loc=-2.6, scale=0.1).log_prob(proposed_particles)) + normal.Normal(loc=-2.6, scale=0.1).log_prob(proposed_particles) + ) num_particles = 16 initial_state = self.evaluate( @@ -50,7 +51,8 @@ def propose_and_update_log_weights_fn(_, weighted_particles, seed=None): particles=tf.random.normal([num_particles], seed=test_util.test_seed()), log_weights=tf.fill([num_particles], - -tf.math.log(float(num_particles))))) + -tf.math.log(float(num_particles))) + )) # Run a couple of steps. seeds = samplers.split_seed( @@ -96,7 +98,8 @@ def testMarginalLikelihoodGradientIsDefined(self): WeightedParticles( particles=samplers.normal([num_particles], seed=seeds[0]), log_weights=tf.fill([num_particles], - -tf.math.log(float(num_particles))))) + -tf.math.log(float(num_particles))) + )) def propose_and_update_log_weights_fn(_, weighted_particles, @@ -110,7 +113,8 @@ def propose_and_update_log_weights_fn(_, particles=proposed_particles, log_weights=(weighted_particles.log_weights + transition_dist.log_prob(proposed_particles) - - proposal_dist.log_prob(proposed_particles))) + proposal_dist.log_prob(proposed_particles)) + ) def marginal_logprob(transition_scale): kernel = SequentialMonteCarlo(