Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Ensemble MCMC #176

Open
rlouf opened this issue Feb 9, 2022 · 0 comments
Open

Implement Ensemble MCMC #176

rlouf opened this issue Feb 9, 2022 · 0 comments
Labels
help wanted Extra attention is needed sampler Issue related to samplers

Comments

@rlouf
Copy link
Member

rlouf commented Feb 9, 2022

We should implement Ensemble MCMC (https://arxiv.org/abs/1801.09065) as it has the potential to speed-up inference on GPU. @AdrienCorenflos shared the following one-file implementation that whoever implements this in blackjax can reuse:

import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp


def mtm(log_pi, log_p, p, N):
    vmapped_log_p = jax.vmap(log_p, in_axes=[0, None])
    vmapped_log_pi = jax.vmap(log_pi)

    def step(key, x):
        sample_key, auxiliary_key, select_key, accept_key = jax.random.split(key, 4)
        xs_hat = p(sample_key, x, N)
        log_weights = vmapped_log_pi(xs_hat) - vmapped_log_p(xs_hat, x)
        idx = jax.random.categorical(select_key, log_weights, 0)
        x_star = xs_hat[idx]
        zs_hat = p(auxiliary_key, x_star, N - 1)
        zs_hat = jnp.insert(zs_hat, idx, x_star, axis=0)
        auxiliary_log_weights = vmapped_log_pi(zs_hat) - vmapped_log_p(zs_hat, x_star)

        log_alpha = logsumexp(log_weights) - logsumexp(auxiliary_log_weights)
        log_u = jnp.log(jax.random.uniform(accept_key))

        accept = log_u < log_alpha
        x = jax.lax.select(accept, x_star, x)
        return x, accept

    return step


def main():
    import jax.scipy.stats as jstats
    import numpy as np
    np.random.seed(0)
    m = np.random.randn(5)
    P = np.random.randn(5, 15)
    P = P @ P.T

    log_pi = lambda x: jstats.multivariate_normal.logpdf(x, m, P)
    log_p = lambda x_t, x: jstats.laplace.logpdf(x_t, x).sum()
    p = lambda k, x, n_samples: x[None, :] + jax.random.laplace(k, (n_samples, 5))

    N = 100
    T = 100_000

    keys = jax.random.split(jax.random.PRNGKey(42), T)

    x0 = np.random.randn(5)
    step = mtm(log_pi, log_p, p, N)

    def mcmc_body(x, k):
        x, accepted = step(k, x)
        return x, (x, accepted)

    _, (xs, are_accepted) = jax.lax.scan(mcmc_body, x0, keys, T)

    print()
    print(m)
    print(xs[1_000::10].mean(0))
    print()
    print(np.cov(xs[1_000::10], rowvar=False))
    print(P)
    print()
    print(np.mean(are_accepted))


if __name__ == "__main__":
    main()
@rlouf rlouf added help wanted Extra attention is needed sampler Issue related to samplers labels Feb 9, 2022
@rlouf rlouf mentioned this issue Dec 11, 2022
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed sampler Issue related to samplers
Projects
None yet
Development

No branches or pull requests

1 participant