diff --git a/tensorflow_probability/python/experimental/mcmc/BUILD b/tensorflow_probability/python/experimental/mcmc/BUILD index 7fb140497b..4056b28d6f 100644 --- a/tensorflow_probability/python/experimental/mcmc/BUILD +++ b/tensorflow_probability/python/experimental/mcmc/BUILD @@ -559,11 +559,13 @@ multi_substrate_py_test( shard_count = 3, deps = [ ":particle_filter", + ":sequential_monte_carlo_kernel", # numpy dep, # tensorflow dep, "//tensorflow_probability/python/bijectors:shift", "//tensorflow_probability/python/distributions:bernoulli", "//tensorflow_probability/python/distributions:deterministic", + "//tensorflow_probability/python/distributions:independent", "//tensorflow_probability/python/distributions:joint_distribution_auto_batched", "//tensorflow_probability/python/distributions:joint_distribution_named", "//tensorflow_probability/python/distributions:linear_gaussian_ssm", diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter.py b/tensorflow_probability/python/experimental/mcmc/particle_filter.py index 1bcbc870f4..50608ffa20 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter.py @@ -16,6 +16,7 @@ import numpy as np import tensorflow.compat.v2 as tf +from tensorflow_probability.python.distributions import uniform from tensorflow_probability.python.experimental.mcmc import sequential_monte_carlo_kernel as smc_kernel from tensorflow_probability.python.experimental.mcmc import weighted_resampling from tensorflow_probability.python.internal import assert_util @@ -549,6 +550,7 @@ def _particle_filter_initial_weighted_particles(observations, initial_state_proposal, num_particles, particles_dim=0, + extra=(), seed=None): """Initialize a set of weighted particles including the first observation.""" # Propose an initial state. @@ -574,14 +576,16 @@ def _particle_filter_initial_weighted_particles(observations, axis=particles_dim) # Return particles weighted by the initial observation. + if observations is not None: + initial_log_weights += _compute_observation_log_weights( + step=0, + particles=initial_state, + observations=observations, + observation_fn=observation_fn, + particles_dim=particles_dim) + return smc_kernel.WeightedParticles( - particles=initial_state, - log_weights=initial_log_weights + _compute_observation_log_weights( - step=0, - particles=initial_state, - observations=observations, - observation_fn=observation_fn, - particles_dim=particles_dim)) + particles=initial_state, log_weights=initial_log_weights, extra=extra) def _particle_filter_propose_and_update_log_weights_fn( @@ -625,7 +629,8 @@ def propose_and_update_log_weights_fn(step, state, seed=None): log_weights=log_weights + _compute_observation_log_weights( step + 1, proposed_particles, observations, observation_fn, num_transitions_per_observation=num_transitions_per_observation, - particles_dim=particles_dim)) + particles_dim=particles_dim), + extra=state.extra) return propose_and_update_log_weights_fn @@ -670,8 +675,18 @@ def _compute_observation_log_weights(step, observation = tf.nest.map_structure( lambda x, step=step: tf.gather(x, observation_idx), observations) - observation = tf.nest.map_structure( - lambda x: tf.expand_dims(x, axis=particles_dim), observation) + # For now, when particles_dim > 0, we do not support the observations + # having batch shape. (This is not needed for SMC^2.) + # + # In JAX, particles_dim > 0 can be handled like: + # vmap(lambda p: observation_fn(step, p).log_prob(observations), + # in_axes=particles_dim, out_axes=particles_dim) + # + # In TF, we could re-arrange dimensions here. Or we could left-pad the + # observations with additional dimensions until they have rank one less + # than the batch-and-event rank of observation_fn(step, particles), and + # then we could expand_dims at dimension particles_dim. + del particles_dim log_weights = observation_fn(step, particles).log_prob(observation) return tf.where(step_has_observation, @@ -741,3 +756,272 @@ def _assert_batch_shape_matches_weights(distribution, weights_shape, diststr): assertions = [assert_util.assert_equal(a, b, message=msg) for a, b in zip(shapes[1:], shapes[:-1])] return assertions + + +def _default_rejuvenation_criterion_fn(step, weighted_particles): + del step + return smc_kernel.ess_below_threshold(weighted_particles, particles_dim=0)[0] + + +def smc_squared( + observations, + initial_parameter_prior, + inner_initial_state_prior, + inner_transition_fn, + observation_fn, + num_outer_particles, + num_inner_particles, + initial_parameter_proposal=None, + parameter_proposal_kernel=None, + inner_initial_state_proposal=None, + inner_proposal_fn=None, + outer_rejuvenation_criterion_fn=_default_rejuvenation_criterion_fn, + inner_resample_criterion_fn=smc_kernel.ess_below_threshold, + inner_resample_fn=weighted_resampling.resample_systematic, + outer_trace_fn=_default_trace_fn, + outer_trace_criterion_fn=_always_trace, + parallel_iterations=1, + num_transitions_per_observation=1, + static_trace_allocation_size=None, + unbiased_gradients=True, + seed=None): + """SMC^2.""" + init_seed, smc_seed = samplers.split_seed(seed, salt='smc_squared') + + num_observation_steps = ps.size0(tf.nest.flatten(observations)[0]) + num_timesteps = ( + 1 + num_transitions_per_observation * (num_observation_steps - 1)) + + initial_state = _smc_squared_intial_weighted_particles( + observations, observation_fn, initial_parameter_prior, + initial_parameter_proposal, num_outer_particles, + inner_initial_state_prior, inner_initial_state_proposal, + num_inner_particles, seed=init_seed) + + outer_propose_and_update_log_weights_fn = ( + _smc_squared_propose_and_update_log_weights_fn( + outer_rejuvenation_criterion_fn=outer_rejuvenation_criterion_fn, + observations=observations, + inner_transition_fn=inner_transition_fn, + inner_proposal_fn=inner_proposal_fn, + observation_fn=observation_fn, + inner_resample_fn=inner_resample_fn, + inner_resample_criterion_fn=inner_resample_criterion_fn, + parameter_proposal_kernel=parameter_proposal_kernel, + initial_parameter_prior=initial_parameter_prior, + num_transitions_per_observation=num_transitions_per_observation, + unbiased_gradients=unbiased_gradients, + inner_initial_state_prior=inner_initial_state_prior, + inner_initial_state_proposal=inner_initial_state_proposal, + num_inner_particles=num_inner_particles, + num_outer_particles=num_outer_particles)) + + return sequential_monte_carlo( + initial_weighted_particles=initial_state, + propose_and_update_log_weights_fn= + outer_propose_and_update_log_weights_fn, + resample_fn=None, + resample_criterion_fn=None, + trace_criterion_fn=outer_trace_criterion_fn, + static_trace_allocation_size=static_trace_allocation_size, + parallel_iterations=parallel_iterations, + unbiased_gradients=unbiased_gradients, + num_steps=num_timesteps, + particles_dim=0, + trace_fn=outer_trace_fn, + seed=smc_seed) + + +def _smc_squared_intial_weighted_particles( + observations, + observation_fn, + initial_parameter_prior, + initial_parameter_proposal, + num_outer_particles, + inner_initial_state_prior, + inner_initial_state_proposal, + num_inner_particles, + seed=None): + """Initialize particles for SMC^2, including the first observation.""" + params_seed, particles_seed = samplers.split_seed( + seed, n=2, salt='smc_squared_init_particles') + + initial_params, initial_log_weights, _ = ( + _particle_filter_initial_weighted_particles( + observations=None, + observation_fn=None, + initial_state_prior=initial_parameter_prior, + initial_state_proposal=initial_parameter_proposal, + num_particles=num_outer_particles, + seed=params_seed)) + + inner_weighted_particles = _particle_filter_initial_weighted_particles( + observations=observations, + observation_fn=observation_fn(initial_params), + initial_state_prior=inner_initial_state_prior(0, initial_params), + initial_state_proposal=(inner_initial_state_proposal(0, initial_params) + if inner_initial_state_proposal is not None + else None), + num_particles=num_inner_particles, + particles_dim=1, + seed=particles_seed) + + inner_filter_results = smc_kernel.SequentialMonteCarlo( + None, None, particles_dim=1).bootstrap_results(inner_weighted_particles) + + return smc_kernel.WeightedParticles( + particles=(initial_params, + inner_weighted_particles, + inner_filter_results.parent_indices, + inner_filter_results.incremental_log_marginal_likelihood, + inner_filter_results.accumulated_log_marginal_likelihood), + log_weights=initial_log_weights, + extra=inner_filter_results.seed) + + +def _smc_squared_propose_and_update_log_weights_fn( + observations, + inner_transition_fn, + inner_proposal_fn, + observation_fn, + initial_parameter_prior, + inner_initial_state_prior, + inner_initial_state_proposal, + num_transitions_per_observation, + inner_resample_fn, + inner_resample_criterion_fn, + outer_rejuvenation_criterion_fn, + unbiased_gradients, + parameter_proposal_kernel, + num_inner_particles, + num_outer_particles): + """Build a function specifying an SMC^2 update step.""" + def _rejuvenate_particles( + outer_params, log_weights, inner_particles, filter_results, seed=None): + seeds = samplers.split_seed(seed, n=4) + step = filter_results.steps + + proposal_kernel = parameter_proposal_kernel(outer_params, log_weights) + rej_outer_params = proposal_kernel(outer_params).sample(seed=seeds[0]) + + rej_inner_particles = _particle_filter_initial_weighted_particles( + observations=observations, + observation_fn=observation_fn(rej_outer_params), + initial_state_prior=inner_initial_state_prior(0, rej_outer_params), + initial_state_proposal=( + inner_initial_state_proposal(0, rej_outer_params) + if inner_initial_state_proposal is not None else None), + num_particles=num_inner_particles, + particles_dim=1, + seed=seeds[1]) + + rej_kernel = smc_kernel.SequentialMonteCarlo( + propose_and_update_log_weights_fn=( + _particle_filter_propose_and_update_log_weights_fn( + observations=observations, + transition_fn=inner_transition_fn(rej_outer_params), + proposal_fn=(inner_proposal_fn(rej_outer_params) + if inner_proposal_fn is not None else None), + observation_fn=observation_fn(rej_outer_params), + particles_dim=1, + num_transitions_per_observation=( + num_transitions_per_observation))), + resample_fn=inner_resample_fn, + resample_criterion_fn=inner_resample_criterion_fn, + particles_dim=1, + unbiased_gradients=unbiased_gradients) + + rej_inner_filter_results = ( + rej_kernel.bootstrap_results(rej_inner_particles)) + + def body(i, state, results): + state, results = rej_kernel.one_step( + state, results, seed=samplers.fold_in(seeds[2], i)) + return (i + 1, state, results) + + (_, rej_inner_particles, rej_filter_results) = tf.while_loop( + lambda i, *_: tf.less_equal(i, step), + body, + [0, rej_inner_particles, rej_inner_filter_results]) + + log_a = (rej_filter_results.accumulated_log_marginal_likelihood + - filter_results.accumulated_log_marginal_likelihood + + initial_parameter_prior.log_prob(rej_outer_params) + - initial_parameter_prior.log_prob(outer_params) + + proposal_kernel(rej_outer_params).log_prob(outer_params) + - proposal_kernel(outer_params).log_prob(rej_outer_params)) + u = uniform.Uniform(0., 1.).sample(num_outer_particles, seed=seeds[3]) + accept = tf.math.log(u) <= log_a + + def _choose(a, b): + if len(a.shape) >= 1 and a.shape[0] == accept.shape[0]: + return mcmc_util.choose(accept, a, b) + return b + outer_params, inner_particles, filter_results = ( + tf.nest.map_structure( + _choose, + (rej_outer_params, rej_inner_particles, rej_filter_results), + (outer_params, inner_particles, filter_results))) + + return (outer_params, tf.zeros_like(log_weights), + inner_particles, filter_results) + + def _outer_propose_and_update_log_weights_fn(step, state, seed=None): + step_seed, rejuvenation_seed = samplers.split_seed(seed, 2) + + outer_params, inner_particles, *_ = state.particles + filter_results = smc_kernel.SequentialMonteCarloResults( + steps=step, + parent_indices=state.particles[2], + incremental_log_marginal_likelihood=state.particles[3], + accumulated_log_marginal_likelihood=state.particles[4], + seed=state.extra) + + kernel = smc_kernel.SequentialMonteCarlo( + propose_and_update_log_weights_fn=( + _particle_filter_propose_and_update_log_weights_fn( + observations=observations, + transition_fn=inner_transition_fn(outer_params), + proposal_fn=(inner_proposal_fn(outer_params) + if inner_proposal_fn is not None else None), + observation_fn=observation_fn(outer_params), + particles_dim=1, + num_transitions_per_observation=( + num_transitions_per_observation))), + resample_fn=inner_resample_fn, + resample_criterion_fn=inner_resample_criterion_fn, + particles_dim=1, + unbiased_gradients=unbiased_gradients) + + inner_particles, filter_results = kernel.one_step( + inner_particles, filter_results, seed=step_seed) + log_weights = ( + state.log_weights + filter_results.incremental_log_marginal_likelihood) + + do_rejuvenation = outer_rejuvenation_criterion_fn( + step, smc_kernel.WeightedParticles( + particles=(outer_params, + inner_particles, + filter_results.parent_indices, + filter_results.incremental_log_marginal_likelihood, + filter_results.accumulated_log_marginal_likelihood), + log_weights=log_weights, + extra=filter_results.seed)) + + (outer_params, log_weights, inner_particles, filter_results) = tf.cond( + do_rejuvenation, + lambda: _rejuvenate_particles( + outer_params, log_weights, inner_particles, filter_results, + seed=rejuvenation_seed), + lambda: (outer_params, log_weights, inner_particles, filter_results)) + + return smc_kernel.WeightedParticles( + particles=(outer_params, + inner_particles, + filter_results.parent_indices, + filter_results.incremental_log_marginal_likelihood, + filter_results.accumulated_log_marginal_likelihood), + log_weights=log_weights, + extra=filter_results.seed) + + return _outer_propose_and_update_log_weights_fn diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py index 6508eb6231..9cb7bbb6d4 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py @@ -21,6 +21,7 @@ from tensorflow_probability.python.bijectors import shift from tensorflow_probability.python.distributions import bernoulli from tensorflow_probability.python.distributions import deterministic +from tensorflow_probability.python.distributions import independent from tensorflow_probability.python.distributions import joint_distribution_auto_batched as jdab from tensorflow_probability.python.distributions import joint_distribution_named as jdn from tensorflow_probability.python.distributions import linear_gaussian_ssm as lgssm @@ -32,6 +33,7 @@ from tensorflow_probability.python.distributions import transformed_distribution from tensorflow_probability.python.distributions import uniform from tensorflow_probability.python.experimental.mcmc import particle_filter +from tensorflow_probability.python.experimental.mcmc import sequential_monte_carlo_kernel as smc_kernel from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import test_util from tensorflow_probability.python.math import gradient @@ -202,15 +204,12 @@ def transition_fn(_, previous_state): def observation_fn(_, state): return normal.Normal(loc=state['position'], scale=0.1) - # Batch of synthetic observations, . - true_initial_positions = np.random.randn(*batch_shape).astype(self.dtype) - true_velocities = 0.1 * np.random.randn( - *batch_shape).astype(self.dtype) + # Batch of synthetic observations + true_initial_positions = np.random.randn() + true_velocities = 0.1 * np.random.randn() observed_positions = ( - true_velocities * - np.arange(num_timesteps).astype( - self.dtype)[..., tf.newaxis, tf.newaxis] + - true_initial_positions) + true_velocities * np.arange(num_timesteps).astype(self.dtype) + + true_initial_positions) (particles, log_weights, parent_indices, incremental_log_marginal_likelihoods) = self.evaluate( @@ -238,6 +237,11 @@ def observation_fn(_, state): batch_shape[0], num_particles, batch_shape[1]]) + self.assertAllEqual(log_weights.shape, + [num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]]) self.assertAllEqual(incremental_log_marginal_likelihoods.shape, [num_timesteps] + batch_shape) @@ -245,14 +249,15 @@ def observation_fn(_, state): self.evaluate( tf.reduce_sum(tf.exp(log_weights) * particles['position'], axis=2)), - observed_positions, + tf.broadcast_to(observed_positions[..., tf.newaxis, tf.newaxis], + [num_timesteps, batch_shape[0], batch_shape[1]]), atol=0.3) velocity_means = tf.reduce_sum(tf.exp(log_weights) * particles['velocity'], axis=2) self.assertAllClose( - self.evaluate(tf.reduce_mean(velocity_means, axis=0)), + self.evaluate(tf.reduce_mean(velocity_means)), true_velocities, atol=0.05) # Uncertainty in velocity should decrease over time. @@ -734,6 +739,55 @@ def marginal_log_likelihood(level_scale, noise_scale): self.assertAllNotNone(grads) self.assertAllAssertsNested(self.assertNotAllZero, grads) + def test_smc_squared_rejuvenation_parameters(self): + def particle_dynamics(params, previous_state): + return independent.Independent( + normal.Normal( + previous_state + params[..., tf.newaxis, tf.newaxis] + 1, 0.1), + reinterpreted_batch_ndims=1) + + num_outer_particles = 10 + num_inner_particles = 16 + num_steps = 45 + observations = tf.stack([tf.range(num_steps, dtype=tf.float32), + tf.range(num_steps, dtype=tf.float32)], axis=1) + + @tf.function(jit_compile=True) + def _run(observations): + return particle_filter.smc_squared( + observations=observations, + inner_initial_state_prior=(lambda _, params: sample_dist_lib.Sample( + normal.Normal(loc=tf.zeros_like(params), + scale=0.01 * tf.ones_like(params)), + sample_shape=[2])), + initial_parameter_prior=normal.Normal(5., 0.5), + num_outer_particles=num_outer_particles, + num_inner_particles=num_inner_particles, + inner_transition_fn=lambda params: ( + lambda _, state: particle_dynamics(params, state)), + observation_fn=lambda params: ( + lambda _, state: independent.Independent( + normal.Normal(state, 2.), 1)), + parameter_proposal_kernel=( + lambda *_: lambda p: normal.Normal(p, scale=1.)), + outer_trace_fn=lambda s, _: (s.particles[0], s.log_weights), + seed=test_util.test_seed()) + + params, log_weights = self.evaluate(_run(observations)) + + self.assertAllEqual([num_steps, num_outer_particles], params.shape) + self.assertAllEqual([num_steps, num_outer_particles], log_weights.shape) + + # Without rejuvenation, we could only estimate our single outer parameter + # from the computed weights of the 10 samples from the prior N(5.0, 0.5). + # But rejuvenation should allow us to correctly estimate that the parameter + # is close to zero. + self.assertAllClose( + 0.0, tf.reduce_sum(tf.exp(log_weights[-1]) * params[-1]), atol=0.1) + self.assertAllGreater( + tf.exp(smc_kernel.log_ess_from_log_weights(log_weights[-1])), + 0.5 * num_outer_particles) + # TODO(b/186068104): add tests with dynamic shapes. class ParticleFilterTestFloat32(_ParticleFilterTest): diff --git a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py index 73cb0f8414..dfc9bbd17a 100644 --- a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py +++ b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py @@ -34,7 +34,7 @@ # SequentialMonteCarlo `state` structure. class WeightedParticles(collections.namedtuple( - 'WeightedParticles', ['particles', 'log_weights'])): + 'WeightedParticles', ['particles', 'log_weights', 'extra'])): """Particles with corresponding log weights. This structure serves as the `state` for the `SequentialMonteCarlo` transition @@ -50,12 +50,20 @@ class WeightedParticles(collections.namedtuple( `exp(reduce_logsumexp(log_weights, axis=0)) == 1.`. These must be used in conjunction with `particles` to compute expectations under the target distribution. + extra: a (structure of) Tensor(s) each of shape + `concat([[b1, ..., bN], event_shape])`, where `event_shape` + may differ across component `Tensor`s. This represents global state of the + sampling process that is not associated with individual particles. + Defaults to an empty tuple. In some contexts, particles may be stacked across multiple inference steps, in which case all `Tensor` shapes will be prefixed by an additional dimension of size `num_steps`. """ + def __new__(cls, particles, log_weights, extra=()): + return super().__new__(cls, particles, log_weights, extra) + # SequentialMonteCarlo `kernel_results` structure. class SequentialMonteCarloResults(collections.namedtuple( @@ -291,42 +299,49 @@ def one_step(self, state, kernel_results, seed=None): tf.gather(state.log_weights, 0, axis=self.particles_dim) - tf.gather(normalized_log_weights, 0, axis=self.particles_dim)) - do_resample = self.resample_criterion_fn( - state, particles_dim=self.particles_dim) - # Some batch elements may require resampling and others not, so - # we first do the resampling for all elements, then select whether to - # use the resampled values for each batch element according to - # `do_resample`. If there were no batching, we might prefer to use - # `tf.cond` to avoid the resampling computation on steps where it's not - # needed---but we're ultimately interested in adaptive resampling - # for statistical (not computational) purposes, so this isn't a - # dealbreaker. - [ - resampled_particles, - resample_indices, - weights_after_resampling - ] = weighted_resampling.resample( - particles=state.particles, - # The `stop_gradient` here does not affect discrete resampling - # (which is nondifferentiable anyway), but avoids canceling out - # the gradient signal from the 'target' log weights, as described in - # Scibior, Masrani, and Wood (2021). - log_weights=tf.stop_gradient(state.log_weights), - resample_fn=self.resample_fn, - target_log_weights=(normalized_log_weights - if self.unbiased_gradients else None), - particles_dim=self.particles_dim, - seed=resample_seed) - (resampled_particles, - resample_indices, - log_weights) = tf.nest.map_structure( - lambda r, p: mcmc_util.choose(do_resample, r, p), - (resampled_particles, resample_indices, weights_after_resampling), - (state.particles, _dummy_indices_like(resample_indices), - normalized_log_weights)) + if self.resample_criterion_fn is not None: + do_resample = self.resample_criterion_fn( + state, self.particles_dim) + # Some batch elements may require resampling and others not, so + # we first do the resampling for all elements, then select whether to + # use the resampled values for each batch element according to + # `do_resample`. If there were no batching, we might prefer to use + # `tf.cond` to avoid the resampling computation on steps where it's + # not needed -- but we're ultimately interested in adaptive resampling + # for statistical (not computational) purposes, so this isn't a + # dealbreaker. + [ + resampled_particles, + resample_indices, + weights_after_resampling + ] = weighted_resampling.resample( + particles=state.particles, + # The `stop_gradient` here does not affect discrete resampling + # (which is nondifferentiable anyway), but avoids canceling out + # the gradient signal from the 'target' log weights, as described + # in Scibior, Masrani, and Wood (2021). + log_weights=tf.stop_gradient(state.log_weights), + resample_fn=self.resample_fn, + target_log_weights=(normalized_log_weights + if self.unbiased_gradients else None), + particles_dim=self.particles_dim, + seed=resample_seed) + (resampled_particles, + resample_indices, + log_weights) = tf.nest.map_structure( + lambda r, p: mcmc_util.choose(do_resample, r, p), + (resampled_particles, resample_indices, + weights_after_resampling), + (state.particles, _dummy_indices_like(resample_indices), + normalized_log_weights)) + else: + resampled_particles = state.particles + resample_indices = _dummy_indices_like(normalized_log_weights) + log_weights = normalized_log_weights return (WeightedParticles(particles=resampled_particles, - log_weights=log_weights), + log_weights=log_weights, + extra=state.extra), SequentialMonteCarloResults( steps=kernel_results.steps + 1, parent_indices=resample_indices,