-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_mixture_of_gaussian_bijector.py
63 lines (52 loc) · 2.13 KB
/
test_mixture_of_gaussian_bijector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.internal import test_util
import bijector_test_util
from mixture_of_gaussian_bijector import MixtureOfGaussians, InverseMixtureOfGaussians
tfb = tfp.bijectors
tfd = tfp.distributions
Root = tfd.JointDistributionCoroutine.Root
n_components = 100
n_dims = 2
component_logits = tf.convert_to_tensor(
[[1. / n_components for _ in range(n_components)] for _ in
range(n_dims)])
locs = tf.convert_to_tensor([[0. for _ in range(n_components)] for _ in range(n_dims)])
scales = tf.convert_to_tensor(
[[1. for _ in range(n_components)] for _ in
range(n_dims)])
dist = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(logits=component_logits),
components_distribution=tfd.Normal(loc=locs, scale=scales),
name=f"prior")
@test_util.test_all_tf_execution_regimes
class GateBijectorForNormalTests(test_util.TestCase):
def testBijector(self):
x = tfb.NormalCDF()(dist.sample(10, seed=(0,0)))
bijector = InverseMixtureOfGaussians(dist)
self.evaluate([v.initializer for v in bijector.trainable_variables])
self.assertStartsWith(bijector.name, 'mixture_of_gaussians')
self.assertAllClose(tf.convert_to_tensor(x), bijector.inverse(tf.identity(bijector.forward(x))))
def testTheoreticalFldj(self):
x = tfb.NormalCDF()(dist.sample(10, seed=(0,0)))
bijector = InverseMixtureOfGaussians(dist)
self.evaluate([v.initializer for v in bijector.trainable_variables])
y = bijector.forward(x)
bijector_test_util.assert_bijective_and_finite(
bijector,
tf.convert_to_tensor(x),
y,
eval_func=self.evaluate,
event_ndims=1,
inverse_event_ndims=1,
rtol=1e-5)
fldj = bijector.forward_log_det_jacobian(x, event_ndims=1)
# The jacobian is not yet broadcast, since it is constant.
fldj_theoretical = bijector_test_util.get_fldj_theoretical(
bijector, tf.convert_to_tensor(x), event_ndims=1)
self.assertAllClose(
self.evaluate(fldj_theoretical),
self.evaluate(fldj),
atol=1e-5,
rtol=1e-5)