Skip to content

Commit

Permalink
Fix the example for tfd.Autoregressive to actually be autoregressive.
Browse files Browse the repository at this point in the history
This is done by zeroing out the diagonal. The autoregressive property (with a
left-to-right order) dictates that when sampling element i, we must look only
at elements 0 ≤ k < i, which implies that the diagonal must be zero. C.f. the
'exclusive' masking for the initial layers in tfb.AutoregressiveNetwork.

Also, add a comment about the superflous iteration when sampling.

PiperOrigin-RevId: 615131666
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Mar 12, 2024
1 parent 7baa486 commit 11031d5
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 15 deletions.
1 change: 0 additions & 1 deletion tensorflow_probability/python/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2568,7 +2568,6 @@ multi_substrate_py_test(
# numpy dep,
# tensorflow dep,
"//tensorflow_probability/python/bijectors:masked_autoregressive",
"//tensorflow_probability/python/bijectors:scale_matvec_tril",
"//tensorflow_probability/python/internal:reparameterization",
"//tensorflow_probability/python/internal:test_util",
"//tensorflow_probability/python/math:linalg",
Expand Down
8 changes: 6 additions & 2 deletions tensorflow_probability/python/distributions/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ class Autoregressive(distribution.Distribution):
def _normal_fn(event_size):
n = event_size * (event_size + 1) // 2
p = tf.Variable(tfd.Normal(loc=0., scale=1.).sample(n))
affine = tfb.FillScaleTriL(tfp.math.fill_triangular(0.25 * p))
ar_matrix = tf.linalg.set_diag(tfp.math.fill_triangular(0.25 * p),
tf.zeros(event_size))
def _fn(samples):
scale = tf.exp(affine(samples))
scale = tf.exp(tf.linalg.matvec(ar_matrix, samples))
return tfd.Independent(
tfd.Normal(loc=0., scale=scale, validate_args=True),
reinterpreted_batch_ndims=1)
Expand Down Expand Up @@ -291,6 +292,9 @@ def _sample_n(self, n, seed=None):

seed = stateful_seed if is_stateful_sampler else stateless_seed

# This runs for 1 more step than strictly necessary because there is no
# guarantee that the samples produced by the sample(n) above is the same as
# batched sample() below.
if num_steps_static is not None:
for _ in range(num_steps_static):
# pylint: disable=not-callable
Expand Down
24 changes: 12 additions & 12 deletions tensorflow_probability/python/distributions/autoregressive_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import numpy as np
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.bijectors import masked_autoregressive
from tensorflow_probability.python.bijectors import scale_matvec_tril
from tensorflow_probability.python.distributions import autoregressive
from tensorflow_probability.python.distributions import categorical
from tensorflow_probability.python.distributions import distribution
Expand All @@ -43,14 +42,17 @@ def setUp(self):
super(AutoregressiveTest, self).setUp()
self._rng = np.random.RandomState(42)

def _random_scale_tril(self, event_size):
def _random_ar_matrix(self, event_size):
n = np.int32(event_size * (event_size + 1) // 2)
p = 2. * self._rng.random_sample(n).astype(np.float32) - 1.
return linalg.fill_triangular(0.25 * p)
# Zero-out the diagonal to ensure auto-regressive property.
return tf.linalg.set_diag(
linalg.fill_triangular(0.25 * p), tf.zeros(event_size)
)

def _normal_fn(self, affine_bijector):
def _normal_fn(self, affine):
def _fn(samples):
scale = tf.exp(affine_bijector.forward(samples))
scale = tf.exp(tf.linalg.matvec(affine, samples))
return independent.Independent(
normal.Normal(loc=0., scale=scale, validate_args=True),
reinterpreted_batch_ndims=1,
Expand All @@ -63,10 +65,9 @@ def testSampleAndLogProbConsistency(self):
event_size = 2
batch_event_shape = np.concatenate([batch_shape, [event_size]], axis=0)
sample0 = tf.zeros(batch_event_shape)
affine = scale_matvec_tril.ScaleMatvecTriL(
scale_tril=self._random_scale_tril(event_size), validate_args=True)
ar_matrix = self._random_ar_matrix(event_size)
ar = autoregressive.Autoregressive(
self._normal_fn(affine), sample0, validate_args=True)
self._normal_fn(ar_matrix), sample0, validate_args=True)
self.run_test_sample_consistent_log_prob(
self.evaluate,
ar,
Expand Down Expand Up @@ -107,13 +108,12 @@ def testCompareToBijector(self):
event_size = np.int32(2)
batch_event_shape = np.concatenate([batch_shape, [event_size]], axis=0)
sample0 = tf.zeros(batch_event_shape)
affine = scale_matvec_tril.ScaleMatvecTriL(
scale_tril=self._random_scale_tril(event_size), validate_args=True)
ar_matrix = self._random_ar_matrix(event_size)
ar = autoregressive.Autoregressive(
self._normal_fn(affine), sample0, validate_args=True)
self._normal_fn(ar_matrix), sample0, validate_args=True)
ar_flow = masked_autoregressive.MaskedAutoregressiveFlow(
is_constant_jacobian=True,
shift_and_log_scale_fn=lambda x: [None, affine.forward(x)],
shift_and_log_scale_fn=lambda x: [None, tf.linalg.matvec(ar_matrix, x)],
validate_args=True)
td = transformed_distribution.TransformedDistribution(
# TODO(b/137665504): Use batch-adding meta-distribution to set the batch
Expand Down

0 comments on commit 11031d5

Please sign in to comment.