diff --git a/tensorflow_probability/python/experimental/autobnn/bnn_tree.py b/tensorflow_probability/python/experimental/autobnn/bnn_tree.py index 4a02f23252..18e7dacaf2 100644 --- a/tensorflow_probability/python/experimental/autobnn/bnn_tree.py +++ b/tensorflow_probability/python/experimental/autobnn/bnn_tree.py @@ -30,6 +30,7 @@ LEAVES = [ kernels.ExponentiatedQuadraticBNN, kernels.MaternBNN, + kernels.ExponentialBNN, kernels.LinearBNN, kernels.QuadraticBNN, kernels.PeriodicBNN, @@ -49,6 +50,7 @@ NON_PERIODIC_KERNELS = [ kernels.ExponentiatedQuadraticBNN, kernels.MaternBNN, + kernels.ExponentialBNN, kernels.LinearBNN, kernels.QuadraticBNN, kernels.OneLayerBNN, @@ -86,7 +88,7 @@ def list_of_all( # Abelian operators that aren't Multiply. if include_sums: for i, c1 in enumerate(non_multiply_children): - for j in range(i + 1): + for j in range(i): c2 = non_multiply_children[j] # Add is also abelian, but WeightedSum is more general. all_bnns.append( diff --git a/tensorflow_probability/python/experimental/autobnn/bnn_tree_test.py b/tensorflow_probability/python/experimental/autobnn/bnn_tree_test.py index 10b38b24c2..67e6f87334 100644 --- a/tensorflow_probability/python/experimental/autobnn/bnn_tree_test.py +++ b/tensorflow_probability/python/experimental/autobnn/bnn_tree_test.py @@ -25,24 +25,25 @@ class TreeTest(parameterized.TestCase): - def test_list_of_all(self): + def test_list_of_all_depth0(self): l0 = bnn_tree.list_of_all(jnp.linspace(0.0, 100.0, 100), 0) - # With no periods, there should be five kernels. - self.assertLen(l0, 5) + # With no periods, there should be six kernels. + self.assertLen(l0, 6) for k in l0: self.assertFalse(k.going_to_be_multiplied) l0 = bnn_tree.list_of_all(100, 0, 50, [20.0, 40.0], parent_is_multiply=True) - self.assertLen(l0, 7) + self.assertLen(l0, 8) for k in l0: self.assertTrue(k.going_to_be_multiplied) + def test_list_of_all_depth1(self): l1 = bnn_tree.list_of_all(jnp.linspace(0.0, 100.0, 100), 1) # With no periods, there should be - # 15 trees with a Multiply top node, - # 15 trees with a WeightedSum top node, and - # 25 trees with a LearnableChangePoint top node. - self.assertLen(l1, 55) + # choose(6+1, 2) = 21 trees with a Multiply top node, + # choose(6, 2) = 15 trees with a WeightedSum top node, and + # 6*6 = 36 trees with a LearnableChangePoint top node. + self.assertLen(l1, 72) # Check that all of the BNNs in the tree can be trained. for k in l1: @@ -63,12 +64,13 @@ def test_list_of_all(self): # nodes, with 7*8/2 = 28 trees. self.assertLen(l1, 28) + def test_list_of_all_depth2(self): l2 = bnn_tree.list_of_all(jnp.linspace(0.0, 100.0, 100), 2) - # With no periods, there should be - # 15*16/2 = 120 trees with a Multiply top node, - # 55*56/2 = 1540 trees with a WeightedSum top node, and - # 55*55 = 3025 trees with a LearnableChangePoint top node. - self.assertLen(l2, 4685) + # There are 66 trees of depth 1, of which 15 are safe to multiply. + # choose(15+1, 2) = 120 trees with a Multiply top node, + # choose(66, 2) = 2145 trees with a WeightedSum top node, and + # 66*66 = 4356 trees with a LearnableChangePoint top node. + self.assertLen(l2, 7860) @parameterized.parameters(0, 1) # depth=2 segfaults on my desktop :( def test_weighted_sum_of_all(self, depth): diff --git a/tensorflow_probability/python/experimental/autobnn/kernels.py b/tensorflow_probability/python/experimental/autobnn/kernels.py index b02ffb7316..f84dca7f4e 100644 --- a/tensorflow_probability/python/experimental/autobnn/kernels.py +++ b/tensorflow_probability/python/experimental/autobnn/kernels.py @@ -162,11 +162,25 @@ def kernel_init(seed, shape, unused_dtype): self.kernel_init = kernel_init super().setup() + def distributions(self): + d = super().distributions() + d['dense1']['kernel'] = student_t_lib.StudentT( + df=2.0 * self.degrees_of_freedom, loc=0.0, scale=1.0) + return d + def summarize(self, params=None, full: bool = False) -> str: """Return a string summarizing the structure of the BNN.""" return f'{self.shortname()}({self.degrees_of_freedom})' +class ExponentialBNN(MaternBNN): + """Matern(0.5), also known as the absolute exponential kernel.""" + degrees_of_freedom: float = 0.5 + + def summarize(self, params=None, full: bool = False) -> str: + return self.shortname() + + class PolynomialBNN(OneLayerBNN): """A BNN where samples are polynomial functions.""" degree: int = 2 diff --git a/tensorflow_probability/python/experimental/autobnn/kernels_test.py b/tensorflow_probability/python/experimental/autobnn/kernels_test.py index 67e574d517..c5f26882eb 100644 --- a/tensorflow_probability/python/experimental/autobnn/kernels_test.py +++ b/tensorflow_probability/python/experimental/autobnn/kernels_test.py @@ -30,6 +30,7 @@ kernels.OneLayerBNN, kernels.ExponentiatedQuadraticBNN, kernels.MaternBNN, + kernels.ExponentialBNN, kernels.PeriodicBNN, kernels.PolynomialBNN, kernels.LinearBNN, @@ -139,6 +140,7 @@ def test_likelihood(self, kernel): (kernels.OneLayerBNN(width=10), 'OneLayer'), (kernels.ExponentiatedQuadraticBNN(width=5), 'RBF'), (kernels.MaternBNN(width=5), 'Matern(2.5)'), + (kernels.ExponentialBNN(width=20), 'Exponential'), (kernels.PeriodicBNN(period=10, width=10), 'Periodic(period=10.00)'), (kernels.PolynomialBNN(degree=3, width=2), 'Polynomial(degree=3)'), (kernels.LinearBNN(width=5), 'Linear'),