diff --git a/spinoffs/oryx/oryx/version.py b/spinoffs/oryx/oryx/version.py index 5bfa72f870..d2e775b301 100644 --- a/spinoffs/oryx/oryx/version.py +++ b/spinoffs/oryx/oryx/version.py @@ -17,7 +17,7 @@ # We follow Semantic Versioning (https://semver.org/) _MAJOR_VERSION = '0' _MINOR_VERSION = '1' -_PATCH_VERSION = '2' +_PATCH_VERSION = '3' # When building releases, we can update this value on the release branch to # reflect the current release candidate ('rc0', 'rc1') or, finally, the official diff --git a/spinoffs/oryx/setup.py b/spinoffs/oryx/setup.py index b4f8124408..ab51d04ae6 100644 --- a/spinoffs/oryx/setup.py +++ b/spinoffs/oryx/setup.py @@ -20,11 +20,10 @@ REQUIRED_PACKAGES = [ 'dataclasses;python_version<"3.7"', - 'jax==0.2.0', - 'jaxlib==0.1.55', + 'jax==0.2.5', + 'jaxlib==0.1.56', # Pin a TF version while TFP-on-JAX still depends on TF - 'tfp-nightly==0.12.0.dev20200923', - 'inference_gym', + 'tfp-nightly==0.12.0.dev20201107', ] diff --git a/tensorflow_probability/python/bijectors/BUILD b/tensorflow_probability/python/bijectors/BUILD index 8259e3d66b..2ecf6e472b 100644 --- a/tensorflow_probability/python/bijectors/BUILD +++ b/tensorflow_probability/python/bijectors/BUILD @@ -130,6 +130,7 @@ multi_substrate_py_library( "//tensorflow_probability/python/internal:name_util", "//tensorflow_probability/python/internal:nest_util", "//tensorflow_probability/python/internal:tensorshape_util", + "//tensorflow_probability/python/math:gradient", ], ) diff --git a/tensorflow_probability/python/bijectors/bijector.py b/tensorflow_probability/python/bijectors/bijector.py index a6a0a65b68..4be79019bc 100644 --- a/tensorflow_probability/python/bijectors/bijector.py +++ b/tensorflow_probability/python/bijectors/bijector.py @@ -32,6 +32,7 @@ from tensorflow_probability.python.internal import name_util from tensorflow_probability.python.internal import nest_util from tensorflow_probability.python.internal import prefer_static as ps +from tensorflow_probability.python.math import gradient from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import @@ -604,6 +605,11 @@ def _is_injective(self): """ return True + @property + def _is_scalar(self): + return (tf.get_static_value(self._forward_min_event_ndims) == 0 and + tf.get_static_value(self._inverse_min_event_ndims) == 0) + @property def validate_args(self): """Returns True if Tensor arguments will be validated.""" @@ -1033,6 +1039,8 @@ def _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs): elif hasattr(self, '_forward_log_det_jacobian'): x = self.inverse(y, **kwargs) # Fall back to computing `-fldj(x)` ildj = attrs['ildj'] = -self._forward_log_det_jacobian(x, **kwargs) + elif self._is_scalar: + ildj = _autodiff_log_det_jacobian(self._inverse, y) else: raise NotImplementedError( 'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian ' @@ -1136,6 +1144,8 @@ def _call_forward_log_det_jacobian(self, x, event_ndims, name, **kwargs): elif hasattr(self, '_inverse_log_det_jacobian'): y = self.forward(x, **kwargs) # Fall back to computing `ildj(y)` ildj = attrs['ildj'] = self._inverse_log_det_jacobian(y, **kwargs) + elif self._is_scalar: + ildj = -_autodiff_log_det_jacobian(self._forward, x) else: raise NotImplementedError( 'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian ' @@ -1670,3 +1680,12 @@ def ldj_reduction_shape(shape_structure, 'LDJ reduction shape.'))) return ldj_reduce_shape, assertions + + +def _autodiff_log_det_jacobian(fn, x): + """Automatically compute the log det jacobian of a scalar function.""" + _, grads = gradient.value_and_gradient(fn, x) + if grads is None: + raise ValueError('Cannot compute log det jacobian; function {} has `None` ' + 'gradient.'.format(fn)) + return tf.math.log(tf.abs(grads)) diff --git a/tensorflow_probability/python/bijectors/bijector_test.py b/tensorflow_probability/python/bijectors/bijector_test.py index 3f098dbd5e..ab48289fef 100644 --- a/tensorflow_probability/python/bijectors/bijector_test.py +++ b/tensorflow_probability/python/bijectors/bijector_test.py @@ -80,12 +80,12 @@ def __init__(self): with self.assertRaisesRegexp( NotImplementedError, - 'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian.*'): + 'inverse not implemented'): bij.inverse_log_det_jacobian(0, event_ndims=0) with self.assertRaisesRegexp( NotImplementedError, - 'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian.*'): + 'forward not implemented'): bij.forward_log_det_jacobian(0, event_ndims=0) @test_util.disable_test_for_backend( @@ -124,6 +124,53 @@ def _forward(self, x): error_clazz, 'Tensor conversion requested dtype'): b64.forward(x32) + @test_util.numpy_disable_gradient_test + def testAutodiffLogDetJacobian(self): + + class NoJacobianBijector(tfb.Bijector): + """Bijector with no log det jacobian methods.""" + + def __init__(self, scale=2.): + parameters = dict(locals()) + self._scale = tensor_util.convert_nonref_to_tensor(scale) + super(NoJacobianBijector, self).__init__( + validate_args=True, + forward_min_event_ndims=0, + parameters=parameters) + + def _forward(self, x): + return tf.exp(self._scale * x) + + def _inverse(self, y): + return tf.math.log(y) / self._scale + + b = NoJacobianBijector(scale=1.4) + x = tf.convert_to_tensor([2., -3.]) + [ + fldj, + true_fldj, + ildj + ] = self.evaluate([ + b.forward_log_det_jacobian(x, event_ndims=0), + tf.math.log(b._scale) + b._scale * x, + b.inverse_log_det_jacobian(b.forward(x), event_ndims=0) + ]) + self.assertAllClose(fldj, true_fldj) + self.assertAllClose(fldj, -ildj) + + y = tf.convert_to_tensor([27., 5.]) + [ + ildj, + true_ildj, + fldj + ] = self.evaluate([ + b.inverse_log_det_jacobian(y, event_ndims=0), + -tf.math.log(tf.abs(y * b._scale)), + b.forward_log_det_jacobian(b.inverse(y), event_ndims=0) + ]) + self.assertAllClose(ildj, true_ildj) + self.assertAllClose(ildj, -fldj) + class IntentionallyMissingError(Exception): pass diff --git a/tensorflow_probability/python/bijectors/glow.py b/tensorflow_probability/python/bijectors/glow.py index 54aed0961b..9deaf17126 100644 --- a/tensorflow_probability/python/bijectors/glow.py +++ b/tensorflow_probability/python/bijectors/glow.py @@ -205,6 +205,8 @@ class Glow(chain.Chain): from functools import reduce from operator import mul + import tensorflow as tf + import tensorflow_datasets as tfds import tensorflow_probability as tfp tfb = tfp.bijectors tfd = tfp.distributions diff --git a/tensorflow_probability/python/distributions/joint_distribution_coroutine.py b/tensorflow_probability/python/distributions/joint_distribution_coroutine.py index cef887b10c..d5b442fd2c 100644 --- a/tensorflow_probability/python/distributions/joint_distribution_coroutine.py +++ b/tensorflow_probability/python/distributions/joint_distribution_coroutine.py @@ -219,8 +219,10 @@ def __init__( """ parameters = dict(locals()) with tf.name_scope(name or 'JointDistributionCoroutine') as name: - self._sample_dtype = sample_dtype self._model_coroutine = model + # Hint `no_dependency` to tell tf.Module not to screw up the sample dtype + # with extraneous wrapping (list => ListWrapper, etc.). + self._sample_dtype = self._no_dependency(sample_dtype) self._single_sample_distributions = {} super(JointDistributionCoroutine, self).__init__( dtype=sample_dtype, diff --git a/tensorflow_probability/python/distributions/joint_distribution_coroutine_test.py b/tensorflow_probability/python/distributions/joint_distribution_coroutine_test.py index 4aca2df9d8..5079ae8ebc 100644 --- a/tensorflow_probability/python/distributions/joint_distribution_coroutine_test.py +++ b/tensorflow_probability/python/distributions/joint_distribution_coroutine_test.py @@ -624,6 +624,7 @@ def noncentered_horseshoe_prior(num_features): tfd.Sample(tfd.Normal(0., 1.), num_features)) yield tfd.Independent(tfd.Deterministic(weights_noncentered * scale), reinterpreted_batch_ndims=1) + # Currently sample_dtype is only used for `tf.nest.pack_structure_as`. In # the future we may use it for error checking and/or casting. sample_dtype = collections.namedtuple('Model', [ @@ -645,6 +646,18 @@ def noncentered_horseshoe_prior(num_features): self.assertEqual([3, 4], joint.log_prob( joint.sample([3, 4], seed=test_util.test_seed())).shape) + # Check that a list dtype doesn't get corrupted by `tf.Module` wrapping. + sample_dtype = [None, None, None, None] + joint = tfd.JointDistributionCoroutine( + lambda: noncentered_horseshoe_prior(4), + sample_dtype=sample_dtype, + validate_args=True) + ds, xs = joint.sample_distributions([2, 3], seed=test_util.test_seed()) + self.assertEqual(type(sample_dtype), type(xs)) + self.assertEqual(type(sample_dtype), type(ds)) + tf.nest.assert_same_structure(sample_dtype, ds) + tf.nest.assert_same_structure(sample_dtype, xs) + def test_repr_with_custom_sample_dtype(self): def model(): s = yield tfd.JointDistributionCoroutine.Root( diff --git a/tensorflow_probability/python/distributions/platform_compatibility_test.py b/tensorflow_probability/python/distributions/platform_compatibility_test.py index 41cd5622d2..b38794cc5a 100644 --- a/tensorflow_probability/python/distributions/platform_compatibility_test.py +++ b/tensorflow_probability/python/distributions/platform_compatibility_test.py @@ -122,7 +122,7 @@ 'BetaBinomial': 1e-5, 'CholeskyLKJ': 1e-4, 'LKJ': 1e-3, - 'PowerSpherical': 1e-5, + 'PowerSpherical': 2e-5, }) VECTORIZED_LOGPROB_RTOL = collections.defaultdict(lambda: 1e-6) diff --git a/tensorflow_probability/python/distributions/sample.py b/tensorflow_probability/python/distributions/sample.py index a916f08480..c1d8325034 100644 --- a/tensorflow_probability/python/distributions/sample.py +++ b/tensorflow_probability/python/distributions/sample.py @@ -327,9 +327,9 @@ def _parameter_control_dependencies(self, is_init): return assertions - _composite_tensor_nonshape_params = ('distribution,') + _composite_tensor_nonshape_params = ('distribution',) - _composite_tensor_shape_params = ('sample_shape,') + _composite_tensor_shape_params = ('sample_shape',) class _DefaultSampleBijector(bijector_lib.Bijector): diff --git a/tensorflow_probability/python/experimental/mcmc/expectations_reducer_test.py b/tensorflow_probability/python/experimental/mcmc/expectations_reducer_test.py index 6de891acd2..03330fe0fa 100644 --- a/tensorflow_probability/python/experimental/mcmc/expectations_reducer_test.py +++ b/tensorflow_probability/python/experimental/mcmc/expectations_reducer_test.py @@ -126,9 +126,9 @@ def test_in_with_reductions(self): ) pkr = reduced_kernel.bootstrap_results(8) _, kernel_results = reduced_kernel.one_step(8, pkr) - streaming_calculations = self.evaluate( - mean_reducer.finalize(kernel_results.streaming_calculations)) - self.assertEqual(9, streaming_calculations) + reduction_results = self.evaluate( + mean_reducer.finalize(kernel_results.reduction_results)) + self.assertEqual(9, reduction_results) def test_in_step_kernel(self): fake_kernel = test_fixtures.TestTransitionKernel() @@ -142,9 +142,9 @@ def test_in_step_kernel(self): kernel=reduced_kernel, return_final_kernel_results=True, ) - streaming_calculations = self.evaluate( - mean_reducer.finalize(kernel_results.streaming_calculations)) - self.assertEqual(11, streaming_calculations) + reduction_results = self.evaluate( + mean_reducer.finalize(kernel_results.reduction_results)) + self.assertEqual(11, reduction_results) if __name__ == '__main__': diff --git a/tensorflow_probability/python/experimental/mcmc/kernel_outputs.py b/tensorflow_probability/python/experimental/mcmc/kernel_outputs.py index 7d21e9aafc..3723895cdf 100644 --- a/tensorflow_probability/python/experimental/mcmc/kernel_outputs.py +++ b/tensorflow_probability/python/experimental/mcmc/kernel_outputs.py @@ -77,7 +77,7 @@ def _process_results(self): reducers, lambda r, s: r.finalize(s), reducers, - unnest.get_outermost(self.results, 'streaming_calculations'), + unnest.get_outermost(self.results, 'reduction_results'), check_types=False) # Grab useful reductions. diff --git a/tensorflow_probability/python/experimental/mcmc/potential_scale_reduction_reducer.py b/tensorflow_probability/python/experimental/mcmc/potential_scale_reduction_reducer.py index 8a36e7b55a..a6215cb83e 100644 --- a/tensorflow_probability/python/experimental/mcmc/potential_scale_reduction_reducer.py +++ b/tensorflow_probability/python/experimental/mcmc/potential_scale_reduction_reducer.py @@ -35,7 +35,7 @@ PotentialScaleReductionReducerState = collections.namedtuple( - 'PotentialScaleReductionReducerState', 'init_state, rhat_state') + 'PotentialScaleReductionReducerState', 'rhat_state') class PotentialScaleReductionReducer(reducer_base.Reducer): @@ -114,18 +114,21 @@ def initialize(self, initial_chain_state, initial_kernel_results=None): mcmc_util.make_name( self.name, 'potential_scale_reduction_reducer', 'initialize')): initial_chain_state = tf.nest.map_structure( - tf.convert_to_tensor, + tf.convert_to_tensor, initial_chain_state) + sample_shape = tf.nest.map_structure( + lambda chain_state: tuple(ps.shape(chain_state)), + initial_chain_state) + chain_ndims = tf.nest.map_structure( + lambda chain_state: self.independent_chain_ndims, + initial_chain_state) + dtype = tf.nest.map_structure( + lambda chain_state: chain_state.dtype, initial_chain_state) - sample_shape, chain_ndims, dtype = _prepare_args( - initial_chain_state, self.independent_chain_ndims - ) - running_rhat = sample_stats.RunningPotentialScaleReduction( + rhat = sample_stats.RunningPotentialScaleReduction.from_shape( shape=sample_shape, independent_chain_ndims=chain_ndims, - dtype=dtype - ) - return PotentialScaleReductionReducerState( - initial_chain_state, running_rhat.initialize()) + dtype=dtype) + return PotentialScaleReductionReducerState(rhat) def one_step( self, @@ -156,20 +159,8 @@ def one_step( new_chain_state = tf.nest.map_structure( tf.convert_to_tensor, new_chain_state) - sample_shape, chain_ndims, dtype = _prepare_args( - new_chain_state, self.independent_chain_ndims - ) - running_rhat = sample_stats.RunningPotentialScaleReduction( - shape=sample_shape, - independent_chain_ndims=chain_ndims, - dtype=dtype - ) - new_rhat_state = running_rhat.update( - current_reducer_state.rhat_state, - new_chain_state) - return PotentialScaleReductionReducerState( - current_reducer_state.init_state, - new_rhat_state) + new_rhat = current_reducer_state.rhat_state.update(new_chain_state) + return PotentialScaleReductionReducerState(new_rhat) def finalize(self, final_reducer_state): """Finalizes R-hat calculation from the `final_reducer_state`. @@ -181,18 +172,10 @@ def finalize(self, final_reducer_state): Returns: rhat: an estimate of the R-hat. """ - sample_shape, chain_ndims, dtype = _prepare_args( - final_reducer_state.init_state, self.independent_chain_ndims - ) with tf.name_scope( mcmc_util.make_name( self.name, 'potential_scale_reduction_reducer', 'finalize')): - running_rhat = sample_stats.RunningPotentialScaleReduction( - shape=sample_shape, - independent_chain_ndims=chain_ndims, - dtype=dtype, - ) - return running_rhat.finalize(final_reducer_state.rhat_state) + return final_reducer_state.rhat_state.potential_scale_reduction() @property def parameters(self): @@ -205,20 +188,3 @@ def independent_chain_ndims(self): @property def name(self): return self._parameters['name'] - - -def _prepare_args(target, chain_ndims): - """Infers metadata to instantiate a streaming rhat object from `target`.""" - sample_shape = tf.nest.map_structure( - lambda chain_state: tuple(ps.shape(chain_state)), - target - ) - nested_chain_ndims = tf.nest.map_structure( - lambda _: chain_ndims, - target - ) - dtype = tf.nest.map_structure( - lambda chain_state: chain_state.dtype, - target - ) - return sample_shape, nested_chain_ndims, dtype diff --git a/tensorflow_probability/python/experimental/mcmc/potential_scale_reduction_reducer_test.py b/tensorflow_probability/python/experimental/mcmc/potential_scale_reduction_reducer_test.py index e75ba27910..6bf1080df5 100644 --- a/tensorflow_probability/python/experimental/mcmc/potential_scale_reduction_reducer_test.py +++ b/tensorflow_probability/python/experimental/mcmc/potential_scale_reduction_reducer_test.py @@ -142,7 +142,7 @@ def test_in_with_reductions(self): chain_state, pkr = reduced_kernel.one_step( chain_state, pkr) rhat = self.evaluate( - rhat_reducer.finalize(pkr.streaming_calculations)) + rhat_reducer.finalize(pkr.reduction_results)) self.assertEqual(0.5, rhat) def test_iid_normal_passes(self): diff --git a/tensorflow_probability/python/experimental/mcmc/sample_discarding_kernel_test.py b/tensorflow_probability/python/experimental/mcmc/sample_discarding_kernel_test.py index b71f1277cd..0d07cf3a9e 100644 --- a/tensorflow_probability/python/experimental/mcmc/sample_discarding_kernel_test.py +++ b/tensorflow_probability/python/experimental/mcmc/sample_discarding_kernel_test.py @@ -146,7 +146,7 @@ def test_with_composed_kernel(self): for _ in range(2): current_state, kernel_results = reducer_kernel.one_step( current_state, kernel_results) - cov = cov_reducer.finalize(kernel_results.streaming_calculations) + cov = cov_reducer.finalize(kernel_results.reduction_results) self.assertAllEqual(16, current_state) self.assertAllEqual(2, kernel_results.inner_results.call_counter) self.assertAllEqual( diff --git a/tensorflow_probability/python/experimental/mcmc/sample_fold.py b/tensorflow_probability/python/experimental/mcmc/sample_fold.py index 7fe5a472fc..ea82ac8275 100644 --- a/tensorflow_probability/python/experimental/mcmc/sample_fold.py +++ b/tensorflow_probability/python/experimental/mcmc/sample_fold.py @@ -157,7 +157,7 @@ def sample_fold( reducer, lambda r, s: r.finalize(s), reducer, - final_kernel_results.streaming_calculations, + final_kernel_results.reduction_results, check_types=False) if reducer_was_none: reduction_results = None @@ -169,7 +169,7 @@ def sample_fold( return (reduction_results, end_state, final_kernel_results.inner_results.inner_results, - final_kernel_results.streaming_calculations) + final_kernel_results.reduction_results) else: return (reduction_results, end_state, diff --git a/tensorflow_probability/python/experimental/mcmc/with_reductions.py b/tensorflow_probability/python/experimental/mcmc/with_reductions.py index c005993987..f1bcd9283f 100644 --- a/tensorflow_probability/python/experimental/mcmc/with_reductions.py +++ b/tensorflow_probability/python/experimental/mcmc/with_reductions.py @@ -36,7 +36,7 @@ class WithReductionsKernelResults( mcmc_util.PrettyNamedTupleMixin, collections.namedtuple('WithReductionsKernelResults', - ['streaming_calculations', + ['reduction_results', 'inner_results'])): """Reducer state and diagnostics for `WithReductions`.""" __slots__ = () @@ -56,7 +56,7 @@ class WithReductions(kernel_base.TransitionKernel): `Reducer`. `WithReductions` operates by generating a sample with its `inner_kernel`'s `one_step`, then invoking each `Reducer`'s `one_step` method on that sample. The updated reducer - states are stored in the `streaming_calculations` field of + states are stored in the `reduction_results` field of `WithReductions`' kernel results. """ @@ -93,7 +93,7 @@ def one_step( representing the current state(s) of the Markov chain(s), previous_kernel_results: `WithReductionsKernelResults` named tuple. `WithReductionsKernelResults` contain the state of - `streaming_calculations` and a reference to kernel results of + `reduction_results` and a reference to kernel results of nested `TransitionKernel`s. seed: Optional seed for reproducible sampling. @@ -101,7 +101,7 @@ def one_step( new_state: Newest MCMC state drawn from the `inner_kernel`. kernel_results: `WithReductionsKernelResults` representing updated kernel results. Reducer states are stored in the - `streaming_calculations` field. The state structure is identical + `reduction_results` field. The state structure is identical to `self.reducer`. """ with tf.name_scope( @@ -117,7 +117,7 @@ def step_reducer(r, state): new_reducer_state = nest.map_structure_up_to( self.reducer, step_reducer, - self.reducer, previous_kernel_results.streaming_calculations, + self.reducer, previous_kernel_results.reduction_results, check_types=False) kernel_results = WithReductionsKernelResults( new_reducer_state, inner_kernel_results) @@ -141,7 +141,7 @@ def bootstrap_results(self, init_state, inner_results=None, Returns: kernel_results: `WithReductionsKernelResults` representing updated kernel results. Reducer states are stored in the - `streaming_calculations` field. The state structure is identical + `reduction_results` field. The state structure is identical to `self.reducer`. """ with tf.name_scope( diff --git a/tensorflow_probability/python/experimental/mcmc/with_reductions_test.py b/tensorflow_probability/python/experimental/mcmc/with_reductions_test.py index 6177cd09f2..3a323de55c 100644 --- a/tensorflow_probability/python/experimental/mcmc/with_reductions_test.py +++ b/tensorflow_probability/python/experimental/mcmc/with_reductions_test.py @@ -42,7 +42,7 @@ def test_simple_operation(self): new_sample, kernel_results = reducer_kernel.one_step(0., pkr) new_sample, kernel_results = self.evaluate([ new_sample, kernel_results]) - self.assertEqual(1, kernel_results.streaming_calculations) + self.assertEqual(1, kernel_results.reduction_results) self.assertEqual(1, new_sample) self.assertEqual(1, kernel_results.inner_results.counter_1) self.assertEqual(2, kernel_results.inner_results.counter_2) @@ -55,7 +55,7 @@ def test_boostrap_results(self): reducer=fake_reducer, ) pkr = self.evaluate(reducer_kernel.bootstrap_results(9.)) - self.assertEqual(0, pkr.streaming_calculations, 0) + self.assertEqual(0, pkr.reduction_results, 0) self.assertEqual(0, pkr.inner_results.counter_1, 0) self.assertEqual(0, pkr.inner_results.counter_2, 0) @@ -99,7 +99,7 @@ def _loop_body(i, curr_state, pkr): new_sample, kernel_results = self.evaluate([ new_sample, kernel_results]) - self.assertEqual(6, kernel_results.streaming_calculations) + self.assertEqual(6, kernel_results.reduction_results) self.assertEqual(6, new_sample) self.assertEqual(6, kernel_results.inner_results.counter_1) self.assertEqual(12, kernel_results.inner_results.counter_2) @@ -120,16 +120,16 @@ def test_nested_reducers(self): new_sample, kernel_results]) self.assertEqual( - 2, len(kernel_results.streaming_calculations[0])) + 2, len(kernel_results.reduction_results[0])) self.assertEqual( - 1, len(kernel_results.streaming_calculations[1])) + 1, len(kernel_results.reduction_results[1])) self.assertEqual( (2,), - np.array(kernel_results.streaming_calculations).shape) + np.array(kernel_results.reduction_results).shape) self.assertAllEqual( [[1, 1], [1]], - kernel_results.streaming_calculations) + kernel_results.reduction_results) self.assertEqual(1, new_sample) self.assertEqual(1, kernel_results.inner_results.counter_1) self.assertEqual(2, kernel_results.inner_results.counter_2) @@ -151,16 +151,16 @@ def test_nested_state_dependent_reducers(self): self.assertEqual( 2, - len(kernel_results.streaming_calculations[0])) + len(kernel_results.reduction_results[0])) self.assertEqual( 1, - len(kernel_results.streaming_calculations[1])) + len(kernel_results.reduction_results[1])) self.assertEqual( (2,), - np.array(kernel_results.streaming_calculations).shape) + np.array(kernel_results.reduction_results).shape) self.assertAllEqualNested( - kernel_results.streaming_calculations, + kernel_results.reduction_results, [[[1, 1], [1, 1]], [[1, 1]]], ) self.assertEqual(1, new_sample) @@ -186,9 +186,9 @@ def test_covariance_reducer(self, ddof): chain_state, kernel_results) final_cov = self.evaluate( - cov_reducer.finalize(kernel_results.streaming_calculations)) + cov_reducer.finalize(kernel_results.reduction_results)) self.assertAllEqual( - 3.5, kernel_results.streaming_calculations.cov_state.mean) + 3.5, kernel_results.reduction_results.cov_state.mean) self.assertNear( np.cov(np.arange(1, 7), ddof=ddof).tolist(), final_cov, @@ -208,9 +208,9 @@ def test_covariance_with_batching(self): for _ in range(6): state, kernel_results = reducer_kernel.one_step( state, kernel_results) - final_cov = cov_reducer.finalize(kernel_results.streaming_calculations) + final_cov = cov_reducer.finalize(kernel_results.reduction_results) self.assertEqual( - (9, 3), kernel_results.streaming_calculations.cov_state.mean.shape) + (9, 3), kernel_results.reduction_results.cov_state.mean.shape) self.assertEqual((9, 3, 3), final_cov.shape) @parameterized.parameters(0, 1) @@ -228,9 +228,9 @@ def test_variance_reducer(self, ddof): chain_state, kernel_results) final_var = self.evaluate( - reducer.finalize(kernel_results.streaming_calculations)) + reducer.finalize(kernel_results.reduction_results)) self.assertAllEqual( - 3.5, kernel_results.streaming_calculations.cov_state.mean) + 3.5, kernel_results.reduction_results.cov_state.mean) self.assertNear( np.var(np.arange(1, 7), ddof=ddof).tolist(), final_var, @@ -266,8 +266,8 @@ def test_multivariate_normal_covariance_with_sample_chain(self): ) samples, mean, final_cov = self.evaluate([ samples, - kernel_results.streaming_calculations.cov_state.mean, - cov_reducer.finalize(kernel_results.streaming_calculations) + kernel_results.reduction_results.cov_state.mean, + cov_reducer.finalize(kernel_results.reduction_results) ]) self.assertAllClose(np.mean(samples, axis=0), mean, rtol=1e-6) self.assertAllClose(np.cov(samples.T, ddof=0), final_cov, rtol=1e-6) @@ -286,10 +286,10 @@ def test_covariance_with_step_kernel(self): return_final_kernel_results=True, ) final_cov = self.evaluate( - cov_reducer.finalize(kernel_results.streaming_calculations)) + cov_reducer.finalize(kernel_results.reduction_results)) self.assertAllEqual(6, chain_state) self.assertAllEqual( - 3.5, kernel_results.streaming_calculations.cov_state.mean) + 3.5, kernel_results.reduction_results.cov_state.mean) self.assertNear( np.cov(np.arange(1, 7), ddof=0).tolist(), final_cov, @@ -319,11 +319,11 @@ def test_covariance_before_transformation(self): samples, final_cov = self.evaluate([ samples, cov_reducer.finalize( - kernel_results.inner_results.streaming_calculations) + kernel_results.inner_results.reduction_results) ]) self.assertAllClose( np.mean(np.log(samples), axis=0), - kernel_results.inner_results.streaming_calculations.cov_state.mean, + kernel_results.inner_results.reduction_results.cov_state.mean, rtol=1e-6) self.assertAllClose( np.cov(np.log(samples).T, ddof=0), final_cov, rtol=1e-6) @@ -350,11 +350,11 @@ def test_covariance_after_transformation(self): samples, final_cov = self.evaluate([ samples, cov_reducer.finalize( - kernel_results.streaming_calculations) + kernel_results.reduction_results) ]) self.assertAllClose( np.mean(samples, axis=0), - kernel_results.streaming_calculations.cov_state.mean, + kernel_results.reduction_results.cov_state.mean, rtol=1e-6) self.assertAllClose( np.cov(samples.T, ddof=0), final_cov, rtol=1e-6) @@ -382,12 +382,12 @@ def test_nested_in_step_size_adaptation(self): trace_fn=None, return_final_kernel_results=True, seed=test_util.test_seed()) - mean = kernel_results.inner_results.streaming_calculations.cov_state.mean + mean = kernel_results.inner_results.reduction_results.cov_state.mean samples, mean, final_cov = self.evaluate([ samples, mean, cov_reducer.finalize( - kernel_results.inner_results.streaming_calculations) + kernel_results.inner_results.reduction_results) ]) self.assertEqual((2,), mean.shape) @@ -413,18 +413,18 @@ def test_nested_reducers(self): final_cov, final_mean = self.evaluate([ cov_reducer.finalize( - kernel_results.streaming_calculations[0][1]), + kernel_results.reduction_results[0][1]), mean_reducer.finalize( - kernel_results.streaming_calculations[0][0]) + kernel_results.reduction_results[0][0]) ]) - self.assertEqual(2, len(kernel_results.streaming_calculations)) - self.assertEqual(2, len(kernel_results.streaming_calculations[0])) - self.assertEqual(1, len(kernel_results.streaming_calculations[1])) + self.assertEqual(2, len(kernel_results.reduction_results)) + self.assertEqual(2, len(kernel_results.reduction_results[0])) + self.assertEqual(1, len(kernel_results.reduction_results[1])) self.assertEqual(3.5, final_mean) self.assertAllEqual( - 3.5, kernel_results.streaming_calculations[0][1].cov_state.mean) - self.assertAllEqual(6, kernel_results.streaming_calculations[1][0]) + 3.5, kernel_results.reduction_results[0][1].cov_state.mean) + self.assertAllEqual(6, kernel_results.reduction_results[1][0]) self.assertNear( np.cov(np.arange(1, 7), ddof=0).tolist(), final_cov, diff --git a/tensorflow_probability/python/experimental/stats/__init__.py b/tensorflow_probability/python/experimental/stats/__init__.py index bd11ee02e7..b2ea559b69 100644 --- a/tensorflow_probability/python/experimental/stats/__init__.py +++ b/tensorflow_probability/python/experimental/stats/__init__.py @@ -24,7 +24,6 @@ from tensorflow_probability.python.experimental.stats.sample_stats import RunningMean from tensorflow_probability.python.experimental.stats.sample_stats import RunningMeanState from tensorflow_probability.python.experimental.stats.sample_stats import RunningPotentialScaleReduction -from tensorflow_probability.python.experimental.stats.sample_stats import RunningPotentialScaleReductionState from tensorflow_probability.python.experimental.stats.sample_stats import RunningVariance @@ -35,6 +34,5 @@ 'RunningMean', 'RunningMeanState', 'RunningPotentialScaleReduction', - 'RunningPotentialScaleReductionState', 'RunningVariance', ] diff --git a/tensorflow_probability/python/experimental/stats/sample_stats.py b/tensorflow_probability/python/experimental/stats/sample_stats.py index 708f99f3b7..7ab6ea77c8 100644 --- a/tensorflow_probability/python/experimental/stats/sample_stats.py +++ b/tensorflow_probability/python/experimental/stats/sample_stats.py @@ -41,7 +41,6 @@ 'RunningMean', 'RunningMeanState', 'RunningPotentialScaleReduction', - 'RunningPotentialScaleReductionState', 'RunningVariance', ] @@ -593,12 +592,9 @@ def _n_choose_k(self, n, k): return math.factorial(n) // math.factorial(k) // math.factorial(n - k) -RunningPotentialScaleReductionState = collections.namedtuple( - 'RunningPotentialScaleReductionState', 'chain_var') - - +@auto_composite_tensor.auto_composite_tensor(omit_kwargs='name') class RunningPotentialScaleReduction(object): - """Holds metadata for and computes a running R-hat diagnostic statistic. + """A running R-hat diagnostic. `RunningPotentialScaleReduction` uses Gelman and Rubin (1992)'s potential scale reduction (also known as R-hat) for chain convergence [1]. @@ -616,14 +612,9 @@ class RunningPotentialScaleReduction(object): independent chain dimensions is defined by the `independent_chain_ndims` parameter at initialization. - `RunningPotentialScaleReduction` objects do not hold state information. That - information, which includes intermediate calculations, are held in a - `RunningPotentialScaleReductionState` as returned via `initialize` and - `update` method calls. - `RunningPotentialScaleReduction` is meant to serve general streaming R-hat. For a specialized version that fits streaming over MCMC samples, see - `RhatReducer` in `tfp.experimental.mcmc`. + `PotentialScaleReductionReducer` in `tfp.experimental.mcmc`. #### References @@ -631,13 +622,32 @@ class RunningPotentialScaleReduction(object): Using Multiple Sequences. _Statistical Science_, 7(4):457-472, 1992. """ - def __init__(self, shape, independent_chain_ndims, dtype=tf.float32): - """Instantiates this object. + def __init__(self, chain_variances, independent_chain_ndims): + """Construct a `RunningPotentialScaleReduction`. Args: - shape: Python `Tuple` or `TensorShape` representing the shape of - incoming samples. Using a collection implies that future samples - will mimic that exact structure. + chain_variances: A `RunningVariance` or nested structure of + `RunningVariance`s, giving the variance estimates for the variables of + interest. + independent_chain_ndims: A Python `int` or structure of Python `ints` + parallel to `chain_variances` giving the number of leading dimensions in + `chain_variances` that index the independent chains over which the + potential scale reduction factor should be computed. Must be at least + 1. + """ + self.chain_variances = chain_variances + self.independent_chain_ndims = independent_chain_ndims + + @classmethod + def from_shape(cls, shape=(), independent_chain_ndims=1, dtype=tf.float32): + """Starts an empty `RunningPotentialScaleReduction` from metadata. + + Args: + shape: Python `Tuple` or `TensorShape` representing the shape of incoming + samples. Using a collection implies that future samples will mimic that + exact structure. This is useful to supply if the + `RunningPotentialScaleReduction` will be carried by a `tf.while_loop`, + so that broadcasting does not change the shape across loop iterations. independent_chain_ndims: Integer or Integer type `Tensor` with value `>= 1` giving the number of leading dimensions holding independent chain results to be tested for convergence. Using a collection @@ -647,101 +657,80 @@ def __init__(self, shape, independent_chain_ndims, dtype=tf.float32): cast to corresponding floats (i.e. `tf.int32` will be cast to `tf.float32`), as intermediate calculations should be performing floating-point division. - """ - self.shape = shape - self.independent_chain_ndims = independent_chain_ndims - def _cast_dtype(dtype): - if dtype_util.as_numpy_dtype(dtype) is np.int64: - return tf.float64 - elif dtype_util.is_integer(dtype): - return tf.float32 - return dtype - self.dtype = tf.nest.map_structure(_cast_dtype, dtype) - - def initialize(self): - """Initializes an empty `RunningPotentialScaleReductionState`. Returns: - state: `RunningPotentialScaleReductionState` representing a stream + state: `RunningPotentialScaleReduction` representing a stream of no inputs. """ - broadcasted_dtype = nest_util.broadcast_structure( - self.independent_chain_ndims, self.dtype) - chain_var = nest.map_structure_up_to( - self.independent_chain_ndims, + dtype = tf.nest.map_structure(_float_dtype_like, dtype) + + dtype = nest_util.broadcast_structure(independent_chain_ndims, dtype) + chain_variances = nest.map_structure_up_to( + independent_chain_ndims, RunningVariance.from_shape, - self.shape, - broadcasted_dtype, - check_types=False - ) - return RunningPotentialScaleReductionState(chain_var) + shape, + dtype, + check_types=False) + return cls(chain_variances, independent_chain_ndims) - def update(self, state, new_sample): - """Update the `RunningPotentialScaleReductionState` with a new sample. + def update(self, new_sample): + """Update the `RunningPotentialScaleReduction` with a new sample. Args: - state: `RunningPotentialScaleReductionState` that represents the - current state of running statistics. new_sample: Incoming `Tensor` sample or (possibly nested) collection of `Tensor`s with shape and dtype compatible with those used to form the - `RunningPotentialScaleReductionState`. + `RunningPotentialScaleReduction`. Returns: - state: `RunningPotentialScaleReductionState` with updated calculations. + state: `RunningPotentialScaleReduction` updated to include the new sample. """ - def _update_for_one_state(chain_var, new_sample): + def _update_for_one_state(chain_variances, new_sample): """Updates the running variance for one group of Markov chains.""" # TODO(axch): chunking could be reasonably added here by accepting and # including the chunked axis to the running variance object - return chain_var.update(new_sample) - updated_chain_vars = nest.map_structure_up_to( - self.independent_chain_ndims, + return chain_variances.update(new_sample) + updated_chain_variancess = tf.nest.map_structure( _update_for_one_state, - state.chain_var, + self.chain_variances, new_sample, check_types=False ) - return RunningPotentialScaleReductionState(updated_chain_vars) + return type(self)(updated_chain_variancess, self.independent_chain_ndims) - def finalize(self, state): - """Finalizes potential scale reduction computation for the `state`. - - Args: - state: `RunningPotentialScaleReductionState` that represents - the current state of running statistics. + def potential_scale_reduction(self): + """Computes the potential scale reduction for samples accumulated so far. Returns: rhat: An estimate of the R-hat. """ - def _finalize_for_one_state(shape, chain_ndims, chain_var): + def _finalize_for_one_state(chain_ndims, chain_variances): """Calculates R-hat for one group of Markov chains.""" # using notation from Brooks and Gelman (1998), # n := num samples / chain; m := number of chains - n = chain_var.num_samples + n = chain_variances.num_samples + shape = chain_variances.mean.shape m = tf.cast( functools.reduce((lambda x, y: x * y), (shape[:chain_ndims])), n.dtype) # b/n is the between-chain variance (the variance of the chain means) b_div_n = diagnostic._reduce_variance( # pylint:disable=protected-access - tf.convert_to_tensor(chain_var.mean), + tf.convert_to_tensor(chain_variances.mean), axis=tf.range(chain_ndims), biased=False) # W is the within sequence variance (the mean of the chain variances) sum_of_chain_squared_residuals = tf.reduce_sum( - chain_var.sum_squared_residuals, axis=tf.range(chain_ndims)) + chain_variances.sum_squared_residuals, axis=tf.range(chain_ndims)) w = sum_of_chain_squared_residuals / (m * (n - 1)) # the `true_variance_estimate` is denoted as sigma^2_+ in the 1998 paper true_variance_estimate = ((n - 1) / n) * w + b_div_n return ((m + 1.) / m) * true_variance_estimate / w - (n - 1.) / (m * n) - return nest.map_structure_up_to( - self.independent_chain_ndims, + return tf.nest.map_structure( _finalize_for_one_state, - self.shape, self.independent_chain_ndims, - state.chain_var, + self.chain_variances, check_types=False ) diff --git a/tensorflow_probability/python/experimental/stats/sample_stats_test.py b/tensorflow_probability/python/experimental/stats/sample_stats_test.py index 6e0f4882e1..2d67aa8150 100644 --- a/tensorflow_probability/python/experimental/stats/sample_stats_test.py +++ b/tensorflow_probability/python/experimental/stats/sample_stats_test.py @@ -345,16 +345,14 @@ def test_tf_while(self): class RunningPotentialScaleReductionTest(test_util.TestCase): def test_simple_operation(self): - running_rhat = tfp.experimental.stats.RunningPotentialScaleReduction( + running_rhat = tfp.experimental.stats.RunningPotentialScaleReduction.from_shape( shape=(3,), - independent_chain_ndims=1 ) - state = running_rhat.initialize() # 5 samples from 3 independent Markov chains x = np.arange(15, dtype=np.float32).reshape((5, 3)) for sample in x: - state = running_rhat.update(state, sample) - rhat = running_rhat.finalize(state) + running_rhat = running_rhat.update(sample) + rhat = running_rhat.potential_scale_reduction() true_rhat = tfp.mcmc.potential_scale_reduction( chains_states=x, independent_chain_ndims=1, @@ -365,14 +363,12 @@ def test_simple_operation(self): def test_random_scalar_computation(self): rng = test_util.test_np_rng() x = rng.rand(100, 10) * 100 - running_rhat = tfp.experimental.stats.RunningPotentialScaleReduction( + running_rhat = tfp.experimental.stats.RunningPotentialScaleReduction.from_shape( shape=(10,), - independent_chain_ndims=1 ) - state = running_rhat.initialize() for sample in x: - state = running_rhat.update(state, sample) - rhat = running_rhat.finalize(state) + running_rhat = running_rhat.update(sample) + rhat = running_rhat.potential_scale_reduction() true_rhat = tfp.mcmc.potential_scale_reduction( chains_states=x, independent_chain_ndims=1, @@ -383,14 +379,12 @@ def test_random_scalar_computation(self): def test_non_scalar_samples(self): rng = test_util.test_np_rng() x = rng.rand(100, 2, 2, 3, 5) * 100 - running_rhat = tfp.experimental.stats.RunningPotentialScaleReduction( + running_rhat = tfp.experimental.stats.RunningPotentialScaleReduction.from_shape( shape=(2, 2, 3, 5), - independent_chain_ndims=1 ) - state = running_rhat.initialize() for sample in x: - state = running_rhat.update(state, sample) - rhat = running_rhat.finalize(state) + running_rhat = running_rhat.update(sample) + rhat = running_rhat.potential_scale_reduction() true_rhat = tfp.mcmc.potential_scale_reduction( chains_states=x, independent_chain_ndims=1, @@ -407,28 +401,26 @@ def test_batching(self): # shifted. offset = np.array([1., -1., 2.]).reshape(3, 1) state_1 = np.random.randn(n_samples, 3, 4) + offset - running_rhat = tfp.experimental.stats.RunningPotentialScaleReduction( + running_rhat = tfp.experimental.stats.RunningPotentialScaleReduction.from_shape( shape=[(2,), (3, 4)], independent_chain_ndims=[1, 1] ) - state = running_rhat.initialize() for sample in zip(state_0, state_1): - state = running_rhat.update(state, sample) - rhat = self.evaluate(running_rhat.finalize(state)) + running_rhat = running_rhat.update(sample) + rhat = self.evaluate(running_rhat.potential_scale_reduction()) true_rhat = tfp.mcmc.potential_scale_reduction( chains_states=[state_0, state_1], independent_chain_ndims=1) self.assertAllClose(true_rhat, rhat, rtol=1e-6) def test_independent_chain_ndims(self): - running_rhat = tfp.experimental.stats.RunningPotentialScaleReduction( + running_rhat = tfp.experimental.stats.RunningPotentialScaleReduction.from_shape( shape=(5, 3), independent_chain_ndims=2, ) - state = running_rhat.initialize() x = np.arange(30, dtype=np.float32).reshape((2, 5, 3)) for sample in x: - state = running_rhat.update(state, sample) - rhat = running_rhat.finalize(state) + running_rhat = running_rhat.update(sample) + rhat = running_rhat.potential_scale_reduction() true_rhat = tfp.mcmc.potential_scale_reduction( chains_states=x, independent_chain_ndims=2, @@ -440,20 +432,19 @@ def test_tf_while(self): rng = test_util.test_np_rng() x = rng.rand(100, 10) * 100 tensor_x = tf.convert_to_tensor(x) - running_rhat = tfp.experimental.stats.RunningPotentialScaleReduction( + running_rhat = tfp.experimental.stats.RunningPotentialScaleReduction.from_shape( shape=(10,), independent_chain_ndims=1 ) - state = running_rhat.initialize() - def _loop_body(i, state): - new_state = running_rhat.update(state, tensor_x[i]) - return i + 1, new_state - _, state = tf.while_loop( + def _loop_body(i, running_rhat): + running_rhat = running_rhat.update(tensor_x[i]) + return i + 1, running_rhat + _, running_rhat = tf.while_loop( lambda i, _: i < 100, _loop_body, - (0, state) + (0, running_rhat) ) - rhat = running_rhat.finalize(state) + rhat = running_rhat.potential_scale_reduction() true_rhat = tfp.mcmc.potential_scale_reduction( chains_states=x, independent_chain_ndims=1, diff --git a/tensorflow_probability/python/math/psd_kernels/kumaraswamy_transformed.py b/tensorflow_probability/python/math/psd_kernels/kumaraswamy_transformed.py index d71c5e55fe..0a23ef213f 100644 --- a/tensorflow_probability/python/math/psd_kernels/kumaraswamy_transformed.py +++ b/tensorflow_probability/python/math/psd_kernels/kumaraswamy_transformed.py @@ -22,7 +22,6 @@ import tensorflow.compat.v2 as tf -from tensorflow_probability.python.bijectors.kumaraswamy_cdf import KumaraswamyCDF from tensorflow_probability.python.internal import assert_util from tensorflow_probability.python.internal import tensor_util from tensorflow_probability.python.math.psd_kernels import feature_transformed @@ -62,6 +61,11 @@ def __init__(self, name: Python `str` name given to ops managed by this object. """ parameters = dict(locals()) + + # Delayed import to avoid circular dependency between `tfp.bijectors` and + # `tfp.math` + from tensorflow_probability.python.bijectors import kumaraswamy_cdf # pylint: disable=g-import-not-at-top + with tf.name_scope(name): self._concentration1 = tensor_util.convert_nonref_to_tensor( concentration1, name='concentration1') @@ -78,9 +82,9 @@ def transform_by_kumaraswamy(x, feature_ndims, example_ndims): self.concentration0, example_ndims, start=-(feature_ndims + 1)) - bij = KumaraswamyCDF(concentration1, - concentration0, - validate_args=validate_args) + bij = kumaraswamy_cdf.KumaraswamyCDF(concentration1, + concentration0, + validate_args=validate_args) return bij.forward(x) super(KumaraswamyTransformed, self).__init__( diff --git a/tensorflow_probability/python/math/psd_kernels/schur_complement.py b/tensorflow_probability/python/math/psd_kernels/schur_complement.py index f4a64a8ae9..ab455158c6 100644 --- a/tensorflow_probability/python/math/psd_kernels/schur_complement.py +++ b/tensorflow_probability/python/math/psd_kernels/schur_complement.py @@ -21,8 +21,6 @@ import functools import tensorflow.compat.v2 as tf -from tensorflow_probability.python.bijectors import cholesky_outer_product -from tensorflow_probability.python.bijectors import invert from tensorflow_probability.python.internal import distribution_util from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import tensor_util @@ -203,6 +201,13 @@ def __init__(self, Default value: `"SchurComplement"` """ parameters = dict(locals()) + + # Delayed import to avoid circular dependency between `tfp.bijectors` and + # `tfp.math` + # pylint: disable=g-import-not-at-top + from tensorflow_probability.python.bijectors import cholesky_outer_product + from tensorflow_probability.python.bijectors import invert + # pylint: enable=g-import-not-at-top with tf.name_scope(name) as name: dtype = dtype_util.common_dtype( [base_kernel, fixed_inputs, diag_shift], tf.float32) diff --git a/tensorflow_probability/python/version.py b/tensorflow_probability/python/version.py index a76151793c..2d6a3e8646 100644 --- a/tensorflow_probability/python/version.py +++ b/tensorflow_probability/python/version.py @@ -24,7 +24,7 @@ # stable release (indicated by `_VERSION_SUFFIX = ''`). Outside the context of a # release branch, the current version is by default assumed to be a # 'development' version, labeled 'dev'. -_VERSION_SUFFIX = 'rc0' +_VERSION_SUFFIX = 'rc1' # Example, '0.4.0-dev' __version__ = '.'.join([