diff --git a/discussion/probabilistic_bundle_adjustment/ProbBundle.ipynb b/discussion/probabilistic_bundle_adjustment/ProbBundle.ipynb index fdbe3b8a7c..f2431f026a 100644 --- a/discussion/probabilistic_bundle_adjustment/ProbBundle.ipynb +++ b/discussion/probabilistic_bundle_adjustment/ProbBundle.ipynb @@ -3377,7 +3377,7 @@ " if (self.auto_resample.value or self.resample) or not jnp.all(\n", " jnp.isfinite(extra.target_log_prob)\n", " ):\n", - " (_, _), ancestor_idx = fun_mc.systematic_resample(\n", + " (_, _), ancestor_idx = fun_mc.resample(\n", " (),\n", " resample_strength * extra.target_log_prob,\n", " jax.random.key(self.step),\n", diff --git a/spinoffs/fun_mc/fun_mc/fun_mc_lib.py b/spinoffs/fun_mc/fun_mc/fun_mc_lib.py index b41fb784fc..9d74e33e16 100644 --- a/spinoffs/fun_mc/fun_mc/fun_mc_lib.py +++ b/spinoffs/fun_mc/fun_mc/fun_mc_lib.py @@ -128,7 +128,6 @@ 'SimpleDualAveragesState', 'splitting_integrator_step', 'State', - 'systematic_resample', 'trace', 'transform_log_prob_fn', 'TransitionOperator', @@ -3466,60 +3465,6 @@ def clip_part(v): ) -@util.named_call -def systematic_resample( - particles: State, - log_weights: FloatArray, - seed: Any, - do_resample: Optional[BooleanArray] = None, -) -> tuple[tuple[State, FloatArray], IntArray]: - """Systematically resamples particles in proportion to their weights. - - This uses the algorithm from [1]. - - Args: - particles: The particles. - log_weights: Un-normalized weights. - seed: PRNG seed. - do_resample: Whether to perform the resample. If None, resampling is - performed unconditionally. - - Returns: - particles_and_weights: tuple of resampled particles and weights. - ancestor_idx: Indices from which the returned particles were sampled from. - - #### References - - [1] Maskell, S., Alun-Jones, B., & Macleod, M. (2006). A Single Instruction - Multiple Data Particle Filter. 2006 IEEE Nonlinear Statistical Signal - Processing Workshop. https://doi.org/10.1109/NSSPW.2006.4378818 - """ - log_weights = jnp.asarray(log_weights) - log_weights = jnp.where( - jnp.isnan(log_weights), - jnp.array(-float('inf'), log_weights.dtype), - log_weights, - ) - probs = jax.nn.softmax(log_weights) - num_particles = probs.shape[0] - - shift = util.random_uniform([], log_weights.dtype, seed) - pie = jnp.cumsum(probs) * num_particles + shift - repeats = jnp.array(util.diff(jnp.floor(pie), prepend=0), jnp.int32) - parent_idxs = util.repeat( - jnp.arange(num_particles), repeats, total_repeat_length=num_particles - ) - if do_resample is not None: - parent_idxs = jnp.where(do_resample, parent_idxs, jnp.arange(num_particles)) - new_particles = util.map_tree(lambda x: x[parent_idxs], particles) - new_log_weights = jnp.full( - log_weights.shape, tfp.math.reduce_logmeanexp(log_weights) - ) - if do_resample is not None: - new_log_weights = jnp.where(do_resample, new_log_weights, log_weights) - return (new_particles, new_log_weights), parent_idxs - - class GeometricAnnealingPathExtra(NamedTuple): """Extra outputs of `geometric_annealing_path`. diff --git a/spinoffs/fun_mc/fun_mc/fun_mc_test.py b/spinoffs/fun_mc/fun_mc/fun_mc_test.py index a89b6c397c..e9ffd9d0cf 100644 --- a/spinoffs/fun_mc/fun_mc/fun_mc_test.py +++ b/spinoffs/fun_mc/fun_mc/fun_mc_test.py @@ -2039,62 +2039,6 @@ def eval_fn(x): self.assertAllCloseNested(value, fn(x)) self.assertAllCloseNested(expected_grad, grad) - def testSystematicResample(self): - probs = self._constant([0.0, 0.5, 0.2, 0.3, 0.0]) - log_weights = jnp.log(probs) - particles = jnp.arange(probs.shape[0]) - - @jax.jit - def body(seed): - seed, resample_seed = util.split_seed(seed, 2) - (new_particles, new_log_weights), _ = fun_mc.systematic_resample( - particles, log_weights, resample_seed - ) - return seed, (new_particles, new_log_weights) - - _, (new_particles, new_log_weights) = fun_mc.trace( - self._make_seed(_test_seed()), body, 1000, trace_mask=(True, False) - ) - - new_particles_probs = jnp.mean( - jnp.array(new_particles[..., jnp.newaxis] == particles, jnp.float32), - (0, 1), - ) - - self.assertAllClose(new_particles_probs, probs, atol=0.05) - self.assertEqual(new_particles_probs[0], 0.0) - self.assertEqual(new_particles_probs[-1], 0.0) - self.assertAllClose( - new_log_weights, - jnp.full(probs.shape, tfp.math.reduce_logmeanexp(log_weights)), - ) - - def testSystematicResampleAncestors(self): - log_weights = self._constant([-float('inf'), 0.0]) - particles = jnp.arange(log_weights.shape[0]) - seed = self._make_seed(_test_seed()) - - (new_particles, new_log_weights), ancestors = fun_mc.systematic_resample( - particles, log_weights, seed=seed - ) - self.assertAllEqual(new_particles, jnp.ones_like(particles)) - self.assertAllEqual(new_log_weights, jnp.log(self._constant([0.5, 0.5]))) - self.assertAllEqual(ancestors, jnp.ones_like(particles)) - - (new_particles, new_log_weights), ancestors = fun_mc.systematic_resample( - particles, log_weights, do_resample=True, seed=seed - ) - self.assertAllEqual(new_particles, jnp.ones_like(particles)) - self.assertAllEqual(new_log_weights, jnp.log(self._constant([0.5, 0.5]))) - self.assertAllEqual(ancestors, jnp.ones_like(particles)) - - (new_particles, new_log_weights), ancestors = fun_mc.systematic_resample( - particles, log_weights, do_resample=False, seed=seed - ) - self.assertAllEqual(new_particles, particles) - self.assertAllEqual(new_log_weights, log_weights) - self.assertAllEqual(ancestors, particles) - @test_util.multi_backend_test(globals(), 'fun_mc_test') class FunMCTest32(FunMCTest): diff --git a/spinoffs/fun_mc/fun_mc/smc.py b/spinoffs/fun_mc/fun_mc/smc.py index 4b08e674b6..8c809cad02 100644 --- a/spinoffs/fun_mc/fun_mc/smc.py +++ b/spinoffs/fun_mc/fun_mc/smc.py @@ -48,6 +48,7 @@ 'conditional_systematic_resampling', 'effective_sample_size_predicate', 'ParticleGatherFn', + 'resample', 'ResamplingPredicate', 'SampleAncestorsFn', 'sequential_monte_carlo_init', @@ -78,6 +79,8 @@ def systematic_resampling( ) -> Int[Array, 'num_particles']: """Generate parent indices via systematic resampling. + This uses the algorithm from [1]. + Args: log_weights: Unnormalized log-scale weights. seed: PRNG seed. @@ -87,13 +90,14 @@ def systematic_resampling( Returns: parent_idxs: parent indices such that the marginal probability that a randomly chosen element will be `i` is equal to `softmax(log_weights)[i]`. + + #### References + + [1] Maskell, S., Alun-Jones, B., & Macleod, M. (2006). A Single Instruction + Multiple Data Particle Filter. 2006 IEEE Nonlinear Statistical Signal + Processing Workshop. https://doi.org/10.1109/NSSPW.2006.4378818 """ shift_seed, permute_seed = util.split_seed(seed, 2) - log_weights = jnp.where( - jnp.isnan(log_weights), - jnp.array(-float('inf'), log_weights.dtype), - log_weights, - ) probs = jax.nn.softmax(log_weights) # A common situation is all -inf log_weights that creats a NaN vector. probs = jnp.where( @@ -146,11 +150,6 @@ def conditional_systematic_resampling( https://www.jstor.org/stable/43590414 """ mixture_seed, shift_seed, permute_seed = util.split_seed(seed, 3) - log_weights = jnp.where( - jnp.isnan(log_weights), - jnp.array(-float('inf'), log_weights.dtype), - log_weights, - ) probs = jax.nn.softmax(log_weights) num_particles = log_weights.shape[0] @@ -377,7 +376,7 @@ def __call__( @types.runtime_typed -def _defalt_pytree_gather( +def _default_pytree_gather( state: State, indices: Int[Array, 'num_particles'], ) -> State: @@ -395,6 +394,75 @@ def _defalt_pytree_gather( return util.map_tree(lambda x: x[indices], state) +@types.runtime_typed +def resample( + state: State, + log_weights: Float[Array, 'num_particles'], + seed: Seed, + do_resample: BoolScalar = True, + sample_ancestors_fn: SampleAncestorsFn = systematic_resampling, + state_gather_fn: ParticleGatherFn[State] = _default_pytree_gather, +) -> tuple[ + tuple[State, Float[Array, 'num_particles']], Int[Array, 'num_particles'] +]: + """Possibly resamples state according to the log_weights. + + The state should represent the same number of particles as implied by the + length of `log_weights`. If resampling occurs, the new log weights are + log-mean-exp of the incoming log weights. Otherwise, they are unchanged. By + default, this function performs systematic resampling. + + Args: + state: The particles. + log_weights: Un-normalized log weights. NaN log weights are treated as -inf. + seed: Random seed. + do_resample: Whether to resample. + sample_ancestors_fn: Ancestor index sampling function. + state_gather_fn: State gather function. + + Returns: + state_and_log_weights: tuple of the resampled state and log weights. + ancestor_idx: Indices that indicate which elements of the original state the + returned state particles were sampled from. + """ + + def do_resample_fn( + state, + log_weights, + seed, + ): + log_weights = jnp.where( + jnp.isnan(log_weights), + jnp.array(-float('inf'), log_weights.dtype), + log_weights, + ) + ancestor_idxs = sample_ancestors_fn(log_weights, seed) + new_state = state_gather_fn(state, ancestor_idxs) + num_particles = log_weights.shape[0] + new_log_weights = jnp.full( + (num_particles,), tfp.math.reduce_logmeanexp(log_weights) + ) + return (new_state, new_log_weights), ancestor_idxs + + def dont_resample_fn( + state, + log_weights, + seed, + ): + del seed + num_particles = log_weights.shape[0] + return (state, log_weights), jnp.arange(num_particles) + + return _smart_cond( + do_resample, + do_resample_fn, + dont_resample_fn, + state, + log_weights, + seed, + ) + + @types.runtime_typed def sequential_monte_carlo_init( state: State, @@ -430,7 +498,7 @@ def sequential_monte_carlo_step( seed: Seed, resampling_pred: ResamplingPredicate = effective_sample_size_predicate, sample_ancestors_fn: SampleAncestorsFn = systematic_resampling, - state_gather_fn: ParticleGatherFn[State] = _defalt_pytree_gather, + state_gather_fn: ParticleGatherFn[State] = _default_pytree_gather, ) -> tuple[ SequentialMonteCarloState[State], SequentialMonteCarloExtra[State, Extra] ]: @@ -461,43 +529,21 @@ def sequential_monte_carlo_step( """ resample_seed, kernel_seed = util.split_seed(seed, 2) - def do_resample( - state, - log_weights, - seed, - ): - ancestor_idxs = sample_ancestors_fn(log_weights, seed) - new_state = state_gather_fn(state, ancestor_idxs) - num_particles = log_weights.shape[0] - new_log_weights = jnp.full( - (num_particles,), tfp.math.reduce_logmeanexp(log_weights) - ) - return (new_state, ancestor_idxs, new_log_weights) - - def dont_resample( - state, - log_weights, - seed, - ): - del seed - num_particles = log_weights.shape[0] - return state, jnp.arange(num_particles), log_weights - # NOTE: We don't explicitly disable resampling at the first step. However, if # we initialize the log weights to zeros, either of # 1. resampling according to the effective sample size criterion and # 2. using systematic resampling effectively disables resampling at the first # step. # First-step resampling can always be forced via the `resampling_pred`. - should_resample = resampling_pred(smc_state) - state_after_resampling, ancestor_idxs, log_weights_after_resampling = ( - _smart_cond( - should_resample, - do_resample, - dont_resample, - smc_state.state, - smc_state.log_weights, - resample_seed, + do_resample = resampling_pred(smc_state) + (state_after_resampling, log_weights_after_resampling), ancestor_idxs = ( + resample( + state=smc_state.state, + log_weights=smc_state.log_weights, + do_resample=do_resample, + seed=resample_seed, + sample_ancestors_fn=sample_ancestors_fn, + state_gather_fn=state_gather_fn, ) ) @@ -516,7 +562,7 @@ def dont_resample( smc_extra = SequentialMonteCarloExtra( incremental_log_weights=incremental_log_weights, kernel_extra=kernel_extra, - resampled=should_resample, + resampled=do_resample, ancestor_idxs=ancestor_idxs, state_after_resampling=state_after_resampling, log_weights_after_resampling=log_weights_after_resampling, @@ -711,6 +757,7 @@ def inner_kernel(state, stage, tlp_fn, seed): ) +@types.runtime_typed def _smart_cond( pred: BoolScalar, true_fn: Callable[..., T], diff --git a/spinoffs/fun_mc/fun_mc/smc_test.py b/spinoffs/fun_mc/fun_mc/smc_test.py index 3db0efe84b..0b5084ddb2 100644 --- a/spinoffs/fun_mc/fun_mc/smc_test.py +++ b/spinoffs/fun_mc/fun_mc/smc_test.py @@ -272,6 +272,33 @@ def kernel(seed): ) self.assertAllClose(rejection_freqs, conditional_freqs, atol=0.05) + def test_resample(self): + state = jnp.array([3, 2, 1, 0]) + log_weights = jnp.array([-jnp.inf, float('NaN'), 1.0, 1.0], self._dtype) + seed = _test_seed() + + (new_state, new_log_weights), ancestor_idxs = smc.resample( + state=state, log_weights=log_weights, seed=seed + ) + + self.assertAllTrue(new_state != 3) + self.assertAllTrue(new_state != 2) + self.assertAllTrue(~jnp.isnan(new_log_weights)) + self.assertAllEqual(3 - new_state, ancestor_idxs) + + def test_resample_but_dont(self): + state = jnp.array([3, 2, 1, 0]) + log_weights = jnp.array([-jnp.inf, float('NaN'), 1.0, 1.0], self._dtype) + seed = _test_seed() + + (new_state, new_log_weights), ancestor_idxs = smc.resample( + state=state, log_weights=log_weights, do_resample=False, seed=seed + ) + + self.assertAllEqual(new_state, state) + self.assertAllEqual(new_log_weights, log_weights) + self.assertAllEqual(ancestor_idxs, jnp.arange(state.shape[0])) + def test_smc_runs_and_shapes_correct(self): num_particles = 3 num_timesteps = 20