Skip to content

Commit

Permalink
Fix NaNs when running the spike-and-slab sampler under XLA.
Browse files Browse the repository at this point in the history
This now enforces hard sparsity directly by setting sampled weights to zero, instead relying on an infinite-precision MVN to do the right thing. The previous approach worked in TF, but sometimes returned NaNs under XLA.

PiperOrigin-RevId: 410036367
  • Loading branch information
davmre authored and tensorflower-gardener committed Nov 15, 2021
1 parent 8870950 commit 76eaa7e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 26 deletions.
1 change: 1 addition & 0 deletions tensorflow_probability/python/experimental/sts_gibbs/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,6 @@ py_test(
"//tensorflow_probability",
"//tensorflow_probability/python/experimental/sts_gibbs",
"//tensorflow_probability/python/internal:test_util",
# "//third_party/tensorflow/compiler/jit:xla_cpu_jit", # DisableOnExport
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import collections

import numpy as np
import tensorflow.compat.v2 as tf

from tensorflow_probability.python import math as tfp_math
Expand Down Expand Up @@ -53,6 +52,26 @@ def _sample_n(self, n, seed=None):
return xs


class MVNPrecisionFactorHardZeros(
MultivariateNormalPrecisionFactorLinearOperator):
"""Multivariate normal that forces some sample dimensions to zero.
This is equivalent to setting `loc[d] = 0.` and `precision_factor[d, d]=`inf`
in the zeroed dimensions, but is numerically better behaved.
"""

def __init__(self, loc, precision_factor, nonzeros, **kwargs):
self._nonzeros = nonzeros
super().__init__(loc=loc, precision_factor=precision_factor, **kwargs)

def _call_sample_n(self, *args, **kwargs):
xs = super()._call_sample_n(*args, **kwargs)
return tf.where(self._nonzeros, xs, 0.)

def _log_prob(self, *args, **kwargs):
raise NotImplementedError('Log prob is not currently implemented.')


class SpikeSlabSamplerState(collections.namedtuple(
'SpikeSlabSamplerState',
['x_transpose_y',
Expand Down Expand Up @@ -513,14 +532,6 @@ def _compute_log_prob(

def _get_conditional_posterior(self, sampler_state):
"""Builds the joint posterior for a sparsity pattern (eqn (7) from [1])."""
# Impose a hard, infinite-precision constraint on zeroed-out features, in
# place of the identity-matrix representation that we used for numerical
# convenience during sampling.
hard_precision_factor = _select_nonzero_block(
sampler_state.conditional_posterior_precision_chol,
nonzeros=sampler_state.nonzeros,
identity_multiplier=np.inf)

@joint_distribution_auto_batched.JointDistributionCoroutineAutoBatched
def posterior_jd():
observation_noise_variance = yield InverseGammaWithSampleUpperBound(
Expand All @@ -529,20 +540,21 @@ def posterior_jd():
scale=sampler_state.observation_noise_variance_posterior_scale,
upper_bound=self.observation_noise_variance_upper_bound,
name='observation_noise_variance')
yield MultivariateNormalPrecisionFactorLinearOperator(
yield MVNPrecisionFactorHardZeros(
loc=sampler_state.conditional_weights_mean,
# Note that the posterior precision varies inversely with the
# noise variance: in worlds with high noise we're also
# more uncertain about the values of the weights.
precision_factor=tf.linalg.LinearOperatorLowerTriangular(
hard_precision_factor /
sampler_state.conditional_posterior_precision_chol /
observation_noise_variance[..., tf.newaxis, tf.newaxis]),
nonzeros=sampler_state.nonzeros,
name='weights')

return posterior_jd


def _select_nonzero_block(matrix, nonzeros, identity_multiplier=1.):
def _select_nonzero_block(matrix, nonzeros):
"""Replaces the `i`th row & col with the identity if not `nonzeros[i]`.
This function effectively selects the 'slab' rows (corresponding to
Expand All @@ -566,7 +578,6 @@ def _select_nonzero_block(matrix, nonzeros, identity_multiplier=1.):
matrix: (batch of) float Tensor matrix(s) of shape
`[num_features, num_features]`.
nonzeros: (batch of) boolean Tensor vectors of shape `[num_features]`.
identity_multiplier: optional scalar multiplier for the identity matrix.
Returns:
block_matrix: (batch of) float Tensor matrix(s) of the same shape as
`matrix`, in which `block_matrix[i, j] = matrix[i, j] if
Expand All @@ -578,13 +589,10 @@ def _select_nonzero_block(matrix, nonzeros, identity_multiplier=1.):
masked = tf.where(nonzeros[..., tf.newaxis],
tf.where(nonzeros[..., tf.newaxis, :], matrix, 0.),
0.)
# Restore a value of `identity_multiplier` on the diagonal of the not-selected
# rows. This avoids numerical issues by ensuring that the matrix still has
# full rank.
# Restore a value of 1 on the diagonal of the not-selected rows. This avoids
# numerical issues by ensuring that the matrix still has full rank.
return tf.linalg.set_diag(masked,
tf.where(nonzeros,
tf.linalg.diag_part(masked),
identity_multiplier))
tf.where(nonzeros, tf.linalg.diag_part(masked), 1.))


def _update_nonzero_block_chol(chol, idx, psd_matrix, new_nonzeros):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ def loop_body(var_weights_seed, _):
tf.reduce_mean(nonzero_weight_samples),
atol=0.03)

def test_deterministic_given_seed(self):
@parameterized.named_parameters(('', False), ('_xla', True))
def test_deterministic_given_seed(self, use_xla):
design_matrix, _, targets = self.evaluate(
self._random_regression_task(
num_outputs=3, num_features=4, batch_shape=[],
Expand All @@ -307,13 +308,16 @@ def test_deterministic_given_seed(self):

initial_nonzeros = tf.convert_to_tensor([True, False, False, True])
seed = test_util.test_seed(sampler_type='stateless')
variance1, weights1 = self.evaluate(
sampler.sample_noise_variance_and_weights(
targets, initial_nonzeros, seed=seed))
variance2, weights2 = self.evaluate(
sampler.sample_noise_variance_and_weights(
targets, initial_nonzeros, seed=seed))

@tf.function(jit_compile=use_xla)
def do_sample(seed):
return sampler.sample_noise_variance_and_weights(
targets, initial_nonzeros, seed=seed)
variance1, weights1 = self.evaluate(do_sample(seed))
variance2, weights2 = self.evaluate(do_sample(seed))
self.assertAllFinite(variance1)
self.assertAllClose(variance1, variance2)
self.assertAllFinite(weights1)
self.assertAllClose(weights1, weights2)


Expand Down

0 comments on commit 76eaa7e

Please sign in to comment.