Skip to content

Commit 76eaa7e

Browse files
davmretensorflower-gardener
authored andcommitted
Fix NaNs when running the spike-and-slab sampler under XLA.
This now enforces hard sparsity directly by setting sampled weights to zero, instead relying on an infinite-precision MVN to do the right thing. The previous approach worked in TF, but sometimes returned NaNs under XLA. PiperOrigin-RevId: 410036367
1 parent 8870950 commit 76eaa7e

File tree

3 files changed

+39
-26
lines changed

3 files changed

+39
-26
lines changed

tensorflow_probability/python/experimental/sts_gibbs/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,5 +111,6 @@ py_test(
111111
"//tensorflow_probability",
112112
"//tensorflow_probability/python/experimental/sts_gibbs",
113113
"//tensorflow_probability/python/internal:test_util",
114+
# "//third_party/tensorflow/compiler/jit:xla_cpu_jit", # DisableOnExport
114115
],
115116
)

tensorflow_probability/python/experimental/sts_gibbs/spike_and_slab.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import collections
1818

19-
import numpy as np
2019
import tensorflow.compat.v2 as tf
2120

2221
from tensorflow_probability.python import math as tfp_math
@@ -53,6 +52,26 @@ def _sample_n(self, n, seed=None):
5352
return xs
5453

5554

55+
class MVNPrecisionFactorHardZeros(
56+
MultivariateNormalPrecisionFactorLinearOperator):
57+
"""Multivariate normal that forces some sample dimensions to zero.
58+
59+
This is equivalent to setting `loc[d] = 0.` and `precision_factor[d, d]=`inf`
60+
in the zeroed dimensions, but is numerically better behaved.
61+
"""
62+
63+
def __init__(self, loc, precision_factor, nonzeros, **kwargs):
64+
self._nonzeros = nonzeros
65+
super().__init__(loc=loc, precision_factor=precision_factor, **kwargs)
66+
67+
def _call_sample_n(self, *args, **kwargs):
68+
xs = super()._call_sample_n(*args, **kwargs)
69+
return tf.where(self._nonzeros, xs, 0.)
70+
71+
def _log_prob(self, *args, **kwargs):
72+
raise NotImplementedError('Log prob is not currently implemented.')
73+
74+
5675
class SpikeSlabSamplerState(collections.namedtuple(
5776
'SpikeSlabSamplerState',
5877
['x_transpose_y',
@@ -513,14 +532,6 @@ def _compute_log_prob(
513532

514533
def _get_conditional_posterior(self, sampler_state):
515534
"""Builds the joint posterior for a sparsity pattern (eqn (7) from [1])."""
516-
# Impose a hard, infinite-precision constraint on zeroed-out features, in
517-
# place of the identity-matrix representation that we used for numerical
518-
# convenience during sampling.
519-
hard_precision_factor = _select_nonzero_block(
520-
sampler_state.conditional_posterior_precision_chol,
521-
nonzeros=sampler_state.nonzeros,
522-
identity_multiplier=np.inf)
523-
524535
@joint_distribution_auto_batched.JointDistributionCoroutineAutoBatched
525536
def posterior_jd():
526537
observation_noise_variance = yield InverseGammaWithSampleUpperBound(
@@ -529,20 +540,21 @@ def posterior_jd():
529540
scale=sampler_state.observation_noise_variance_posterior_scale,
530541
upper_bound=self.observation_noise_variance_upper_bound,
531542
name='observation_noise_variance')
532-
yield MultivariateNormalPrecisionFactorLinearOperator(
543+
yield MVNPrecisionFactorHardZeros(
533544
loc=sampler_state.conditional_weights_mean,
534545
# Note that the posterior precision varies inversely with the
535546
# noise variance: in worlds with high noise we're also
536547
# more uncertain about the values of the weights.
537548
precision_factor=tf.linalg.LinearOperatorLowerTriangular(
538-
hard_precision_factor /
549+
sampler_state.conditional_posterior_precision_chol /
539550
observation_noise_variance[..., tf.newaxis, tf.newaxis]),
551+
nonzeros=sampler_state.nonzeros,
540552
name='weights')
541553

542554
return posterior_jd
543555

544556

545-
def _select_nonzero_block(matrix, nonzeros, identity_multiplier=1.):
557+
def _select_nonzero_block(matrix, nonzeros):
546558
"""Replaces the `i`th row & col with the identity if not `nonzeros[i]`.
547559
548560
This function effectively selects the 'slab' rows (corresponding to
@@ -566,7 +578,6 @@ def _select_nonzero_block(matrix, nonzeros, identity_multiplier=1.):
566578
matrix: (batch of) float Tensor matrix(s) of shape
567579
`[num_features, num_features]`.
568580
nonzeros: (batch of) boolean Tensor vectors of shape `[num_features]`.
569-
identity_multiplier: optional scalar multiplier for the identity matrix.
570581
Returns:
571582
block_matrix: (batch of) float Tensor matrix(s) of the same shape as
572583
`matrix`, in which `block_matrix[i, j] = matrix[i, j] if
@@ -578,13 +589,10 @@ def _select_nonzero_block(matrix, nonzeros, identity_multiplier=1.):
578589
masked = tf.where(nonzeros[..., tf.newaxis],
579590
tf.where(nonzeros[..., tf.newaxis, :], matrix, 0.),
580591
0.)
581-
# Restore a value of `identity_multiplier` on the diagonal of the not-selected
582-
# rows. This avoids numerical issues by ensuring that the matrix still has
583-
# full rank.
592+
# Restore a value of 1 on the diagonal of the not-selected rows. This avoids
593+
# numerical issues by ensuring that the matrix still has full rank.
584594
return tf.linalg.set_diag(masked,
585-
tf.where(nonzeros,
586-
tf.linalg.diag_part(masked),
587-
identity_multiplier))
595+
tf.where(nonzeros, tf.linalg.diag_part(masked), 1.))
588596

589597

590598
def _update_nonzero_block_chol(chol, idx, psd_matrix, new_nonzeros):

tensorflow_probability/python/experimental/sts_gibbs/spike_and_slab_test.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,8 @@ def loop_body(var_weights_seed, _):
297297
tf.reduce_mean(nonzero_weight_samples),
298298
atol=0.03)
299299

300-
def test_deterministic_given_seed(self):
300+
@parameterized.named_parameters(('', False), ('_xla', True))
301+
def test_deterministic_given_seed(self, use_xla):
301302
design_matrix, _, targets = self.evaluate(
302303
self._random_regression_task(
303304
num_outputs=3, num_features=4, batch_shape=[],
@@ -307,13 +308,16 @@ def test_deterministic_given_seed(self):
307308

308309
initial_nonzeros = tf.convert_to_tensor([True, False, False, True])
309310
seed = test_util.test_seed(sampler_type='stateless')
310-
variance1, weights1 = self.evaluate(
311-
sampler.sample_noise_variance_and_weights(
312-
targets, initial_nonzeros, seed=seed))
313-
variance2, weights2 = self.evaluate(
314-
sampler.sample_noise_variance_and_weights(
315-
targets, initial_nonzeros, seed=seed))
311+
312+
@tf.function(jit_compile=use_xla)
313+
def do_sample(seed):
314+
return sampler.sample_noise_variance_and_weights(
315+
targets, initial_nonzeros, seed=seed)
316+
variance1, weights1 = self.evaluate(do_sample(seed))
317+
variance2, weights2 = self.evaluate(do_sample(seed))
318+
self.assertAllFinite(variance1)
316319
self.assertAllClose(variance1, variance2)
320+
self.assertAllFinite(weights1)
317321
self.assertAllClose(weights1, weights2)
318322

319323

0 commit comments

Comments
 (0)