Open
Description
We should implement Ensemble MCMC (https://arxiv.org/abs/1801.09065) as it has the potential to speed-up inference on GPU. @AdrienCorenflos shared the following one-file implementation that whoever implements this in blackjax
can reuse:
import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp
def mtm(log_pi, log_p, p, N):
vmapped_log_p = jax.vmap(log_p, in_axes=[0, None])
vmapped_log_pi = jax.vmap(log_pi)
def step(key, x):
sample_key, auxiliary_key, select_key, accept_key = jax.random.split(key, 4)
xs_hat = p(sample_key, x, N)
log_weights = vmapped_log_pi(xs_hat) - vmapped_log_p(xs_hat, x)
idx = jax.random.categorical(select_key, log_weights, 0)
x_star = xs_hat[idx]
zs_hat = p(auxiliary_key, x_star, N - 1)
zs_hat = jnp.insert(zs_hat, idx, x_star, axis=0)
auxiliary_log_weights = vmapped_log_pi(zs_hat) - vmapped_log_p(zs_hat, x_star)
log_alpha = logsumexp(log_weights) - logsumexp(auxiliary_log_weights)
log_u = jnp.log(jax.random.uniform(accept_key))
accept = log_u < log_alpha
x = jax.lax.select(accept, x_star, x)
return x, accept
return step
def main():
import jax.scipy.stats as jstats
import numpy as np
np.random.seed(0)
m = np.random.randn(5)
P = np.random.randn(5, 15)
P = P @ P.T
log_pi = lambda x: jstats.multivariate_normal.logpdf(x, m, P)
log_p = lambda x_t, x: jstats.laplace.logpdf(x_t, x).sum()
p = lambda k, x, n_samples: x[None, :] + jax.random.laplace(k, (n_samples, 5))
N = 100
T = 100_000
keys = jax.random.split(jax.random.PRNGKey(42), T)
x0 = np.random.randn(5)
step = mtm(log_pi, log_p, p, N)
def mcmc_body(x, k):
x, accepted = step(k, x)
return x, (x, accepted)
_, (xs, are_accepted) = jax.lax.scan(mcmc_body, x0, keys, T)
print()
print(m)
print(xs[1_000::10].mean(0))
print()
print(np.cov(xs[1_000::10], rowvar=False))
print(P)
print()
print(np.mean(are_accepted))
if __name__ == "__main__":
main()