Skip to content

Commit

Permalink
FunMC: Remove the old AIS implementation.
Browse files Browse the repository at this point in the history
Use the SMC-based one instead. One feature that is lost in this process is that singleton particles are now required to have a leading singleton dimension, even if no resampling is done.

PiperOrigin-RevId: 721110130
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Jan 29, 2025
1 parent 54613b9 commit f3147cd
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 291 deletions.
229 changes: 2 additions & 227 deletions spinoffs/fun_mc/fun_mc/fun_mc_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,6 @@
'adam_step',
'AdamExtra',
'AdamState',
'annealed_importance_sampling_init',
'annealed_importance_sampling_resample',
'annealed_importance_sampling_step',
'AnnealedImportanceSamplingExtra',
'AnnealedImportanceSamplingState',
'blanes_3_stage_step',
'blanes_4_stage_step',
'call_fn',
Expand Down Expand Up @@ -3471,203 +3466,6 @@ def clip_part(v):
)


TransitionExtra = ArrayNest
LogWeightExtra = ArrayNest
ResampleExtra = ArrayNest
Stage = IntArray


class AnnealedImportanceSamplingState(NamedTuple):
"""State of the annealed importance sampler.
Attributes:
state: The particles.
log_weight: Log weight of the particles.
stage: Current stage.
"""

state: Any
log_weight: FloatArray
stage: Stage

def ess(self) -> FloatArray:
"""Estimates the effective sample size."""
norm_weights = jax.nn.softmax(self.log_weight)
return 1.0 / jnp.sum(norm_weights**2)


class AnnealedImportanceSamplingExtra(NamedTuple):
"""Extra outputs from the annealed importance sampler.
Attributes:
stage_log_weight: Incremental log weight for this stage.
transition_extra: Extra outputs from the transition operator.
log_weight_extra: Extra outputs from log-weight computation.
"""

stage_log_weight: FloatArray
transition_extra: TransitionExtra
log_weight_extra: LogWeightExtra


@util.named_call
def annealed_importance_sampling_init(
state: State, initial_log_weight: FloatArray, initial_stage: int | Stage = 0
) -> AnnealedImportanceSamplingState:
"""Initializes the annealed importance sampler.
Args:
state: Initial state.
initial_log_weight: Initial log weight.
initial_stage: Initial stage.
Returns:
`AnnealedImportanceSamplingState`.
"""
state = util.map_tree(jnp.asarray, state)
return AnnealedImportanceSamplingState(
state=state,
log_weight=jnp.asarray(initial_log_weight),
stage=jnp.asarray(initial_stage, jnp.int32),
)


@util.named_call
def annealed_importance_sampling_step(
ais_state: AnnealedImportanceSamplingState,
transition_operator: Callable[
[State, Stage, Callable[[State], tuple[FloatArray, StateExtra]]],
tuple[State, TransitionExtra],
],
make_tlp_fn: Callable[[Stage], PotentialFn],
log_weight_fn: Optional[
Callable[[State, State, Stage, TransitionExtra], tuple[FloatArray, Any]]
] = None,
) -> tuple[AnnealedImportanceSamplingState, AnnealedImportanceSamplingExtra]:
"""Takes a step of the annealed importance sampler (AIS).
AIS is a simple Sequential Monte Carlo (SMC) sampler that generates weighted
samples from a target distribution by annealing to it along a schedule of
simpler distributions. If self-normalized, the weights are biased. The mean of
the unnormalized weights, however, is an unbiased estimator of the ratio of
the normalizing constants of the target and the initial distributions.
In addition to classic AIS, this function also allows extending it to a more
general SMC sampler by overriding `log_weight_fn` (thus allowing custom
backward kernel) and also resampling (via the `resample` method on the state
object).
#### Example
In this example we estimate the normalizing constant ratio between `tlp_1`
and `tlp_2`.
```python
def tlp_1(x):
return -x**2 / 2., ()
def tlp_2(x):
return -(x - 2)**2 / 2 / 16., ()
def kernel(ais_state, seed):
hmc_seed, resample_seed, seed = jax.random.split(seed, 3)
ais_state, _ = fun_mc.annealed_importance_sampling_resample(
ais_state,
seed=resample_seed)
def transition_operator(state, stage, tlp_fn):
f = stage / num_stages
hmc_state = fun_mc.hamiltonian_monte_carlo_init(state, tlp_fn)
hmc_state, hmc_extra = fun_mc.hamiltonian_monte_carlo_step(
hmc_state,
tlp_fn,
step_size=f * 4. + (1. - f) * 1.,
num_integrator_steps=1,
seed=hmc_seed)
return hmc_state.state, ()
ais_state, ais_extra = fun_mc.annealed_importance_sampling_step(
ais_state, transition_operator,
functools.partial(
fun_mc.geometric_annealing_path,
num_stages=num_stages,
initial_target_log_prob_fn=tlp_1,
final_target_log_prob_fn=tlp_2,
))
return (ais_state, seed), ()
num_stages = 200
num_particles = 200
init_seed, seed = jax.random.split(jax.random.PRNGKey(0))
init_state = jax.random.normal(init_seed, [num_particles])
(ais_state, _), _ = fun_mc.trace(
(fun_mc.annealed_importance_sampling_init(
init_state, jnp.zeros([num_particles])), seed),
kernel,
num_stages,
)
weights = jnp.exp(ais_state.log_weight)
# Should be close to 4.
print(estimated z2/z1, weights.mean())
# Should be close to 2.
print(estimated mean, (jax.nn.softmax(ais_state.log_weight)
* ais_state.state).sum())
```
Args:
ais_state: `AnnealedImportanceSamplingState`
transition_operator: The forward MCMC kernel. It has signature: `(state,
stage, tlp_fn) -> (state, extra)`.
make_tlp_fn: A function which, given the stage index, returns an annealed
density.
log_weight_fn: Optional function to compute the incremental log weight of a
stage. The default uses a naive implementation of the usual AIS
incremental weight computation.
Returns:
ais_state: `AnnealedImportanceSamplingState`
ais_extra: `AnnealedImportanceSamplingExtra`
"""

if log_weight_fn is None:

def _default_log_weight_fn(old_state, new_state, stage, transition_extra):
del old_state, transition_extra
tlp_denom, denom_extra = call_potential_fn(make_tlp_fn(stage), new_state)
tlp_num, num_extra = call_potential_fn(make_tlp_fn(stage + 1), new_state)
stage_log_weight = tlp_num - tlp_denom
log_weight_extra = (num_extra, denom_extra)
return stage_log_weight, log_weight_extra

log_weight_fn = _default_log_weight_fn

new_state, transition_extra = transition_operator(
ais_state.state, ais_state.stage, make_tlp_fn(ais_state.stage)
)

stage_log_weight, log_weight_extra = log_weight_fn(
ais_state.state, new_state, ais_state.stage, transition_extra
)

ais_state = ais_state._replace(
state=new_state,
log_weight=ais_state.log_weight + stage_log_weight,
stage=ais_state.stage + 1,
)
extra = AnnealedImportanceSamplingExtra(
stage_log_weight=stage_log_weight,
transition_extra=transition_extra,
log_weight_extra=log_weight_extra,
)
return ais_state, extra


@util.named_call
def systematic_resample(
particles: State,
Expand Down Expand Up @@ -3722,29 +3520,6 @@ def systematic_resample(
return (new_particles, new_log_weights), parent_idxs


@util.named_call
def annealed_importance_sampling_resample(
ais_state: AnnealedImportanceSamplingState,
resample_fn: Callable[
[State, FloatArray, Any, BooleanArray],
tuple[tuple[State, jnp.ndarray], ResampleExtra],
] = systematic_resample,
min_ess_threshold: float | FloatArray = 0.5,
seed: Any = None,
) -> tuple[AnnealedImportanceSamplingState, ResampleExtra]:
"""Resamples the particles in AnnealedImportanceSamplingState."""

log_weight = jnp.asarray(ais_state.log_weight)
do_resample = (
ais_state.ess()
< jnp.array(log_weight.shape[0], log_weight.dtype) * min_ess_threshold
)
(state, log_weight), extra = resample_fn(
ais_state.state, ais_state.log_weight, seed, do_resample
)
return ais_state._replace(state=state, log_weight=log_weight), extra


class GeometricAnnealingPathExtra(NamedTuple):
"""Extra outputs of `geometric_annealing_path`.
Expand All @@ -3760,8 +3535,8 @@ class GeometricAnnealingPathExtra(NamedTuple):


def geometric_annealing_path(
stage: Stage,
num_stages: Stage,
stage: IntArray,
num_stages: IntArray,
initial_target_log_prob_fn: PotentialFn,
final_target_log_prob_fn: PotentialFn,
fraction_fn: Optional[Callable[[FloatArray], jnp.ndarray]] = None,
Expand Down
64 changes: 0 additions & 64 deletions spinoffs/fun_mc/fun_mc/fun_mc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2095,70 +2095,6 @@ def testSystematicResampleAncestors(self):
self.assertAllEqual(new_log_weights, log_weights)
self.assertAllEqual(ancestors, particles)

def testAIS(self):
def tlp_1(x):
return -(x**2) / 2.0, ()

def tlp_2(x):
return -((x - 2) ** 2) / 2 / 16.0, ()

@jax.jit
def kernel(ais_state, seed):
hmc_seed, resample_seed, seed = util.split_seed(seed, 3)

ais_state, _ = fun_mc.annealed_importance_sampling_resample(
ais_state, seed=resample_seed
)

def transition_operator(state, stage, tlp_fn):
f = jnp.array(stage, state.dtype) / num_stages
hmc_state = fun_mc.hamiltonian_monte_carlo_init(state, tlp_fn)
hmc_state, _ = fun_mc.hamiltonian_monte_carlo_step(
hmc_state,
tlp_fn,
step_size=f * 4.0 + (1.0 - f) * 1.0,
num_integrator_steps=1,
seed=hmc_seed,
)
return hmc_state.state, ()

ais_state, _ = fun_mc.annealed_importance_sampling_step(
ais_state,
transition_operator,
functools.partial(
fun_mc.geometric_annealing_path,
num_stages=num_stages,
initial_target_log_prob_fn=tlp_1,
final_target_log_prob_fn=tlp_2,
),
)

return (ais_state, seed), ()

num_stages = 100
num_particles = 400
init_seed, seed = util.split_seed(self._make_seed(_test_seed()), 2)
init_state = util.random_normal([num_particles], self._dtype, init_seed)

(ais_state, _), _ = fun_mc.trace(
(
fun_mc.annealed_importance_sampling_init(
init_state, jnp.zeros([num_particles], self._dtype)
),
seed,
),
kernel,
num_stages,
)

weights = jnp.exp(ais_state.log_weight)
self.assertAllClose(4.0, jnp.mean(weights), atol=0.7)
self.assertAllClose(
2.0,
jnp.sum(jax.nn.softmax(ais_state.log_weight) * ais_state.state),
atol=0.8,
)


@test_util.multi_backend_test(globals(), 'fun_mc_test')
class FunMCTest32(FunMCTest):
Expand Down

0 comments on commit f3147cd

Please sign in to comment.