diff --git a/spinoffs/fun_mc/fun_mc/fun_mc_lib.py b/spinoffs/fun_mc/fun_mc/fun_mc_lib.py index 8d406ea43c..b41fb784fc 100644 --- a/spinoffs/fun_mc/fun_mc/fun_mc_lib.py +++ b/spinoffs/fun_mc/fun_mc/fun_mc_lib.py @@ -53,11 +53,6 @@ 'adam_step', 'AdamExtra', 'AdamState', - 'annealed_importance_sampling_init', - 'annealed_importance_sampling_resample', - 'annealed_importance_sampling_step', - 'AnnealedImportanceSamplingExtra', - 'AnnealedImportanceSamplingState', 'blanes_3_stage_step', 'blanes_4_stage_step', 'call_fn', @@ -3471,203 +3466,6 @@ def clip_part(v): ) -TransitionExtra = ArrayNest -LogWeightExtra = ArrayNest -ResampleExtra = ArrayNest -Stage = IntArray - - -class AnnealedImportanceSamplingState(NamedTuple): - """State of the annealed importance sampler. - - Attributes: - state: The particles. - log_weight: Log weight of the particles. - stage: Current stage. - """ - - state: Any - log_weight: FloatArray - stage: Stage - - def ess(self) -> FloatArray: - """Estimates the effective sample size.""" - norm_weights = jax.nn.softmax(self.log_weight) - return 1.0 / jnp.sum(norm_weights**2) - - -class AnnealedImportanceSamplingExtra(NamedTuple): - """Extra outputs from the annealed importance sampler. - - Attributes: - stage_log_weight: Incremental log weight for this stage. - transition_extra: Extra outputs from the transition operator. - log_weight_extra: Extra outputs from log-weight computation. - """ - - stage_log_weight: FloatArray - transition_extra: TransitionExtra - log_weight_extra: LogWeightExtra - - -@util.named_call -def annealed_importance_sampling_init( - state: State, initial_log_weight: FloatArray, initial_stage: int | Stage = 0 -) -> AnnealedImportanceSamplingState: - """Initializes the annealed importance sampler. - - Args: - state: Initial state. - initial_log_weight: Initial log weight. - initial_stage: Initial stage. - - Returns: - `AnnealedImportanceSamplingState`. - """ - state = util.map_tree(jnp.asarray, state) - return AnnealedImportanceSamplingState( - state=state, - log_weight=jnp.asarray(initial_log_weight), - stage=jnp.asarray(initial_stage, jnp.int32), - ) - - -@util.named_call -def annealed_importance_sampling_step( - ais_state: AnnealedImportanceSamplingState, - transition_operator: Callable[ - [State, Stage, Callable[[State], tuple[FloatArray, StateExtra]]], - tuple[State, TransitionExtra], - ], - make_tlp_fn: Callable[[Stage], PotentialFn], - log_weight_fn: Optional[ - Callable[[State, State, Stage, TransitionExtra], tuple[FloatArray, Any]] - ] = None, -) -> tuple[AnnealedImportanceSamplingState, AnnealedImportanceSamplingExtra]: - """Takes a step of the annealed importance sampler (AIS). - - AIS is a simple Sequential Monte Carlo (SMC) sampler that generates weighted - samples from a target distribution by annealing to it along a schedule of - simpler distributions. If self-normalized, the weights are biased. The mean of - the unnormalized weights, however, is an unbiased estimator of the ratio of - the normalizing constants of the target and the initial distributions. - - In addition to classic AIS, this function also allows extending it to a more - general SMC sampler by overriding `log_weight_fn` (thus allowing custom - backward kernel) and also resampling (via the `resample` method on the state - object). - - #### Example - - In this example we estimate the normalizing constant ratio between `tlp_1` - and `tlp_2`. - - ```python - def tlp_1(x): - return -x**2 / 2., () - - def tlp_2(x): - return -(x - 2)**2 / 2 / 16., () - - def kernel(ais_state, seed): - - hmc_seed, resample_seed, seed = jax.random.split(seed, 3) - - ais_state, _ = fun_mc.annealed_importance_sampling_resample( - ais_state, - seed=resample_seed) - - def transition_operator(state, stage, tlp_fn): - f = stage / num_stages - hmc_state = fun_mc.hamiltonian_monte_carlo_init(state, tlp_fn) - hmc_state, hmc_extra = fun_mc.hamiltonian_monte_carlo_step( - hmc_state, - tlp_fn, - step_size=f * 4. + (1. - f) * 1., - num_integrator_steps=1, - seed=hmc_seed) - return hmc_state.state, () - - ais_state, ais_extra = fun_mc.annealed_importance_sampling_step( - ais_state, transition_operator, - functools.partial( - fun_mc.geometric_annealing_path, - num_stages=num_stages, - initial_target_log_prob_fn=tlp_1, - final_target_log_prob_fn=tlp_2, - )) - - return (ais_state, seed), () - - - num_stages = 200 - num_particles = 200 - init_seed, seed = jax.random.split(jax.random.PRNGKey(0)) - init_state = jax.random.normal(init_seed, [num_particles]) - - (ais_state, _), _ = fun_mc.trace( - (fun_mc.annealed_importance_sampling_init( - init_state, jnp.zeros([num_particles])), seed), - kernel, - num_stages, - ) - - weights = jnp.exp(ais_state.log_weight) - # Should be close to 4. - print(estimated z2/z1, weights.mean()) - # Should be close to 2. - print(estimated mean, (jax.nn.softmax(ais_state.log_weight) - * ais_state.state).sum()) - ``` - - Args: - ais_state: `AnnealedImportanceSamplingState` - transition_operator: The forward MCMC kernel. It has signature: `(state, - stage, tlp_fn) -> (state, extra)`. - make_tlp_fn: A function which, given the stage index, returns an annealed - density. - log_weight_fn: Optional function to compute the incremental log weight of a - stage. The default uses a naive implementation of the usual AIS - incremental weight computation. - - Returns: - ais_state: `AnnealedImportanceSamplingState` - ais_extra: `AnnealedImportanceSamplingExtra` - """ - - if log_weight_fn is None: - - def _default_log_weight_fn(old_state, new_state, stage, transition_extra): - del old_state, transition_extra - tlp_denom, denom_extra = call_potential_fn(make_tlp_fn(stage), new_state) - tlp_num, num_extra = call_potential_fn(make_tlp_fn(stage + 1), new_state) - stage_log_weight = tlp_num - tlp_denom - log_weight_extra = (num_extra, denom_extra) - return stage_log_weight, log_weight_extra - - log_weight_fn = _default_log_weight_fn - - new_state, transition_extra = transition_operator( - ais_state.state, ais_state.stage, make_tlp_fn(ais_state.stage) - ) - - stage_log_weight, log_weight_extra = log_weight_fn( - ais_state.state, new_state, ais_state.stage, transition_extra - ) - - ais_state = ais_state._replace( - state=new_state, - log_weight=ais_state.log_weight + stage_log_weight, - stage=ais_state.stage + 1, - ) - extra = AnnealedImportanceSamplingExtra( - stage_log_weight=stage_log_weight, - transition_extra=transition_extra, - log_weight_extra=log_weight_extra, - ) - return ais_state, extra - - @util.named_call def systematic_resample( particles: State, @@ -3722,29 +3520,6 @@ def systematic_resample( return (new_particles, new_log_weights), parent_idxs -@util.named_call -def annealed_importance_sampling_resample( - ais_state: AnnealedImportanceSamplingState, - resample_fn: Callable[ - [State, FloatArray, Any, BooleanArray], - tuple[tuple[State, jnp.ndarray], ResampleExtra], - ] = systematic_resample, - min_ess_threshold: float | FloatArray = 0.5, - seed: Any = None, -) -> tuple[AnnealedImportanceSamplingState, ResampleExtra]: - """Resamples the particles in AnnealedImportanceSamplingState.""" - - log_weight = jnp.asarray(ais_state.log_weight) - do_resample = ( - ais_state.ess() - < jnp.array(log_weight.shape[0], log_weight.dtype) * min_ess_threshold - ) - (state, log_weight), extra = resample_fn( - ais_state.state, ais_state.log_weight, seed, do_resample - ) - return ais_state._replace(state=state, log_weight=log_weight), extra - - class GeometricAnnealingPathExtra(NamedTuple): """Extra outputs of `geometric_annealing_path`. @@ -3760,8 +3535,8 @@ class GeometricAnnealingPathExtra(NamedTuple): def geometric_annealing_path( - stage: Stage, - num_stages: Stage, + stage: IntArray, + num_stages: IntArray, initial_target_log_prob_fn: PotentialFn, final_target_log_prob_fn: PotentialFn, fraction_fn: Optional[Callable[[FloatArray], jnp.ndarray]] = None, diff --git a/spinoffs/fun_mc/fun_mc/fun_mc_test.py b/spinoffs/fun_mc/fun_mc/fun_mc_test.py index 41d2dbe32a..a89b6c397c 100644 --- a/spinoffs/fun_mc/fun_mc/fun_mc_test.py +++ b/spinoffs/fun_mc/fun_mc/fun_mc_test.py @@ -2095,70 +2095,6 @@ def testSystematicResampleAncestors(self): self.assertAllEqual(new_log_weights, log_weights) self.assertAllEqual(ancestors, particles) - def testAIS(self): - def tlp_1(x): - return -(x**2) / 2.0, () - - def tlp_2(x): - return -((x - 2) ** 2) / 2 / 16.0, () - - @jax.jit - def kernel(ais_state, seed): - hmc_seed, resample_seed, seed = util.split_seed(seed, 3) - - ais_state, _ = fun_mc.annealed_importance_sampling_resample( - ais_state, seed=resample_seed - ) - - def transition_operator(state, stage, tlp_fn): - f = jnp.array(stage, state.dtype) / num_stages - hmc_state = fun_mc.hamiltonian_monte_carlo_init(state, tlp_fn) - hmc_state, _ = fun_mc.hamiltonian_monte_carlo_step( - hmc_state, - tlp_fn, - step_size=f * 4.0 + (1.0 - f) * 1.0, - num_integrator_steps=1, - seed=hmc_seed, - ) - return hmc_state.state, () - - ais_state, _ = fun_mc.annealed_importance_sampling_step( - ais_state, - transition_operator, - functools.partial( - fun_mc.geometric_annealing_path, - num_stages=num_stages, - initial_target_log_prob_fn=tlp_1, - final_target_log_prob_fn=tlp_2, - ), - ) - - return (ais_state, seed), () - - num_stages = 100 - num_particles = 400 - init_seed, seed = util.split_seed(self._make_seed(_test_seed()), 2) - init_state = util.random_normal([num_particles], self._dtype, init_seed) - - (ais_state, _), _ = fun_mc.trace( - ( - fun_mc.annealed_importance_sampling_init( - init_state, jnp.zeros([num_particles], self._dtype) - ), - seed, - ), - kernel, - num_stages, - ) - - weights = jnp.exp(ais_state.log_weight) - self.assertAllClose(4.0, jnp.mean(weights), atol=0.7) - self.assertAllClose( - 2.0, - jnp.sum(jax.nn.softmax(ais_state.log_weight) * ais_state.state), - atol=0.8, - ) - @test_util.multi_backend_test(globals(), 'fun_mc_test') class FunMCTest32(FunMCTest):