From c9a22d62da6b7e6d424f39502ca68b1f5a1d1903 Mon Sep 17 00:00:00 2001 From: siege Date: Wed, 14 Feb 2024 12:09:39 -0800 Subject: [PATCH] Fix SNAPER step size reporting. Internally, SNAPER absorbs part of the diagonal preconditioner inside the step size rather than placing it entirely into the mass matrix. This makes it difficult to compare the final step sizes to those obtained via, e.g., NUTS. This change undoes that scaling when constructing the default trace. Also, fix the GradientBasedTrajectoryLengthAdaptation not storing the seed in kernel results. PiperOrigin-RevId: 607064555 --- ...ient_based_trajectory_length_adaptation.py | 3 +- .../python/experimental/mcmc/snaper_hmc.py | 48 +++++++++++-------- .../experimental/mcmc/snaper_hmc_test.py | 41 +++++++++------- 3 files changed, 54 insertions(+), 38 deletions(-) diff --git a/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py b/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py index 6a297813b2..6f65e2ae5a 100644 --- a/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py +++ b/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py @@ -834,7 +834,8 @@ def one_step(self, current_state, previous_kernel_results, seed=None): new_kernel_results = new_kernel_results._replace( inner_results=new_inner_results, step=previous_kernel_results.step + 1, - criterion=criterion) + criterion=criterion, + seed=seed) return new_state, new_kernel_results diff --git a/tensorflow_probability/python/experimental/mcmc/snaper_hmc.py b/tensorflow_probability/python/experimental/mcmc/snaper_hmc.py index f9e9ff36a5..49080e4f4a 100644 --- a/tensorflow_probability/python/experimental/mcmc/snaper_hmc.py +++ b/tensorflow_probability/python/experimental/mcmc/snaper_hmc.py @@ -67,6 +67,7 @@ class SNAPERHamiltonianMonteCarloResults( 'inner_results', 'ema_mean', 'ema_variance', + 'max_ema_variance', 'state_ema_points', 'ema_principal_component', 'principal_component_ema_points', @@ -80,6 +81,7 @@ class SNAPERHamiltonianMonteCarloResults( `GradientBasedTrajectoryLengthAdaptationResults`. ema_mean: Exponential moving average cross-chain state mean. ema_variance: Exponential moving average cross-chain state variance. + max_ema_variance: Maximum of `ema_variance`. state_ema_points: Approximate number of points used to compute the exponential moving averages. ema_principal_component: Exponential moving average cross-chain state @@ -422,7 +424,7 @@ def _max_part(x, named_axis): validate_args=self.validate_args, **gbtla_kwargs, ) - return kernel + return kernel, max_variance def _update_state_ema( self, @@ -539,7 +541,7 @@ def one_step(self, current_state, previous_kernel_results, seed=None): step = inner_results.step state_ema_points = previous_kernel_results.state_ema_points - kernel = self._make_kernel( + kernel, max_variance = self._make_kernel( batch_shape=batch_shape, step=step, state_ema_points=state_ema_points, @@ -588,6 +590,7 @@ def one_step(self, current_state, previous_kernel_results, seed=None): inner_results=inner_results, ema_mean=ema_mean, ema_variance=ema_variance, + max_ema_variance=max_variance, state_ema_points=state_ema_points, ema_principal_component=ema_principal_component, principal_component_ema_points=principal_component_ema_points, @@ -659,7 +662,7 @@ def bootstrap_results(self, init_state): state_ema_points = tf.ones([], tf.int32) principal_component_ema_points = tf.ones([], tf.int32) - kernel = self._make_kernel( + kernel, max_variance = self._make_kernel( batch_shape=batch_shape, step=tf.zeros([], tf.int32), state_ema_points=state_ema_points, @@ -675,6 +678,7 @@ def bootstrap_results(self, init_state): inner_results=inner_results, ema_mean=ema_mean, ema_variance=ema_variance, + max_ema_variance=max_variance, state_ema_points=state_ema_points, ema_principal_component=ema_principal_component, principal_component_ema_points=principal_component_ema_points, @@ -1009,23 +1013,27 @@ def default_snaper_trace_fn(state, is_burnin, kernel_results, reducer, # The ~ is here to catch NaNs. has_divergence = ~(tf.math.abs(energy_diff) < 500.) return state, { - 'step_size': - unnest.get_innermost(kr, 'step_size'), - 'n_steps': - unnest.get_innermost(kr, 'num_leapfrog_steps'), - 'tune': - is_burnin, - 'max_trajectory_length': - unnest.get_innermost(kr, 'max_trajectory_length'), - 'variance_scaling': - tf.nest.map_structure(lambda x: 1. / x, - unnest.get_innermost(kr, 'ema_variance')), - 'diverging': - has_divergence, - 'accept_ratio': - tf.minimum(tf.ones_like(energy_diff), tf.exp(energy_diff)), - 'is_accepted': - unnest.get_innermost(kr, 'is_accepted'), + # SNAPER rescales the inner HMC kernel by max_ema_variance, so to aid + # comparisons with other algorithms which typically don't do this + # rescaling, we undo the rescaling here. This makes the step size + # consistent with the target_log_prob_fn scale implied by + # `variance_scaling` below. + 'step_size': unnest.get_innermost(kr, 'step_size') / tf.sqrt( + unnest.get_innermost(kr, 'max_ema_variance') + ), + 'n_steps': unnest.get_innermost(kr, 'num_leapfrog_steps'), + 'tune': is_burnin, + 'max_trajectory_length': unnest.get_innermost( + kr, 'max_trajectory_length' + ), + 'variance_scaling': tf.nest.map_structure( + lambda x: 1.0 / x, unnest.get_innermost(kr, 'ema_variance') + ), + 'diverging': has_divergence, + 'accept_ratio': tf.minimum( + tf.ones_like(energy_diff), tf.exp(energy_diff) + ), + 'is_accepted': unnest.get_innermost(kr, 'is_accepted'), } diff --git a/tensorflow_probability/python/experimental/mcmc/snaper_hmc_test.py b/tensorflow_probability/python/experimental/mcmc/snaper_hmc_test.py index e8a3c8dcba..1d367ebe37 100644 --- a/tensorflow_probability/python/experimental/mcmc/snaper_hmc_test.py +++ b/tensorflow_probability/python/experimental/mcmc/snaper_hmc_test.py @@ -81,7 +81,7 @@ def testEndToEndAdaptation(self): num_mala_steps = 100 eigenvalues = np.exp(np.linspace(0., 3., num_dims)) - q, r = np.linalg.qr(np.random.randn(num_dims, num_dims)) + q, r = np.linalg.qr(np.random.RandomState(0).randn(num_dims, num_dims)) q *= np.sign(np.diag(r)) covariance = (q * eigenvalues).dot(q.T).astype(self.dtype) @@ -100,20 +100,24 @@ def testEndToEndAdaptation(self): num_mala_steps=num_mala_steps, ) kernel = dassa.DualAveragingStepSizeAdaptation( - kernel, num_adaptation_steps=num_adaptation_steps) + kernel, + num_adaptation_steps=num_adaptation_steps, + target_accept_prob=0.8, + ) def trace_fn(_, pkr): return { - 'step_size': - unnest.get_innermost(pkr, 'step_size'), - 'mean_trajectory_length': - unnest.get_innermost(pkr, 'max_trajectory_length') / 2., - 'principal_component': - unnest.get_innermost(pkr, 'ema_principal_component'), - 'variance': - unnest.get_innermost(pkr, 'ema_variance'), - 'num_leapfrog_steps': - unnest.get_innermost(pkr, 'num_leapfrog_steps'), + 'step_size': unnest.get_innermost(pkr, 'step_size') / tf.sqrt( + unnest.get_innermost(pkr, 'max_ema_variance') + ), + 'mean_trajectory_length': ( + unnest.get_innermost(pkr, 'max_trajectory_length') / 2.0 + ), + 'principal_component': unnest.get_innermost( + pkr, 'ema_principal_component' + ), + 'variance': unnest.get_innermost(pkr, 'ema_variance'), + 'num_leapfrog_steps': unnest.get_innermost(pkr, 'num_leapfrog_steps'), } init_x = tf.zeros([num_chains, num_dims], self.dtype) @@ -137,7 +141,8 @@ def trace_fn(_, pkr): self.assertEqual(self.dtype, trace['principal_component'].dtype) # Adaptation results. - self.assertAllClose(1.75, trace['step_size'][-1], rtol=0.2) + # Obtained via a separate run of `windowed_adaptive_nuts`. + self.assertAllClose(0.45, trace['step_size'][-1], rtol=0.25) self.assertAllClose(4., trace['mean_trajectory_length'][-1], atol=1.) self.assertAllClose(np.diag(covariance), trace['variance'][-1], rtol=0.2) self.assertAllClose( @@ -280,7 +285,7 @@ def testEndToEnd(self): num_dims = 8 eigenvalues = np.exp(np.linspace(0., 3., num_dims)) - q, r = np.linalg.qr(np.random.randn(num_dims, num_dims)) + q, r = np.linalg.qr(np.random.RandomState(0).randn(num_dims, num_dims)) q *= np.sign(np.diag(r)) covariance = (q * eigenvalues).dot(q.T).astype(self.dtype) @@ -305,7 +310,8 @@ def run(seed): run(test_util.test_seed(sampler_type='stateless'))) self.assertEqual(self.dtype, chain.dtype) - self.assertAllClose(1.4, trace['step_size'][-1], rtol=0.2) + # Obtained via a separate run of `windowed_adaptive_nuts`. + self.assertAllClose(0.45, trace['step_size'][-1], rtol=0.25) self.assertAllClose(8., trace['max_trajectory_length'][-1], atol=2.) self.assertAllClose(chain.var((0, 1)), np.diag(covariance), rtol=0.2) self.assertAllClose( @@ -518,7 +524,7 @@ def testShardedChainAxes(self): num_dims = 8 eigenvalues = np.exp(np.linspace(0., 3., num_dims)) - q, r = np.linalg.qr(np.random.randn(num_dims, num_dims)) + q, r = np.linalg.qr(np.random.RandomState(0).randn(num_dims, num_dims)) q *= np.sign(np.diag(r)) covariance = (q * eigenvalues).dot(q.T).astype(np.float32) @@ -549,7 +555,8 @@ def run(_): ))) # Adaptation results. - self.assertAllClose(1.4, trace['step_size'][0, -1], rtol=0.2) + # Obtained via a separate run of `windowed_adaptive_nuts`. + self.assertAllClose(0.45, trace['step_size'][0, -1], rtol=0.25) self.assertAllClose(chain.var((0, 1, 2)), np.diag(covariance), rtol=0.2) # Shard consistency.