16
16
17
17
import collections
18
18
19
- import numpy as np
20
19
import tensorflow .compat .v2 as tf
21
20
22
21
from tensorflow_probability .python import math as tfp_math
@@ -53,6 +52,26 @@ def _sample_n(self, n, seed=None):
53
52
return xs
54
53
55
54
55
+ class MVNPrecisionFactorHardZeros (
56
+ MultivariateNormalPrecisionFactorLinearOperator ):
57
+ """Multivariate normal that forces some sample dimensions to zero.
58
+
59
+ This is equivalent to setting `loc[d] = 0.` and `precision_factor[d, d]=`inf`
60
+ in the zeroed dimensions, but is numerically better behaved.
61
+ """
62
+
63
+ def __init__ (self , loc , precision_factor , nonzeros , ** kwargs ):
64
+ self ._nonzeros = nonzeros
65
+ super ().__init__ (loc = loc , precision_factor = precision_factor , ** kwargs )
66
+
67
+ def _call_sample_n (self , * args , ** kwargs ):
68
+ xs = super ()._call_sample_n (* args , ** kwargs )
69
+ return tf .where (self ._nonzeros , xs , 0. )
70
+
71
+ def _log_prob (self , * args , ** kwargs ):
72
+ raise NotImplementedError ('Log prob is not currently implemented.' )
73
+
74
+
56
75
class SpikeSlabSamplerState (collections .namedtuple (
57
76
'SpikeSlabSamplerState' ,
58
77
['x_transpose_y' ,
@@ -513,14 +532,6 @@ def _compute_log_prob(
513
532
514
533
def _get_conditional_posterior (self , sampler_state ):
515
534
"""Builds the joint posterior for a sparsity pattern (eqn (7) from [1])."""
516
- # Impose a hard, infinite-precision constraint on zeroed-out features, in
517
- # place of the identity-matrix representation that we used for numerical
518
- # convenience during sampling.
519
- hard_precision_factor = _select_nonzero_block (
520
- sampler_state .conditional_posterior_precision_chol ,
521
- nonzeros = sampler_state .nonzeros ,
522
- identity_multiplier = np .inf )
523
-
524
535
@joint_distribution_auto_batched .JointDistributionCoroutineAutoBatched
525
536
def posterior_jd ():
526
537
observation_noise_variance = yield InverseGammaWithSampleUpperBound (
@@ -529,20 +540,21 @@ def posterior_jd():
529
540
scale = sampler_state .observation_noise_variance_posterior_scale ,
530
541
upper_bound = self .observation_noise_variance_upper_bound ,
531
542
name = 'observation_noise_variance' )
532
- yield MultivariateNormalPrecisionFactorLinearOperator (
543
+ yield MVNPrecisionFactorHardZeros (
533
544
loc = sampler_state .conditional_weights_mean ,
534
545
# Note that the posterior precision varies inversely with the
535
546
# noise variance: in worlds with high noise we're also
536
547
# more uncertain about the values of the weights.
537
548
precision_factor = tf .linalg .LinearOperatorLowerTriangular (
538
- hard_precision_factor /
549
+ sampler_state . conditional_posterior_precision_chol /
539
550
observation_noise_variance [..., tf .newaxis , tf .newaxis ]),
551
+ nonzeros = sampler_state .nonzeros ,
540
552
name = 'weights' )
541
553
542
554
return posterior_jd
543
555
544
556
545
- def _select_nonzero_block (matrix , nonzeros , identity_multiplier = 1. ):
557
+ def _select_nonzero_block (matrix , nonzeros ):
546
558
"""Replaces the `i`th row & col with the identity if not `nonzeros[i]`.
547
559
548
560
This function effectively selects the 'slab' rows (corresponding to
@@ -566,7 +578,6 @@ def _select_nonzero_block(matrix, nonzeros, identity_multiplier=1.):
566
578
matrix: (batch of) float Tensor matrix(s) of shape
567
579
`[num_features, num_features]`.
568
580
nonzeros: (batch of) boolean Tensor vectors of shape `[num_features]`.
569
- identity_multiplier: optional scalar multiplier for the identity matrix.
570
581
Returns:
571
582
block_matrix: (batch of) float Tensor matrix(s) of the same shape as
572
583
`matrix`, in which `block_matrix[i, j] = matrix[i, j] if
@@ -578,13 +589,10 @@ def _select_nonzero_block(matrix, nonzeros, identity_multiplier=1.):
578
589
masked = tf .where (nonzeros [..., tf .newaxis ],
579
590
tf .where (nonzeros [..., tf .newaxis , :], matrix , 0. ),
580
591
0. )
581
- # Restore a value of `identity_multiplier` on the diagonal of the not-selected
582
- # rows. This avoids numerical issues by ensuring that the matrix still has
583
- # full rank.
592
+ # Restore a value of 1 on the diagonal of the not-selected rows. This avoids
593
+ # numerical issues by ensuring that the matrix still has full rank.
584
594
return tf .linalg .set_diag (masked ,
585
- tf .where (nonzeros ,
586
- tf .linalg .diag_part (masked ),
587
- identity_multiplier ))
595
+ tf .where (nonzeros , tf .linalg .diag_part (masked ), 1. ))
588
596
589
597
590
598
def _update_nonzero_block_chol (chol , idx , psd_matrix , new_nonzeros ):
0 commit comments