Skip to content

Implement Ensemble MCMC #176

Open
Open
@rlouf

Description

@rlouf

We should implement Ensemble MCMC (https://arxiv.org/abs/1801.09065) as it has the potential to speed-up inference on GPU. @AdrienCorenflos shared the following one-file implementation that whoever implements this in blackjax can reuse:

import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp


def mtm(log_pi, log_p, p, N):
    vmapped_log_p = jax.vmap(log_p, in_axes=[0, None])
    vmapped_log_pi = jax.vmap(log_pi)

    def step(key, x):
        sample_key, auxiliary_key, select_key, accept_key = jax.random.split(key, 4)
        xs_hat = p(sample_key, x, N)
        log_weights = vmapped_log_pi(xs_hat) - vmapped_log_p(xs_hat, x)
        idx = jax.random.categorical(select_key, log_weights, 0)
        x_star = xs_hat[idx]
        zs_hat = p(auxiliary_key, x_star, N - 1)
        zs_hat = jnp.insert(zs_hat, idx, x_star, axis=0)
        auxiliary_log_weights = vmapped_log_pi(zs_hat) - vmapped_log_p(zs_hat, x_star)

        log_alpha = logsumexp(log_weights) - logsumexp(auxiliary_log_weights)
        log_u = jnp.log(jax.random.uniform(accept_key))

        accept = log_u < log_alpha
        x = jax.lax.select(accept, x_star, x)
        return x, accept

    return step


def main():
    import jax.scipy.stats as jstats
    import numpy as np
    np.random.seed(0)
    m = np.random.randn(5)
    P = np.random.randn(5, 15)
    P = P @ P.T

    log_pi = lambda x: jstats.multivariate_normal.logpdf(x, m, P)
    log_p = lambda x_t, x: jstats.laplace.logpdf(x_t, x).sum()
    p = lambda k, x, n_samples: x[None, :] + jax.random.laplace(k, (n_samples, 5))

    N = 100
    T = 100_000

    keys = jax.random.split(jax.random.PRNGKey(42), T)

    x0 = np.random.randn(5)
    step = mtm(log_pi, log_p, p, N)

    def mcmc_body(x, k):
        x, accepted = step(k, x)
        return x, (x, accepted)

    _, (xs, are_accepted) = jax.lax.scan(mcmc_body, x0, keys, T)

    print()
    print(m)
    print(xs[1_000::10].mean(0))
    print()
    print(np.cov(xs[1_000::10], rowvar=False))
    print(P)
    print()
    print(np.mean(are_accepted))


if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    help wantedExtra attention is neededsamplerIssue related to samplers

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions