From 0a56cacf7400449c26f3b52cec2a36d4b3bfebde Mon Sep 17 00:00:00 2001 From: siege Date: Wed, 29 Jan 2025 14:29:49 -0800 Subject: [PATCH] FunMC: Add the AIS kernel for use with SMC. PiperOrigin-RevId: 721109723 --- spinoffs/fun_mc/fun_mc/smc.py | 193 +++++++++++++++++++++++++++++ spinoffs/fun_mc/fun_mc/smc_test.py | 67 ++++++++++ spinoffs/fun_mc/fun_mc/types.py | 21 +++- 3 files changed, 279 insertions(+), 2 deletions(-) diff --git a/spinoffs/fun_mc/fun_mc/smc.py b/spinoffs/fun_mc/fun_mc/smc.py index 1479164c66..4b08e674b6 100644 --- a/spinoffs/fun_mc/fun_mc/smc.py +++ b/spinoffs/fun_mc/fun_mc/smc.py @@ -17,6 +17,7 @@ from typing import Any, Callable, Generic, Protocol, TypeVar, runtime_checkable from fun_mc import backend +from fun_mc import fun_mc_lib as fun_mc from fun_mc import types jax = backend.jax @@ -34,11 +35,16 @@ BoolScalar = types.BoolScalar IntScalar = types.IntScalar FloatScalar = types.FloatScalar +PotentialFn = types.PotentialFn + State = TypeVar('State') Extra = TypeVar('Extra') +KernelExtra = TypeVar('KernelExtra') T = TypeVar('T') __all__ = [ + 'annealed_importance_sampling_kernel', + 'AnnealedImportanceSamplingKernelExtra', 'conditional_systematic_resampling', 'effective_sample_size_predicate', 'ParticleGatherFn', @@ -518,6 +524,193 @@ def dont_resample( return smc_state, smc_extra +@runtime_checkable +class AnnealedImportanceSamplingMCMCKernel(Protocol[State, Extra, KernelExtra]): + """Function that decides whether to resample.""" + + def __call__( + self, + state: State, + step: IntScalar, + target_log_prob_fn: PotentialFn[Extra], + seed: Seed, + ) -> tuple[State, KernelExtra]: + """Return boolean indicating whether to resample. + + Note that resampling happens before stepping the kernel. + + Args: + state: State step `t`. + step: The timestep, `t`. + target_log_prob_fn: Target distribution corresponding to `t`. + seed: PRNG seed. + + Returns: + new_state: New state, targeting `target_log_prob_fn`. + extra: Extra information from the kernel. + """ + + +@util.dataclass +class AnnealedImportanceSamplingKernelExtra(Generic[KernelExtra, Extra]): + """Extra outputs from the AIS kernel. + + Attributes: + kernel_extra: Extra outputs from the inner kernel. + next_state_extra: Extra output from the next step's target log prob + function. + cur_state_extra: Extra output from the current step's target log prob + function. + """ + + kernel_extra: KernelExtra + cur_state_extra: Extra + next_state_extra: Extra + + +@types.runtime_typed +def annealed_importance_sampling_kernel( + state: State, + step: IntScalar, + seed: Seed, + kernel: AnnealedImportanceSamplingMCMCKernel[State, Extra, KernelExtra], + make_target_log_probability_fn: Callable[[IntScalar], PotentialFn[Extra]], +) -> tuple[ + State, + tuple[ + Float[Array, 'num_particles'], + AnnealedImportanceSamplingKernelExtra[KernelExtra, Extra], + ], +]: + """SMC kernel that implements Annealed Importance Sampling. + + Annealed Importance Sampling (AIS)[1] can be interpreted as a special case of + SMC with a particular choice of forward and reverse kernels: + ```none + r_t = k_t(x_{t + 1} | x_t) p_t(x_t) / p_t(x_{t + 1}) + q_t = k_{t - 1}(x_t | x_{t - 1}) + ``` + where `k_t` is an MCMC kernel that has `p_t` invariant. This causes the + incremental weight equation to be particularly simple: + ```none + iw_t = p_t(x_t) / p_{t - 1}(x_t) + ``` + Unfortunately, the reverse kernel is not optimal, so the annealing schedule + needs to be fine. The original formulation from [1] does not do resampling, + but enabling it will usually reduce the variance of the estimator. + + Args: + state: The previous particle state, `x_{t - 1}^{1:K}`. + step: The previous timestep, `t - 1`. + seed: PRNG seed. + kernel: The inner MCMC kernel. It takes the current state, the timestep, the + target distribution and the seed and generates an approximate sample from + `p_t` where `t` is the passed-in timestep. + make_target_log_probability_fn: A function that, given a timestep, returns + the target distribution `p_t` where `t` is the passed-in timestep. + + Returns: + state: The new particles, `x_t^{1:K}`. + extra: A 2-tuple of: + incremental_log_weights: The incremental log weight at timestep t, + `iw_t^{1:K}`. + kernel_extra: Extra information returned by the kernel. + + #### Example + + In this example we estimate the normalizing constant ratio between `tlp_1` + and `tlp_2`. + + ```python + def tlp_1(x): + return -(x**2) / 2.0, () + + def tlp_2(x): + return -((x - 2) ** 2) / 2 / 16.0, () + + @jax.jit + def kernel(smc_state, seed): + smc_seed, seed = jax.random.split(seed, 2) + + def inner_kernel(state, stage, tlp_fn, seed): + f = jnp.array(stage, state.dtype) / num_steps + hmc_state = fun_mc.hamiltonian_monte_carlo_init(state, tlp_fn) + hmc_state, _ = fun_mc.hamiltonian_monte_carlo_step( + hmc_state, + tlp_fn, + step_size=f * 4.0 + (1.0 - f) * 1.0, + num_integrator_steps=1, + seed=seed, + ) + return hmc_state.state, () + + smc_state, _ = smc.sequential_monte_carlo_step( + smc_state, + kernel=functools.partial( + smc.annealed_importance_sampling_kernel, + kernel=inner_kernel, + make_target_log_probability_fn=functools.partial( + fun_mc.geometric_annealing_path, + num_stages=num_steps, + initial_target_log_prob_fn=tlp_1, + final_target_log_prob_fn=tlp_2, + ), + ), + seed=smc_seed, + ) + + return (smc_state, seed), () + + num_steps = 100 + num_particles = 400 + init_seed, seed = jax.random.split(jax.random.PRNGKey(0)) + init_state = jax.random.normal(init_seed, [num_particles]) + + (smc_state, _), _ = fun_mc.trace( + ( + smc.sequential_monte_carlo_init( + init_state, + weight_dtype=self._dtype, + ), + smc_seed, + ), + kernel, + num_steps, + ) + + weights = jnp.exp(smc_state.log_weights) + # Should be close to 4. + print(estimated z2/z1, weights.mean()) + # Should be close to 2. + print(estimated mean, (jax.nn.softmax(smc_state.log_weights) + * smc_state.state).sum()) + ``` + + #### References + + [1]: Neal, Radford M. (1998) Annealed Importance Sampling. + https://arxiv.org/abs/physics/9803008 + """ + new_state, kernel_extra = kernel( + state, step, make_target_log_probability_fn(step), seed + ) + tlp_num, num_extra = fun_mc.call_potential_fn( + make_target_log_probability_fn(step + 1), new_state + ) + tlp_denom, denom_extra = fun_mc.call_potential_fn( + make_target_log_probability_fn(step), new_state + ) + extra = AnnealedImportanceSamplingKernelExtra( + kernel_extra=kernel_extra, + cur_state_extra=denom_extra, + next_state_extra=num_extra, + ) + return new_state, ( + tlp_num - tlp_denom, + extra, + ) + + def _smart_cond( pred: BoolScalar, true_fn: Callable[..., T], diff --git a/spinoffs/fun_mc/fun_mc/smc_test.py b/spinoffs/fun_mc/fun_mc/smc_test.py index 1a7962e3fa..3db0efe84b 100644 --- a/spinoffs/fun_mc/fun_mc/smc_test.py +++ b/spinoffs/fun_mc/fun_mc/smc_test.py @@ -1184,6 +1184,73 @@ def kernel(smc_state, seed): self.assertAllClose(gt_log_evidence, log_evidence, rtol=0.01) self.assertAllClose(gt_log_evidence, log_evidence, atol=0.2) + def test_annealed_importance_sampling(self): + def tlp_1(x): + return -0.5 * x**2, () + + def tlp_2(x): + return (-0.5 * (x - 2) ** 2) / 16.0, () + + @jax.jit + def kernel(smc_state, seed): + smc_seed, seed = util.split_seed(seed, 2) + + def inner_kernel(state, step, tlp_fn, seed): + f = jnp.array(step, state.dtype) / num_steps + hmc_state = fun_mc.hamiltonian_monte_carlo_init(state, tlp_fn) + hmc_state, _ = fun_mc.hamiltonian_monte_carlo_step( + hmc_state, + tlp_fn, + step_size=f * 4.0 + (1.0 - f) * 1.0, + num_integrator_steps=1, + seed=seed, + ) + return hmc_state.state, () + + smc_state, _ = smc.sequential_monte_carlo_step( + smc_state, + kernel=functools.partial( + smc.annealed_importance_sampling_kernel, + kernel=inner_kernel, + make_target_log_probability_fn=functools.partial( + fun_mc.geometric_annealing_path, + num_stages=num_steps, + initial_target_log_prob_fn=tlp_1, + final_target_log_prob_fn=tlp_2, + ), + ), + seed=smc_seed, + ) + + return (smc_state, seed), () + + num_steps = 1000 + num_particles = 1000 + init_seed, smc_seed = util.split_seed(_test_seed(), 2) + init_state = util.random_normal([num_particles], self._dtype, init_seed) + + (smc_state, _), _ = fun_mc.trace( + ( + smc.sequential_monte_carlo_init( + init_state, + weight_dtype=self._dtype, + ), + smc_seed, + ), + kernel, + num_steps, + ) + + weights = jnp.exp(smc_state.log_weights) + # 4 because tlp_2 has stddev of 4 while tlp_1 has stddev of 1. + self.assertAllClose(4.0, jnp.mean(weights), atol=0.1) + + normed_weights = jax.nn.softmax(smc_state.log_weights) + mean = jnp.sum(normed_weights * smc_state.state) + variance = jnp.sum(normed_weights * (smc_state.state - mean) ** 2) + self.assertAllClose(2.0, mean, atol=0.3) + self.assertAllClose(16.0, variance, rtol=0.2) + @test_util.multi_backend_test(globals(), 'smc_test') class SMCTest32(SMCTest): diff --git a/spinoffs/fun_mc/fun_mc/types.py b/spinoffs/fun_mc/fun_mc/types.py index e40eeb8ba5..22aa0e564c 100644 --- a/spinoffs/fun_mc/fun_mc/types.py +++ b/spinoffs/fun_mc/fun_mc/types.py @@ -14,7 +14,7 @@ # ============================================================================ """Various types used in FunMC.""" -from typing import Callable, TypeAlias, TypeVar +from typing import Callable, Protocol, TypeAlias, TypeVar, runtime_checkable import jaxtyping from fun_mc import backend @@ -29,6 +29,7 @@ 'FloatScalar', 'Int', 'IntScalar', + 'PotentialFn', 'runtime_typed', 'Seed', ] @@ -42,8 +43,24 @@ BoolScalar: TypeAlias = bool | Bool[Array, ''] IntScalar: TypeAlias = int | Int[Array, ''] FloatScalar: TypeAlias = float | Float[Array, ''] - F = TypeVar('F', bound=Callable) +_Extra = TypeVar('_Extra') + + +@runtime_checkable +class PotentialFn(Protocol[_Extra]): + """Maps state to an array of float. + + If the state has leading dimension, the same dimension is present in the + returned values as well. + """ + + def __call__( + self, + *args, + **kwargs, + ) -> tuple[Float[Array, '...'], _Extra]: + """Potential function.""" def runtime_typed(f: F) -> F: