diff --git a/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py b/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py index 6a297813b2..6f65e2ae5a 100644 --- a/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py +++ b/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py @@ -834,7 +834,8 @@ def one_step(self, current_state, previous_kernel_results, seed=None): new_kernel_results = new_kernel_results._replace( inner_results=new_inner_results, step=previous_kernel_results.step + 1, - criterion=criterion) + criterion=criterion, + seed=seed) return new_state, new_kernel_results diff --git a/tensorflow_probability/python/experimental/mcmc/snaper_hmc.py b/tensorflow_probability/python/experimental/mcmc/snaper_hmc.py index f9e9ff36a5..49080e4f4a 100644 --- a/tensorflow_probability/python/experimental/mcmc/snaper_hmc.py +++ b/tensorflow_probability/python/experimental/mcmc/snaper_hmc.py @@ -67,6 +67,7 @@ class SNAPERHamiltonianMonteCarloResults( 'inner_results', 'ema_mean', 'ema_variance', + 'max_ema_variance', 'state_ema_points', 'ema_principal_component', 'principal_component_ema_points', @@ -80,6 +81,7 @@ class SNAPERHamiltonianMonteCarloResults( `GradientBasedTrajectoryLengthAdaptationResults`. ema_mean: Exponential moving average cross-chain state mean. ema_variance: Exponential moving average cross-chain state variance. + max_ema_variance: Maximum of `ema_variance`. state_ema_points: Approximate number of points used to compute the exponential moving averages. ema_principal_component: Exponential moving average cross-chain state @@ -422,7 +424,7 @@ def _max_part(x, named_axis): validate_args=self.validate_args, **gbtla_kwargs, ) - return kernel + return kernel, max_variance def _update_state_ema( self, @@ -539,7 +541,7 @@ def one_step(self, current_state, previous_kernel_results, seed=None): step = inner_results.step state_ema_points = previous_kernel_results.state_ema_points - kernel = self._make_kernel( + kernel, max_variance = self._make_kernel( batch_shape=batch_shape, step=step, state_ema_points=state_ema_points, @@ -588,6 +590,7 @@ def one_step(self, current_state, previous_kernel_results, seed=None): inner_results=inner_results, ema_mean=ema_mean, ema_variance=ema_variance, + max_ema_variance=max_variance, state_ema_points=state_ema_points, ema_principal_component=ema_principal_component, principal_component_ema_points=principal_component_ema_points, @@ -659,7 +662,7 @@ def bootstrap_results(self, init_state): state_ema_points = tf.ones([], tf.int32) principal_component_ema_points = tf.ones([], tf.int32) - kernel = self._make_kernel( + kernel, max_variance = self._make_kernel( batch_shape=batch_shape, step=tf.zeros([], tf.int32), state_ema_points=state_ema_points, @@ -675,6 +678,7 @@ def bootstrap_results(self, init_state): inner_results=inner_results, ema_mean=ema_mean, ema_variance=ema_variance, + max_ema_variance=max_variance, state_ema_points=state_ema_points, ema_principal_component=ema_principal_component, principal_component_ema_points=principal_component_ema_points, @@ -1009,23 +1013,27 @@ def default_snaper_trace_fn(state, is_burnin, kernel_results, reducer, # The ~ is here to catch NaNs. has_divergence = ~(tf.math.abs(energy_diff) < 500.) return state, { - 'step_size': - unnest.get_innermost(kr, 'step_size'), - 'n_steps': - unnest.get_innermost(kr, 'num_leapfrog_steps'), - 'tune': - is_burnin, - 'max_trajectory_length': - unnest.get_innermost(kr, 'max_trajectory_length'), - 'variance_scaling': - tf.nest.map_structure(lambda x: 1. / x, - unnest.get_innermost(kr, 'ema_variance')), - 'diverging': - has_divergence, - 'accept_ratio': - tf.minimum(tf.ones_like(energy_diff), tf.exp(energy_diff)), - 'is_accepted': - unnest.get_innermost(kr, 'is_accepted'), + # SNAPER rescales the inner HMC kernel by max_ema_variance, so to aid + # comparisons with other algorithms which typically don't do this + # rescaling, we undo the rescaling here. This makes the step size + # consistent with the target_log_prob_fn scale implied by + # `variance_scaling` below. + 'step_size': unnest.get_innermost(kr, 'step_size') / tf.sqrt( + unnest.get_innermost(kr, 'max_ema_variance') + ), + 'n_steps': unnest.get_innermost(kr, 'num_leapfrog_steps'), + 'tune': is_burnin, + 'max_trajectory_length': unnest.get_innermost( + kr, 'max_trajectory_length' + ), + 'variance_scaling': tf.nest.map_structure( + lambda x: 1.0 / x, unnest.get_innermost(kr, 'ema_variance') + ), + 'diverging': has_divergence, + 'accept_ratio': tf.minimum( + tf.ones_like(energy_diff), tf.exp(energy_diff) + ), + 'is_accepted': unnest.get_innermost(kr, 'is_accepted'), } diff --git a/tensorflow_probability/python/experimental/mcmc/snaper_hmc_test.py b/tensorflow_probability/python/experimental/mcmc/snaper_hmc_test.py index e8a3c8dcba..1d367ebe37 100644 --- a/tensorflow_probability/python/experimental/mcmc/snaper_hmc_test.py +++ b/tensorflow_probability/python/experimental/mcmc/snaper_hmc_test.py @@ -81,7 +81,7 @@ def testEndToEndAdaptation(self): num_mala_steps = 100 eigenvalues = np.exp(np.linspace(0., 3., num_dims)) - q, r = np.linalg.qr(np.random.randn(num_dims, num_dims)) + q, r = np.linalg.qr(np.random.RandomState(0).randn(num_dims, num_dims)) q *= np.sign(np.diag(r)) covariance = (q * eigenvalues).dot(q.T).astype(self.dtype) @@ -100,20 +100,24 @@ def testEndToEndAdaptation(self): num_mala_steps=num_mala_steps, ) kernel = dassa.DualAveragingStepSizeAdaptation( - kernel, num_adaptation_steps=num_adaptation_steps) + kernel, + num_adaptation_steps=num_adaptation_steps, + target_accept_prob=0.8, + ) def trace_fn(_, pkr): return { - 'step_size': - unnest.get_innermost(pkr, 'step_size'), - 'mean_trajectory_length': - unnest.get_innermost(pkr, 'max_trajectory_length') / 2., - 'principal_component': - unnest.get_innermost(pkr, 'ema_principal_component'), - 'variance': - unnest.get_innermost(pkr, 'ema_variance'), - 'num_leapfrog_steps': - unnest.get_innermost(pkr, 'num_leapfrog_steps'), + 'step_size': unnest.get_innermost(pkr, 'step_size') / tf.sqrt( + unnest.get_innermost(pkr, 'max_ema_variance') + ), + 'mean_trajectory_length': ( + unnest.get_innermost(pkr, 'max_trajectory_length') / 2.0 + ), + 'principal_component': unnest.get_innermost( + pkr, 'ema_principal_component' + ), + 'variance': unnest.get_innermost(pkr, 'ema_variance'), + 'num_leapfrog_steps': unnest.get_innermost(pkr, 'num_leapfrog_steps'), } init_x = tf.zeros([num_chains, num_dims], self.dtype) @@ -137,7 +141,8 @@ def trace_fn(_, pkr): self.assertEqual(self.dtype, trace['principal_component'].dtype) # Adaptation results. - self.assertAllClose(1.75, trace['step_size'][-1], rtol=0.2) + # Obtained via a separate run of `windowed_adaptive_nuts`. + self.assertAllClose(0.45, trace['step_size'][-1], rtol=0.25) self.assertAllClose(4., trace['mean_trajectory_length'][-1], atol=1.) self.assertAllClose(np.diag(covariance), trace['variance'][-1], rtol=0.2) self.assertAllClose( @@ -280,7 +285,7 @@ def testEndToEnd(self): num_dims = 8 eigenvalues = np.exp(np.linspace(0., 3., num_dims)) - q, r = np.linalg.qr(np.random.randn(num_dims, num_dims)) + q, r = np.linalg.qr(np.random.RandomState(0).randn(num_dims, num_dims)) q *= np.sign(np.diag(r)) covariance = (q * eigenvalues).dot(q.T).astype(self.dtype) @@ -305,7 +310,8 @@ def run(seed): run(test_util.test_seed(sampler_type='stateless'))) self.assertEqual(self.dtype, chain.dtype) - self.assertAllClose(1.4, trace['step_size'][-1], rtol=0.2) + # Obtained via a separate run of `windowed_adaptive_nuts`. + self.assertAllClose(0.45, trace['step_size'][-1], rtol=0.25) self.assertAllClose(8., trace['max_trajectory_length'][-1], atol=2.) self.assertAllClose(chain.var((0, 1)), np.diag(covariance), rtol=0.2) self.assertAllClose( @@ -518,7 +524,7 @@ def testShardedChainAxes(self): num_dims = 8 eigenvalues = np.exp(np.linspace(0., 3., num_dims)) - q, r = np.linalg.qr(np.random.randn(num_dims, num_dims)) + q, r = np.linalg.qr(np.random.RandomState(0).randn(num_dims, num_dims)) q *= np.sign(np.diag(r)) covariance = (q * eigenvalues).dot(q.T).astype(np.float32) @@ -549,7 +555,8 @@ def run(_): ))) # Adaptation results. - self.assertAllClose(1.4, trace['step_size'][0, -1], rtol=0.2) + # Obtained via a separate run of `windowed_adaptive_nuts`. + self.assertAllClose(0.45, trace['step_size'][0, -1], rtol=0.25) self.assertAllClose(chain.var((0, 1, 2)), np.diag(covariance), rtol=0.2) # Shard consistency.