From 11031d54e8002e02f098cbf52adfb2fef33dbf2a Mon Sep 17 00:00:00 2001 From: siege Date: Tue, 12 Mar 2024 12:21:50 -0700 Subject: [PATCH] Fix the example for tfd.Autoregressive to actually be autoregressive. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is done by zeroing out the diagonal. The autoregressive property (with a left-to-right order) dictates that when sampling element i, we must look only at elements 0 ≤ k < i, which implies that the diagonal must be zero. C.f. the 'exclusive' masking for the initial layers in tfb.AutoregressiveNetwork. Also, add a comment about the superflous iteration when sampling. PiperOrigin-RevId: 615131666 --- .../python/distributions/BUILD | 1 - .../python/distributions/autoregressive.py | 8 +++++-- .../distributions/autoregressive_test.py | 24 +++++++++---------- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/tensorflow_probability/python/distributions/BUILD b/tensorflow_probability/python/distributions/BUILD index 26f80786dd..413f625142 100644 --- a/tensorflow_probability/python/distributions/BUILD +++ b/tensorflow_probability/python/distributions/BUILD @@ -2568,7 +2568,6 @@ multi_substrate_py_test( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/bijectors:masked_autoregressive", - "//tensorflow_probability/python/bijectors:scale_matvec_tril", "//tensorflow_probability/python/internal:reparameterization", "//tensorflow_probability/python/internal:test_util", "//tensorflow_probability/python/math:linalg", diff --git a/tensorflow_probability/python/distributions/autoregressive.py b/tensorflow_probability/python/distributions/autoregressive.py index 7188589353..11afbdc2d9 100644 --- a/tensorflow_probability/python/distributions/autoregressive.py +++ b/tensorflow_probability/python/distributions/autoregressive.py @@ -86,9 +86,10 @@ class Autoregressive(distribution.Distribution): def _normal_fn(event_size): n = event_size * (event_size + 1) // 2 p = tf.Variable(tfd.Normal(loc=0., scale=1.).sample(n)) - affine = tfb.FillScaleTriL(tfp.math.fill_triangular(0.25 * p)) + ar_matrix = tf.linalg.set_diag(tfp.math.fill_triangular(0.25 * p), + tf.zeros(event_size)) def _fn(samples): - scale = tf.exp(affine(samples)) + scale = tf.exp(tf.linalg.matvec(ar_matrix, samples)) return tfd.Independent( tfd.Normal(loc=0., scale=scale, validate_args=True), reinterpreted_batch_ndims=1) @@ -291,6 +292,9 @@ def _sample_n(self, n, seed=None): seed = stateful_seed if is_stateful_sampler else stateless_seed + # This runs for 1 more step than strictly necessary because there is no + # guarantee that the samples produced by the sample(n) above is the same as + # batched sample() below. if num_steps_static is not None: for _ in range(num_steps_static): # pylint: disable=not-callable diff --git a/tensorflow_probability/python/distributions/autoregressive_test.py b/tensorflow_probability/python/distributions/autoregressive_test.py index 891b9f6a13..7db54fcc91 100644 --- a/tensorflow_probability/python/distributions/autoregressive_test.py +++ b/tensorflow_probability/python/distributions/autoregressive_test.py @@ -21,7 +21,6 @@ import numpy as np import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import masked_autoregressive -from tensorflow_probability.python.bijectors import scale_matvec_tril from tensorflow_probability.python.distributions import autoregressive from tensorflow_probability.python.distributions import categorical from tensorflow_probability.python.distributions import distribution @@ -43,14 +42,17 @@ def setUp(self): super(AutoregressiveTest, self).setUp() self._rng = np.random.RandomState(42) - def _random_scale_tril(self, event_size): + def _random_ar_matrix(self, event_size): n = np.int32(event_size * (event_size + 1) // 2) p = 2. * self._rng.random_sample(n).astype(np.float32) - 1. - return linalg.fill_triangular(0.25 * p) + # Zero-out the diagonal to ensure auto-regressive property. + return tf.linalg.set_diag( + linalg.fill_triangular(0.25 * p), tf.zeros(event_size) + ) - def _normal_fn(self, affine_bijector): + def _normal_fn(self, affine): def _fn(samples): - scale = tf.exp(affine_bijector.forward(samples)) + scale = tf.exp(tf.linalg.matvec(affine, samples)) return independent.Independent( normal.Normal(loc=0., scale=scale, validate_args=True), reinterpreted_batch_ndims=1, @@ -63,10 +65,9 @@ def testSampleAndLogProbConsistency(self): event_size = 2 batch_event_shape = np.concatenate([batch_shape, [event_size]], axis=0) sample0 = tf.zeros(batch_event_shape) - affine = scale_matvec_tril.ScaleMatvecTriL( - scale_tril=self._random_scale_tril(event_size), validate_args=True) + ar_matrix = self._random_ar_matrix(event_size) ar = autoregressive.Autoregressive( - self._normal_fn(affine), sample0, validate_args=True) + self._normal_fn(ar_matrix), sample0, validate_args=True) self.run_test_sample_consistent_log_prob( self.evaluate, ar, @@ -107,13 +108,12 @@ def testCompareToBijector(self): event_size = np.int32(2) batch_event_shape = np.concatenate([batch_shape, [event_size]], axis=0) sample0 = tf.zeros(batch_event_shape) - affine = scale_matvec_tril.ScaleMatvecTriL( - scale_tril=self._random_scale_tril(event_size), validate_args=True) + ar_matrix = self._random_ar_matrix(event_size) ar = autoregressive.Autoregressive( - self._normal_fn(affine), sample0, validate_args=True) + self._normal_fn(ar_matrix), sample0, validate_args=True) ar_flow = masked_autoregressive.MaskedAutoregressiveFlow( is_constant_jacobian=True, - shift_and_log_scale_fn=lambda x: [None, affine.forward(x)], + shift_and_log_scale_fn=lambda x: [None, tf.linalg.matvec(ar_matrix, x)], validate_args=True) td = transformed_distribution.TransformedDistribution( # TODO(b/137665504): Use batch-adding meta-distribution to set the batch