From 0c26a853535339a3655ce552efb142e5d905d868 Mon Sep 17 00:00:00 2001 From: siege Date: Fri, 16 Feb 2024 15:24:09 -0800 Subject: [PATCH] For SNAPER, use the improved gradient estimator from Riou-Durand et al. The implementation required a bit of refactoring, since we need to be careful to only use validated (by MetropolisHastings) states for both previous and proposed states depending on direction. Also, use harmonic mean for the docstring example, as that is best practice. The use of tfb.Exp() isn't ideal for tfd.HalfNormal(), but it gets to stay as it uncovered the issues mentioned above. PiperOrigin-RevId: 607816872 --- ...ient_based_trajectory_length_adaptation.py | 255 ++++++++++++------ ...based_trajectory_length_adaptation_test.py | 23 +- .../python/experimental/mcmc/snaper_hmc.py | 1 + 3 files changed, 187 insertions(+), 92 deletions(-) diff --git a/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py b/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py index 6f65e2ae5a..62bb1390a0 100644 --- a/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py +++ b/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py @@ -108,6 +108,39 @@ def _reduce_with_axes(index_op, name_op, x, axis_idx=None, axis_names=None): distribute_lib.pmean) +def _estimate_empirical_mean(x, accept_prob, safe, reduce_chain_axis_names): + """Estimates the empirical mean of x.""" + batch_ndims = ps.rank(accept_prob) + batch_axes = ps.range(batch_ndims, dtype=tf.int32) + + if safe: + # Note that we don't do a monte carlo average of the accepted chain + # position, but rather try to get an estimate of the underlying dynamics. + # This is done by only looking at proposed states where the integration + # error is low. + # TODO(mhoffman): Needs more experimentation. + expanded_accept_prob = bu.left_justified_expand_dims_like(accept_prob, x) + + # accept_prob is zero when x is NaN, but we still want to sanitize such + # values. + x_safe = tf.where(tf.math.is_finite(x), x, tf.zeros_like(x)) + # If all accept_prob's are zero, the x_center will have a nonsense value, + # but we'll discard the resultant gradients later on, so it's fine. + x_mean = _reduce_sum_with_axes( + expanded_accept_prob * x_safe, batch_axes, reduce_chain_axis_names + ) / ( + _reduce_sum_with_axes( + expanded_accept_prob, batch_axes, reduce_chain_axis_names + ) + + 1e-20 + ) + else: + x_mean = _reduce_mean_with_axes(x, batch_axes, reduce_chain_axis_names) + # The empirical mean here is a stand-in for the true mean, so we drop the + # gradient that flows through this term. + return tf.stop_gradient(x_mean) + + def hmc_like_num_leapfrog_steps_getter_fn(kernel_results): """Getter for `num_leapfrog_steps` so it can be inspected.""" return unnest.get_innermost(kernel_results, 'num_leapfrog_steps') @@ -120,24 +153,32 @@ def hmc_like_num_leapfrog_steps_setter_fn(kernel_results, kernel_results, num_leapfrog_steps=new_num_leapfrog_steps) -def hmc_like_proposed_velocity_getter_fn(kernel_results): - """Getter for `proposed_velocity` so it can be inspected.""" - final_momentum = unnest.get_innermost(kernel_results, 'final_momentum') +def _hmc_like_velocity_getter_fn(kernel_results, momentum_name): + """Getter for a velocity so it can be inspected.""" + momentum = unnest.get_innermost(kernel_results, momentum_name) proposed_state = unnest.get_innermost(kernel_results, 'proposed_state') momentum_distribution = unnest.get_innermost( kernel_results, 'momentum_distribution', default=None) if momentum_distribution is None: - proposed_velocity = final_momentum + velocity = momentum else: momentum_log_prob = getattr(momentum_distribution, '_log_prob_unnormalized', momentum_distribution.log_prob) kinetic_energy_fn = lambda *args: -momentum_log_prob(*args) - _, proposed_velocity = mcmc_util.maybe_call_fn_and_grads( - kinetic_energy_fn, final_momentum) + _, velocity = mcmc_util.maybe_call_fn_and_grads( + kinetic_energy_fn, momentum) # proposed_velocity has the wrong structure when state is a scalar. return tf.nest.pack_sequence_as(proposed_state, - tf.nest.flatten(proposed_velocity)) + tf.nest.flatten(velocity)) + + +hmc_like_proposed_velocity_getter_fn = functools.partial( + _hmc_like_velocity_getter_fn, momentum_name='final_momentum' +) +hmc_like_initial_velocity_getter_fn = functools.partial( + _hmc_like_velocity_getter_fn, momentum_name='initial_momentum' +) def hmc_like_proposed_state_getter_fn(kernel_results): @@ -161,6 +202,7 @@ def chees_criterion(previous_state, proposed_state, accept_prob, trajectory_length, + forward=True, validate_args=False, experimental_shard_axis_names=None, experimental_reduce_chain_axis_names=None): @@ -196,6 +238,8 @@ def chees_criterion(previous_state, accept_prob: Floating `Tensor`. Probability of acceping the proposed state. trajectory_length: Floating `Tensor`. Mean trajectory length (not used in this criterion). + forward: Whether accept_prob refers to the proposed_state (True) or the + previous_state (False). validate_args: Whether to perform non-static argument validation. experimental_shard_axis_names: A structure of string names indicating how members of the state are sharded. @@ -218,7 +262,6 @@ def chees_criterion(previous_state, """ del trajectory_length batch_ndims = ps.rank(accept_prob) - batch_axes = ps.range(batch_ndims, dtype=tf.int32) reduce_chain_axis_names = distribute_lib.canonicalize_named_axis( experimental_reduce_chain_axis_names) @@ -230,33 +273,20 @@ def chees_criterion(previous_state, ) def _center_previous_state(x): - # The empirical mean here is a stand-in for the true mean, so we drop the - # gradient that flows through this term. - x_mean = _reduce_mean_with_axes(x, batch_axes, reduce_chain_axis_names) - return x - tf.stop_gradient(x_mean) + return x - _estimate_empirical_mean( + x, + accept_prob=accept_prob, + safe=not forward, + reduce_chain_axis_names=reduce_chain_axis_names, + ) def _center_proposed_state(x): - # Note that we don't do a monte carlo average of the accepted chain - # position, but rather try to get an estimate of the underlying dynamics. - # This is done by only looking at proposed states where the integration - # error is low. - # TODO(mhoffman): Needs more experimentation. - expanded_accept_prob = bu.left_justified_expand_dims_like(accept_prob, x) - - # accept_prob is zero when x is NaN, but we still want to sanitize such - # values. - x_safe = tf.where(tf.math.is_finite(x), x, tf.zeros_like(x)) - # If all accept_prob's are zero, the x_center will have a nonsense value, - # but we'll discard the resultant gradients later on, so it's fine. - x_center = ( - _reduce_sum_with_axes(expanded_accept_prob * x_safe, batch_axes, - reduce_chain_axis_names) / - (_reduce_sum_with_axes(expanded_accept_prob, batch_axes, - reduce_chain_axis_names) + 1e-20)) - - # The empirical mean here is a stand-in for the true mean, so we drop the - # gradient that flows through this term. - return x - tf.stop_gradient(x_center) + return x - _estimate_empirical_mean( + x, + accept_prob=accept_prob, + safe=forward, + reduce_chain_axis_names=reduce_chain_axis_names, + ) def _sum_event_part(x, shard_axes=None): event_axes = ps.range(batch_ndims, ps.rank(x)) @@ -287,6 +317,7 @@ def chees_rate_criterion(previous_state, proposed_state, accept_prob, trajectory_length, + forward=True, validate_args=False, experimental_shard_axis_names=None, experimental_reduce_chain_axis_names=None): @@ -306,6 +337,8 @@ def chees_rate_criterion(previous_state, state of the HMC chain. accept_prob: Floating `Tensor`. Probability of acceping the proposed state. trajectory_length: Floating `Tensor`. Trajectory length. + forward: Whether accept_prob refers to the proposed_state (True) or the + previous_state (False). validate_args: Whether to perform non-static argument validation. experimental_shard_axis_names: A structure of string names indicating how members of the state are sharded. @@ -321,6 +354,7 @@ def chees_rate_criterion(previous_state, proposed_state=proposed_state, accept_prob=accept_prob, trajectory_length=trajectory_length, + forward=forward, validate_args=validate_args, experimental_shard_axis_names=experimental_shard_axis_names, experimental_reduce_chain_axis_names=experimental_reduce_chain_axis_names, @@ -334,6 +368,7 @@ def snaper_criterion(previous_state, direction, state_mean=None, state_mean_weight=0., + forward=True, validate_args=False, experimental_shard_axis_names=None, experimental_reduce_chain_axis_names=None): @@ -383,6 +418,8 @@ def snaper_criterion(previous_state, state_mean: Optional (Possibly nested) floating point `Tensor`. The estimated state mean. state_mean_weight: Floating point `Tensor`. The weight of the `state_mean`. + forward: Whether accept_prob refers to the proposed_state (True) or the + previous_state (False). validate_args: Whether to perform non-static argument validation. experimental_shard_axis_names: A structure of string names indicating how members of the state are sharded. @@ -400,7 +437,6 @@ def snaper_criterion(previous_state, """ batch_ndims = ps.rank(accept_prob) - batch_axes = ps.range(batch_ndims, dtype=tf.int32) reduce_chain_axis_names = distribute_lib.canonicalize_named_axis( experimental_reduce_chain_axis_names) @@ -411,7 +447,10 @@ def snaper_criterion(previous_state, accept_prob, reduce_chain_axis_names=reduce_chain_axis_names, validate_args=validate_args, - message='snaper_criterion requires at least 2 chains when `state_mean` is `None`' + message=( + 'snaper_criterion requires at least 2 chains when `state_mean` is' + ' `None`' + ), ) def _mix_in_state_mean(empirical_mean, state_mean): @@ -422,33 +461,22 @@ def _mix_in_state_mean(empirical_mean, state_mean): state_mean_weight * state_mean) def _center_previous_state(x, x_mean): - # The empirical mean here is a stand-in for the true mean, so we drop the - # gradient that flows through this term. - emp_x_mean = tf.stop_gradient( - distribute_lib.reduce_mean(x, batch_axes, reduce_chain_axis_names)) + emp_x_mean = _estimate_empirical_mean( + x, + accept_prob=accept_prob, + safe=not forward, + reduce_chain_axis_names=reduce_chain_axis_names, + ) x_mean = _mix_in_state_mean(emp_x_mean, x_mean) return x - x_mean def _center_proposed_state(x, x_mean): - # Note that we don't do a monte carlo average of the accepted chain - # position, but rather try to get an estimate of the underlying dynamics. - # This is done by only looking at proposed states where the integration - # error is low. - expanded_accept_prob = bu.left_justified_expand_dims_like(accept_prob, x) - - # accept_prob is zero when x is NaN, but we still want to sanitize such - # values. - x_safe = tf.where(tf.math.is_finite(x), x, tf.zeros_like(x)) - # The empirical mean here is a stand-in for the true mean, so we drop the - # gradient that flows through this term. - # If all accept_prob's are zero, the x_center will have a nonsense value, - # but we'll discard the resultant gradients later on, so it's fine. - emp_x_mean = tf.stop_gradient( - distribute_lib.reduce_sum(expanded_accept_prob * x_safe, batch_axes, - reduce_chain_axis_names) / - (distribute_lib.reduce_sum(expanded_accept_prob, batch_axes, - reduce_chain_axis_names) + 1e-20)) - + emp_x_mean = _estimate_empirical_mean( + x, + accept_prob=accept_prob, + safe=forward, + reduce_chain_axis_names=reduce_chain_axis_names, + ) x_mean = _mix_in_state_mean(emp_x_mean, x_mean) return x - x_mean @@ -505,6 +533,13 @@ class GradientBasedTrajectoryLengthAdaptation(kernel_base.TransitionKernel): value during development in order to inspect the behavior of the chain during adaptation. + Optionally, it is possible to use the improved gradient estimator from [3] by + setting `use_reverse_estimator` to `True`. This estimator relies on the + reversibility of HMC proposal to reduce variance and thus improve the + adaptation speed and reliability. If this is set to `true`, `criterion_fn` + needs to also take the `forward` argument to distinguish the implied + integration direction. + #### Examples This implements something similar to ChEES HMC from [2]. @@ -535,7 +570,9 @@ class GradientBasedTrajectoryLengthAdaptation(kernel_base.TransitionKernel): kernel, num_adaptation_steps=num_adaptation_steps) kernel = tfp.mcmc.DualAveragingStepSizeAdaptation( - kernel, num_adaptation_steps=num_adaptation_steps) + kernel, + num_adaptation_steps=num_adaptation_steps, + reduce_fn=tfp.math.reduce_log_harmonic_mean_exp) kernel = tfp.mcmc.TransformedTransitionKernel( kernel, [tfb.Identity(), @@ -560,7 +597,8 @@ def trace_fn(_, pkr): kernel=kernel, trace_fn=trace_fn,)) - # ~0.75 + # ~0.95, because Exp bijector is really bad for HalfNormal. Use Softplus in + # practice. accept_prob = tf.math.exp(tfp.math.reduce_logmeanexp( tf.minimum(log_accept_ratio, 0.))) ``` @@ -574,6 +612,10 @@ def trace_fn(_, pkr): for Setting Trajectory Lengths in Hamiltonian Monte Carlo. + [3]: Riou-Durand, L., Sountsov, P., Vogrinc, J., Margossian, C., Power, S. + (2023) Adaptive Tuning for Metropolis Adjusted Langevin Trajectories. + + """ def __init__( @@ -589,9 +631,11 @@ def __init__( num_leapfrog_steps_getter_fn=hmc_like_num_leapfrog_steps_getter_fn, num_leapfrog_steps_setter_fn=hmc_like_num_leapfrog_steps_setter_fn, step_size_getter_fn=hmc_like_step_size_getter_fn, + initial_velocity_getter_fn=hmc_like_initial_velocity_getter_fn, proposed_velocity_getter_fn=hmc_like_proposed_velocity_getter_fn, log_accept_prob_getter_fn=hmc_like_log_accept_prob_getter_fn, proposed_state_getter_fn=hmc_like_proposed_state_getter_fn, + use_reverse_estimator=False, validate_args=False, experimental_shard_axis_names=None, experimental_reduce_chain_axis_names=None, @@ -636,11 +680,16 @@ def __init__( step_size_getter_fn: A callable with the signature `(kernel_results) -> step_size` where `kernel_results` are the results of the `inner_kernel`, and `step_size` is a floating point `Tensor`. + initial_velocity_getter_fn: A callable with the signature + `(kernel_results) -> initial_velocity` where `kernel_results` are the + results of the `inner_kernel`, and `initial_velocity` is a (possibly + nested) floating point `Tensor`. Velocity is the derivative of state + with respect to trajectory length. proposed_velocity_getter_fn: A callable with the signature `(kernel_results) -> proposed_velocity` where `kernel_results` are the results of the `inner_kernel`, and `proposed_velocity` is a (possibly - nested) floating point `Tensor`. Velocity is derivative of state with - respect to trajectory length. + nested) floating point `Tensor`. Velocity is the derivative of state + with respect to trajectory length. log_accept_prob_getter_fn: A callable with the signature `(kernel_results) -> log_accept_prob` where `kernel_results` are the results of the `inner_kernel`, and `log_accept_prob` is a floating point `Tensor`. @@ -649,6 +698,9 @@ def __init__( -> proposed_state` where `kernel_results` are the results of the `inner_kernel`, and `proposed_state` is a (possibly nested) floating point `Tensor`. + use_reverse_estimator: Whether to use an improved estimator to compute + trajectory length gradients. If `True`, `criterion_fn` needs to take a + `forward` kwarg. validate_args: Python `bool`. When `True` kernel parameters are checked for validity. When `False` invalid inputs may silently render incorrect outputs. @@ -690,9 +742,11 @@ class docstring). num_leapfrog_steps_getter_fn=num_leapfrog_steps_getter_fn, num_leapfrog_steps_setter_fn=num_leapfrog_steps_setter_fn, step_size_getter_fn=step_size_getter_fn, + initial_velocity_getter_fn=initial_velocity_getter_fn, proposed_velocity_getter_fn=proposed_velocity_getter_fn, log_accept_prob_getter_fn=log_accept_prob_getter_fn, proposed_state_getter_fn=hmc_like_proposed_state_getter_fn, + use_reverse_estimator=use_reverse_estimator, validate_args=validate_args, experimental_shard_axis_names=experimental_shard_axis_names, experimental_reduce_chain_axis_names=experimental_reduce_chain_axis_names, @@ -712,7 +766,7 @@ def num_adaptation_steps(self): return self._parameters['num_adaptation_steps'] def criterion_fn(self, previous_state, proposed_state, accept_prob, - trajectory_length): + trajectory_length, forward=True): kwargs = {} if self.experimental_reduce_chain_axis_names is not None: kwargs['experimental_reduce_chain_axis_names'] = ( @@ -720,6 +774,8 @@ def criterion_fn(self, previous_state, proposed_state, accept_prob, if self.experimental_shard_axis_names is not None: kwargs['experimental_shard_axis_names'] = ( self.experimental_shard_axis_names) + if self.use_reverse_estimator: + kwargs['forward'] = forward return self._parameters['criterion_fn'](previous_state, proposed_state, accept_prob, trajectory_length, **kwargs) @@ -743,6 +799,9 @@ def num_leapfrog_steps_setter_fn(self, kernel_results, def step_size_getter_fn(self, kernel_results): return self._parameters['step_size_getter_fn'](kernel_results) + def initial_velocity_getter_fn(self, kernel_results): + return self._parameters['initial_velocity_getter_fn'](kernel_results) + def proposed_velocity_getter_fn(self, kernel_results): return self._parameters['proposed_velocity_getter_fn'](kernel_results) @@ -752,6 +811,10 @@ def log_accept_prob_getter_fn(self, kernel_results): def proposed_state_getter_fn(self, kernel_results): return self._parameters['proposed_state_getter_fn'](kernel_results) + @property + def use_reverse_estimator(self): + return self._parameters['use_reverse_estimator'] + @property def validate_args(self): return self._parameters['validate_args'] @@ -806,6 +869,7 @@ def one_step(self, current_state, previous_kernel_results, seed=None): current_state, previous_kernel_results_with_jitter.inner_results, inner_seed) + initial_velocity = self.initial_velocity_getter_fn(new_inner_results) proposed_state = self.proposed_state_getter_fn(new_inner_results) proposed_velocity = self.proposed_velocity_getter_fn(new_inner_results) accept_prob = tf.exp(self.log_accept_prob_getter_fn(new_inner_results)) @@ -815,13 +879,15 @@ def one_step(self, current_state, previous_kernel_results, seed=None): previous_state=current_state, proposed_state=proposed_state, proposed_velocity=proposed_velocity, + initial_velocity=initial_velocity, trajectory_jitter=trajectory_jitter, accept_prob=accept_prob, step_size=step_size, criterion_fn=self.criterion_fn, max_leapfrog_steps=self.max_leapfrog_steps, experimental_shard_axis_names=self.experimental_shard_axis_names, - reduce_chain_axis_names=self.experimental_reduce_chain_axis_names) + reduce_chain_axis_names=self.experimental_reduce_chain_axis_names, + use_reverse_estimator=self.use_reverse_estimator) # Undo the effect of adaptation if we're not in the burnin phase. We keep # the criterion, however, as that's a diagnostic. We also keep the @@ -930,6 +996,7 @@ def _halton_sequence(i, max_bits=MAX_HALTON_SEQUENCE_BITS): def _update_trajectory_grad(previous_kernel_results, previous_state, + initial_velocity, proposed_state, proposed_velocity, trajectory_jitter, @@ -937,35 +1004,54 @@ def _update_trajectory_grad(previous_kernel_results, step_size, criterion_fn, max_leapfrog_steps, + use_reverse_estimator, experimental_shard_axis_names=None, reduce_chain_axis_names=None): """Updates the trajectory length.""" # Compute criterion grads. def leapfrog_action(dt): - # This represents the effect on the criterion value as the state follows the - # proposed velocity. This implicitly assumes an identity mass matrix. + fwd_start_end_vel = [ + (True, previous_state, proposed_state, proposed_velocity) + ] + if use_reverse_estimator: + fwd_start_end_vel.append(( + False, + proposed_state, + previous_state, + tf.nest.map_structure(lambda x: -x, initial_velocity), + )) + + # This represents the effect on the criterion value as the state follows + # the proposed velocity. This implicitly assumes an identity mass matrix. def adjust_state(x, v, shard_axes=None): broadcasted_dt = distribute_lib.pbroadcast( bu.left_justified_expand_dims_like(dt, v), shard_axes) return x + broadcasted_dt * v - adjusted_state = _map_structure_up_to_with_axes( - proposed_state, - adjust_state, - proposed_state, - proposed_velocity, - experimental_shard_axis_names=experimental_shard_axis_names) - return criterion_fn( - previous_state=previous_state, - proposed_state=adjusted_state, - accept_prob=accept_prob, - # We add the step size here because we effectively do `floor(traj + - # step_size) / step_size` when computing the number of leapfrog steps. - trajectory_length=( - trajectory_jitter * previous_kernel_results.max_trajectory_length + - step_size + dt), - ) + criterion_vals = [] + for forward, start, end, vel in fwd_start_end_vel: + adjusted_end = _map_structure_up_to_with_axes( + end, + adjust_state, + end, + vel, + experimental_shard_axis_names=experimental_shard_axis_names) + criterion_val = criterion_fn( + previous_state=start, + proposed_state=adjusted_end, + accept_prob=accept_prob, + # We add the step size here because we effectively do `floor(traj + + # step_size) / step_size` when computing the number of leapfrog steps. + trajectory_length=( + trajectory_jitter * previous_kernel_results.max_trajectory_length + + step_size + + dt + ), + forward=forward, + ) + criterion_vals.append(criterion_val) + return tf.reduce_mean(criterion_vals, axis=0) criterion, trajectory_grad = gradient.value_and_gradient( leapfrog_action, tf.zeros_like(accept_prob)) @@ -999,8 +1085,9 @@ def adjust_state(x, v, shard_axes=None): # Apply the gradient. Clip absolute value to ~log(2)/2. log_update = tf.clip_by_value(trajectory_step_size * trajectory_grad, -0.35, 0.35) - new_max_trajectory_length = previous_kernel_results.max_trajectory_length * tf.exp( - log_update) + new_max_trajectory_length = ( + previous_kernel_results.max_trajectory_length * tf.exp(log_update) + ) # Iterate averaging. average_weight = iteration_f**(-0.5) diff --git a/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation_test.py b/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation_test.py index 49054d6bf0..768c516084 100644 --- a/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation_test.py +++ b/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation_test.py @@ -159,9 +159,13 @@ def target_log_prob_fn(*x): num_leapfrog_steps=1, ) kernel = gbtla.GradientBasedTrajectoryLengthAdaptation( - kernel, num_adaptation_steps=num_adaptation_steps, validate_args=True) + kernel, num_adaptation_steps=num_adaptation_steps, validate_args=True, + use_reverse_estimator=True) kernel = dassa.DualAveragingStepSizeAdaptation( - kernel, num_adaptation_steps=num_adaptation_steps) + kernel, + num_adaptation_steps=num_adaptation_steps, + reduce_fn=generic.reduce_log_harmonic_mean_exp, + ) kernel = transformed_kernel.TransformedTransitionKernel( kernel, [identity.Identity(), exp.Exp()]) @@ -192,9 +196,9 @@ def trace_fn(_, pkr): mean_step_size = tf.reduce_mean(step_size) mean_max_trajectory_length = tf.reduce_mean(max_trajectory_length) - self.assertAllClose(0.75, p_accept, atol=0.1) - self.assertAllClose(0.52, mean_step_size, atol=0.2) - self.assertAllClose(46., mean_max_trajectory_length, atol=15) + self.assertAllClose(0.95, p_accept, rtol=0.2) + self.assertAllClose(0.3, mean_step_size, rtol=0.2) + self.assertAllClose(43., mean_max_trajectory_length, rtol=0.2) self.assertAllClose( target.mean(), [tf.reduce_mean(x, axis=[0, 1]) for x in chain], atol=1.5) @@ -328,10 +332,12 @@ def target_log_prob_fn(x, y): final_kernel_results.max_trajectory_length), 0.0005) @parameterized.named_parameters( - ('ChEES', gbtla.chees_rate_criterion), - ('SNAPER', snaper_criterion_2d_direction), + ('ChEES', gbtla.chees_rate_criterion, False), + ('SNAPER', snaper_criterion_2d_direction, False), + ('ChEES_reverse', gbtla.chees_rate_criterion, True), + ('SNAPER_reverse', snaper_criterion_2d_direction, True), ) - def testAdaptation(self, criterion_fn): + def testAdaptation(self, criterion_fn, use_reverse_estimator): if tf.executing_eagerly() and not JAX_MODE: self.skipTest('Too slow for TF Eager.') @@ -353,6 +359,7 @@ def testAdaptation(self, criterion_fn): kernel, num_adaptation_steps=num_adaptation_steps, criterion_fn=criterion_fn, + use_reverse_estimator=use_reverse_estimator, validate_args=True) kernel = dassa.DualAveragingStepSizeAdaptation( kernel, num_adaptation_steps=num_adaptation_steps) diff --git a/tensorflow_probability/python/experimental/mcmc/snaper_hmc.py b/tensorflow_probability/python/experimental/mcmc/snaper_hmc.py index 063e43cc83..d08d3f9014 100644 --- a/tensorflow_probability/python/experimental/mcmc/snaper_hmc.py +++ b/tensorflow_probability/python/experimental/mcmc/snaper_hmc.py @@ -406,6 +406,7 @@ def _max_part(x, named_axis): gbtla_kwargs = ( self.gradient_based_trajectory_length_adaptation_kwargs.copy()) gbtla_kwargs.setdefault('averaged_sq_grad_adaptation_rate', 0.5) + gbtla_kwargs.setdefault('use_reverse_estimator', True) kernel = gbtla.GradientBasedTrajectoryLengthAdaptation( kernel, num_adaptation_steps=self.num_adaptation_steps,