diff --git a/tensorflow_probability/python/distributions/joint_distribution_coroutine.py b/tensorflow_probability/python/distributions/joint_distribution_coroutine.py index cef887b10c..d5b442fd2c 100644 --- a/tensorflow_probability/python/distributions/joint_distribution_coroutine.py +++ b/tensorflow_probability/python/distributions/joint_distribution_coroutine.py @@ -219,8 +219,10 @@ def __init__( """ parameters = dict(locals()) with tf.name_scope(name or 'JointDistributionCoroutine') as name: - self._sample_dtype = sample_dtype self._model_coroutine = model + # Hint `no_dependency` to tell tf.Module not to screw up the sample dtype + # with extraneous wrapping (list => ListWrapper, etc.). + self._sample_dtype = self._no_dependency(sample_dtype) self._single_sample_distributions = {} super(JointDistributionCoroutine, self).__init__( dtype=sample_dtype, diff --git a/tensorflow_probability/python/distributions/joint_distribution_coroutine_test.py b/tensorflow_probability/python/distributions/joint_distribution_coroutine_test.py index 4aca2df9d8..5079ae8ebc 100644 --- a/tensorflow_probability/python/distributions/joint_distribution_coroutine_test.py +++ b/tensorflow_probability/python/distributions/joint_distribution_coroutine_test.py @@ -624,6 +624,7 @@ def noncentered_horseshoe_prior(num_features): tfd.Sample(tfd.Normal(0., 1.), num_features)) yield tfd.Independent(tfd.Deterministic(weights_noncentered * scale), reinterpreted_batch_ndims=1) + # Currently sample_dtype is only used for `tf.nest.pack_structure_as`. In # the future we may use it for error checking and/or casting. sample_dtype = collections.namedtuple('Model', [ @@ -645,6 +646,18 @@ def noncentered_horseshoe_prior(num_features): self.assertEqual([3, 4], joint.log_prob( joint.sample([3, 4], seed=test_util.test_seed())).shape) + # Check that a list dtype doesn't get corrupted by `tf.Module` wrapping. + sample_dtype = [None, None, None, None] + joint = tfd.JointDistributionCoroutine( + lambda: noncentered_horseshoe_prior(4), + sample_dtype=sample_dtype, + validate_args=True) + ds, xs = joint.sample_distributions([2, 3], seed=test_util.test_seed()) + self.assertEqual(type(sample_dtype), type(xs)) + self.assertEqual(type(sample_dtype), type(ds)) + tf.nest.assert_same_structure(sample_dtype, ds) + tf.nest.assert_same_structure(sample_dtype, xs) + def test_repr_with_custom_sample_dtype(self): def model(): s = yield tfd.JointDistributionCoroutine.Root(