Skip to content

Commit b107f9f

Browse files
ismael-mendozaAdrienCorenflosjunpenglao
authored
Add pre-conditioning matrix to Barker proposal (#731)
* Draft pre-conditioning matrix in Barker proposal. This is a first draft of adding the pre-conditioning to the Barker proposal. This follows Algorithms 4 and 5 in Appendix G of the original Barker proposal paper. It's somewhat unclear from the paper, but the separate step size that was already implemented serves as a global scale for the normal distribution of the proposal. The function `_compute_acceptance_probability` now takes in the transpose sqrt mass matrix and the inverse, also it has been flattened to accomodate the corresponding matrix multiplicatios. * Fix typing of inverse_mass_matrix argument Fix typing of mass matrix. * Fix docstrings. The original docstring of step_size was incorrect, there is no sympletic integrator. * Make test for Barker in test_sampling run again We make this possible by adding an identity pre-conditining matrix, which should make the test run in the same way as before. * Add test to ensure correctness of precond matrix We add a new test to barker.py to ensure that our implementation of the preconditioning matrix is correct. We follow Appendix G in the paper that mentions that algorithm 4 and 5 (which we implemented) should be equivalent to rescaling the parameters and the logdensity in a specific way. We implement both approaches when using the barker proposal to infer the mean and sigma of a normal distribution. We check that with two different random seeds the chains outputted are equivalent up to some tolerance. We also patch the original test in this file by adding an identity mass matrix. * Fix dimensionality of identity matrix * Add missing mass matrix in missing tests. * added option to transpose the matrix when scaling option to transpose the mass_matrix_sqrt or inv_mass_matrix_sqrt was necessary for the barker algorithm as far as I can tell. This has not been propagated to the riemannian metric * use the metric scaling function in barker Here we use the new metric.scale function to perform the operations required by the Barker proposal algorithm, instead of passing around the mass_matrix_sqrt and inv_mass_matrix_sqrt directly. We also make the `inverse_mass_matrix` argument optional to avoid breaking the API. * update test_sampling with barker api the mass matrix is now an optional argument in barker. * update test_barker so it works with metric.scale * fix tests add trans to scale * add trans argument to riemannian scaling * no default * Update barker.py Make acceptance function metric agnostic * Update test_barker.py Add invariance test * simplify logic to remove _barker_sample_nd * fix bug so now everything is tree_mapped in barker * fix test to not use _barker_sample_nd * Update blackjax/mcmc/metrics.py make inv and trans required kwarg with type bool in metric.scale Co-authored-by: Junpeng Lao <[email protected]> * Update blackjax/mcmc/metrics.py lax.cond might not be needed in metric.scale as inv and trans are static kwarg Co-authored-by: Junpeng Lao <[email protected]> * propagate changes of inv, trans as required kwarg * fix test metrics --------- Co-authored-by: Adrien Corenflos <[email protected]> Co-authored-by: Junpeng Lao <[email protected]>
1 parent 5a25352 commit b107f9f

File tree

5 files changed

+269
-93
lines changed

5 files changed

+269
-93
lines changed

blackjax/mcmc/barker.py

+79-67
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
import jax.numpy as jnp
1919
from jax.flatten_util import ravel_pytree
2020
from jax.scipy import stats
21-
from jax.tree_util import tree_leaves, tree_map
2221

22+
import blackjax.mcmc.metrics as metrics
2323
from blackjax.base import SamplingAlgorithm
24+
from blackjax.mcmc.metrics import Metric
2425
from blackjax.mcmc.proposal import static_binomial_sampling
25-
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey
26+
from blackjax.types import ArrayLikeTree, ArrayTree, Numeric, PRNGKey
27+
from blackjax.util import generate_gaussian_noise
2628

2729
__all__ = ["BarkerState", "BarkerInfo", "init", "build_kernel", "as_top_level_api"]
2830

@@ -81,44 +83,70 @@ def build_kernel():
8183
"""
8284

8385
def _compute_acceptance_probability(
84-
state: BarkerState,
85-
proposal: BarkerState,
86-
) -> float:
86+
state: BarkerState, proposal: BarkerState, metric: Metric
87+
) -> Numeric:
8788
"""Compute the acceptance probability of the Barker's proposal kernel."""
8889

89-
def ratio_proposal_nd(y, x, log_y, log_x):
90-
num = -_log1pexp(-log_y * (x - y))
91-
den = -_log1pexp(-log_x * (y - x))
90+
x = state.position
91+
y = proposal.position
92+
log_x = state.logdensity_grad
93+
log_y = proposal.logdensity_grad
9294

93-
return jnp.sum(num - den)
95+
y_minus_x = jax.tree_util.tree_map(lambda a, b: a - b, y, x)
96+
x_minus_y = jax.tree_util.tree_map(lambda a: -a, y_minus_x)
97+
z_tilde_x_to_y = metric.scale(x, y_minus_x, inv=True, trans=True)
98+
z_tilde_y_to_x = metric.scale(y, x_minus_y, inv=True, trans=True)
9499

95-
ratios_proposals = tree_map(
96-
ratio_proposal_nd,
97-
proposal.position,
98-
state.position,
99-
proposal.logdensity_grad,
100-
state.logdensity_grad,
100+
c_x_to_y = metric.scale(x, log_x, inv=False, trans=True)
101+
c_y_to_x = metric.scale(y, log_y, inv=False, trans=True)
102+
103+
z_tilde_x_to_y_flat, _ = ravel_pytree(z_tilde_x_to_y)
104+
z_tilde_y_to_x_flat, _ = ravel_pytree(z_tilde_y_to_x)
105+
106+
c_x_to_y_flat, _ = ravel_pytree(c_x_to_y)
107+
c_y_to_x_flat, _ = ravel_pytree(c_y_to_x)
108+
109+
num = metric.kinetic_energy(x_minus_y, y) - _log1pexp(
110+
-z_tilde_y_to_x_flat * c_y_to_x_flat
101111
)
102-
ratio_proposal = sum(tree_leaves(ratios_proposals))
112+
denom = metric.kinetic_energy(y_minus_x, x) - _log1pexp(
113+
-z_tilde_x_to_y_flat * c_x_to_y_flat
114+
)
115+
116+
ratio_proposal = jnp.sum(num - denom)
117+
103118
return proposal.logdensity - state.logdensity + ratio_proposal
104119

105120
def kernel(
106-
rng_key: PRNGKey, state: BarkerState, logdensity_fn: Callable, step_size: float
121+
rng_key: PRNGKey,
122+
state: BarkerState,
123+
logdensity_fn: Callable,
124+
step_size: float,
125+
inverse_mass_matrix: metrics.MetricTypes | None = None,
107126
) -> tuple[BarkerState, BarkerInfo]:
108-
"""Generate a new sample with the MALA kernel."""
127+
"""Generate a new sample with the Barker kernel."""
128+
if inverse_mass_matrix is None:
129+
p, _ = ravel_pytree(state.position)
130+
(m,) = p.shape
131+
inverse_mass_matrix = jnp.ones((m,))
132+
metric = metrics.default_metric(inverse_mass_matrix)
109133
grad_fn = jax.value_and_grad(logdensity_fn)
110-
111134
key_sample, key_rmh = jax.random.split(rng_key)
112135

113136
proposed_pos = _barker_sample(
114-
key_sample, state.position, state.logdensity_grad, step_size
137+
key_sample,
138+
state.position,
139+
state.logdensity_grad,
140+
step_size,
141+
metric,
115142
)
143+
116144
proposed_logdensity, proposed_logdensity_grad = grad_fn(proposed_pos)
117145
proposed_state = BarkerState(
118146
proposed_pos, proposed_logdensity, proposed_logdensity_grad
119147
)
120148

121-
log_p_accept = _compute_acceptance_probability(state, proposed_state)
149+
log_p_accept = _compute_acceptance_probability(state, proposed_state, metric)
122150
accepted_state, info = static_binomial_sampling(
123151
key_rmh, log_p_accept, state, proposed_state
124152
)
@@ -131,6 +159,7 @@ def kernel(
131159
def as_top_level_api(
132160
logdensity_fn: Callable,
133161
step_size: float,
162+
inverse_mass_matrix: metrics.MetricTypes | None = None,
134163
) -> SamplingAlgorithm:
135164
"""Implements the (basic) user interface for the Barker's proposal :cite:p:`Livingstone2022Barker` kernel with a
136165
Gaussian base kernel.
@@ -174,7 +203,9 @@ def as_top_level_api(
174203
logdensity_fn
175204
The log-density function we wish to draw samples from.
176205
step_size
177-
The value to use for the step size in the symplectic integrator.
206+
The value of the step_size correspnoding to the global scale of the proposal distribution.
207+
inverse_mass_matrix
208+
The inverse mass matrix to use for pre-conditioning (see Appendix G of :cite:p:`Livingstone2022Barker`).
178209
179210
Returns
180211
-------
@@ -189,74 +220,55 @@ def init_fn(position: ArrayLikeTree, rng_key=None):
189220
return init(position, logdensity_fn)
190221

191222
def step_fn(rng_key: PRNGKey, state):
192-
return kernel(rng_key, state, logdensity_fn, step_size)
223+
return kernel(rng_key, state, logdensity_fn, step_size, inverse_mass_matrix)
193224

194225
return SamplingAlgorithm(init_fn, step_fn)
195226

196227

197-
def _barker_sample_nd(key, mean, a, scale):
198-
"""
199-
Sample from a multivariate Barker's proposal distribution. In 1D, this has the following probability density function:
200-
201-
.. math::
202-
p(x; \\mu, a, \\sigma) = 2 \frac{N(x; \\mu, \\sigma^2)}{1 + \\exp(-a (x - \\mu)}
228+
def _generate_bernoulli(
229+
rng_key: PRNGKey, position: ArrayLikeTree, p: ArrayLikeTree
230+
) -> ArrayTree:
231+
pos, unravel_fn = ravel_pytree(position)
232+
p_flat, _ = ravel_pytree(p)
233+
sample = jax.random.bernoulli(rng_key, p=p_flat, shape=pos.shape)
234+
return unravel_fn(sample)
203235

204-
where :math:`N(x; \\mu, \\sigma^2)` is the normal distribution with mean :math:`\\mu` and standard deviation :math:`\\sigma`.
205-
The multivariate Barker's proposal distribution is the product of one-dimensional Barker's proposal distributions.
206236

237+
def _barker_sample(key, mean, a, scale, metric):
238+
r"""
239+
Sample from a multivariate Barker's proposal distribution for PyTrees.
207240
208241
Parameters
209242
----------
210243
key
211244
A PRNG key.
212245
mean
213-
The mean of the normal distribution, an Array. This corresponds to :math:`\\mu` in the equation above.
246+
The mean of the normal distribution, a PyTree. This corresponds to :math:`\mu` in the equation above.
214247
a
215-
The parameter :math:`a` in the equation above, an Array. This is a skewness parameter.
248+
The parameter :math:`a` in the equation above, the same PyTree as `mean`. This is a skewness parameter.
216249
scale
217-
The standard deviation of the normal distribution, a scalar. This corresponds to :math:`\\sigma` in the equation above.
250+
The global scale, a scalar. This corresponds to :math:`\\sigma` in the equation above.
218251
It encodes the step size of the proposal.
219-
220-
Returns
221-
-------
222-
A sample from the Barker's multidimensional proposal distribution.
223-
252+
metric
253+
A `metrics.MetricTypes` object encoding the mass matrix information.
224254
"""
225255

226256
key1, key2 = jax.random.split(key)
227-
z = scale * jax.random.normal(key1, shape=mean.shape)
257+
258+
z = generate_gaussian_noise(key1, mean, sigma=scale)
259+
c = metric.scale(mean, a, inv=False, trans=True)
228260

229261
# Sample b=1 with probability p and 0 with probability 1 - p where
230262
# p = 1 / (1 + exp(-a * (z - mean)))
231-
log_p = -_log1pexp(-a * z)
232-
b = jax.random.bernoulli(key2, p=jnp.exp(log_p), shape=mean.shape)
233-
234-
# return mean + z if b == 1 else mean - z
235-
return mean + b * z - (1 - b) * z
236-
263+
log_p = jax.tree_util.tree_map(lambda x, y: -_log1pexp(-x * y), c, z)
264+
p = jax.tree_util.tree_map(lambda x: jnp.exp(x), log_p)
265+
b = _generate_bernoulli(key2, mean, p=p)
237266

238-
def _barker_sample(key, mean, a, scale):
239-
r"""
240-
Sample from a multivariate Barker's proposal distribution for PyTrees.
241-
242-
Parameters
243-
----------
244-
key
245-
A PRNG key.
246-
mean
247-
The mean of the normal distribution, a PyTree. This corresponds to :math:`\mu` in the equation above.
248-
a
249-
The parameter :math:`a` in the equation above, the same PyTree as `mean`. This is a skewness parameter.
250-
scale
251-
The standard deviation of the normal distribution, a scalar. This corresponds to :math:`\sigma` in the equation above.
252-
It encodes the step size of the proposal.
253-
254-
"""
267+
bz = jax.tree_util.tree_map(lambda x, y: x * y - (1 - x) * y, b, z)
255268

256-
flat_mean, unravel_fn = ravel_pytree(mean)
257-
flat_a, _ = ravel_pytree(a)
258-
flat_sample = _barker_sample_nd(key, flat_mean, flat_a, scale)
259-
return unravel_fn(flat_sample)
269+
return jax.tree_util.tree_map(
270+
lambda a, b: a + b, mean, metric.scale(mean, bz, inv=False, trans=False)
271+
)
260272

261273

262274
def _log1pexp(a):

blackjax/mcmc/metrics.py

+39-16
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
"""
3131
from typing import Callable, NamedTuple, Optional, Protocol, Union
3232

33-
import jax
3433
import jax.numpy as jnp
3534
import jax.scipy as jscipy
3635
from jax.flatten_util import ravel_pytree
@@ -62,7 +61,12 @@ def __call__(
6261

6362
class Scale(Protocol):
6463
def __call__(
65-
self, position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
64+
self,
65+
position: ArrayLikeTree,
66+
element: ArrayLikeTree,
67+
*,
68+
inv: bool,
69+
trans: bool,
6670
) -> ArrayLikeTree:
6771
...
6872

@@ -187,7 +191,11 @@ def is_turning(
187191
return turning_at_left | turning_at_right
188192

189193
def scale(
190-
position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
194+
position: ArrayLikeTree,
195+
element: ArrayLikeTree,
196+
*,
197+
inv: bool,
198+
trans: bool,
191199
) -> ArrayLikeTree:
192200
"""Scale elements by the mass matrix.
193201
@@ -197,10 +205,11 @@ def scale(
197205
The current position. Not used in this metric.
198206
elements
199207
Elements to scale
200-
invs
208+
inv
201209
Whether to scale the elements by the inverse mass matrix or the mass matrix.
202210
If True, the element is scaled by the inverse square root mass matrix, i.e., elem <- (M^{1/2})^{-1} elem.
203-
Same pytree structure as `elements`.
211+
trans
212+
whether to transpose mass matrix when scaling
204213
205214
Returns
206215
-------
@@ -209,11 +218,16 @@ def scale(
209218
"""
210219

211220
ravelled_element, unravel_fn = ravel_pytree(element)
212-
scaled = jax.lax.cond(
213-
inv,
214-
lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element),
215-
lambda: linear_map(mass_matrix_sqrt, ravelled_element),
216-
)
221+
222+
if inv:
223+
left_hand_side_matrix = inv_mass_matrix_sqrt
224+
else:
225+
left_hand_side_matrix = mass_matrix_sqrt
226+
if trans:
227+
left_hand_side_matrix = left_hand_side_matrix.T
228+
229+
scaled = linear_map(left_hand_side_matrix, ravelled_element)
230+
217231
return unravel_fn(scaled)
218232

219233
return Metric(momentum_generator, kinetic_energy, is_turning, scale)
@@ -279,7 +293,11 @@ def is_turning(
279293
# return turning_at_left | turning_at_right
280294

281295
def scale(
282-
position: ArrayLikeTree, element: ArrayLikeTree, inv: ArrayLikeTree
296+
position: ArrayLikeTree,
297+
element: ArrayLikeTree,
298+
*,
299+
inv: bool,
300+
trans: bool,
283301
) -> ArrayLikeTree:
284302
"""Scale elements by the mass matrix.
285303
@@ -298,11 +316,16 @@ def scale(
298316
mass_matrix, is_inv=False
299317
)
300318
ravelled_element, unravel_fn = ravel_pytree(element)
301-
scaled = jax.lax.cond(
302-
inv,
303-
lambda: linear_map(inv_mass_matrix_sqrt, ravelled_element),
304-
lambda: linear_map(mass_matrix_sqrt, ravelled_element),
305-
)
319+
320+
if inv:
321+
left_hand_side_matrix = inv_mass_matrix_sqrt
322+
else:
323+
left_hand_side_matrix = mass_matrix_sqrt
324+
if trans:
325+
left_hand_side_matrix = left_hand_side_matrix.T
326+
327+
scaled = linear_map(left_hand_side_matrix, ravelled_element)
328+
306329
return unravel_fn(scaled)
307330

308331
return Metric(momentum_generator, kinetic_energy, is_turning, scale)

0 commit comments

Comments
 (0)