diff --git a/tensorflow_probability/python/experimental/autobnn/bnn_tree.py b/tensorflow_probability/python/experimental/autobnn/bnn_tree.py index 18e7dacaf2..4e63635d2a 100644 --- a/tensorflow_probability/python/experimental/autobnn/bnn_tree.py +++ b/tensorflow_probability/python/experimental/autobnn/bnn_tree.py @@ -27,16 +27,15 @@ Array = jnp.ndarray -LEAVES = [ +NON_PERIODIC_KERNELS = [ kernels.ExponentiatedQuadraticBNN, - kernels.MaternBNN, - kernels.ExponentialBNN, kernels.LinearBNN, kernels.QuadraticBNN, - kernels.PeriodicBNN, - kernels.OneLayerBNN, + # Don't use Matern, Exponential or OneLayer BNN's in the leaves because + # they all give very similar predictions to ExponentiatedQuadratic. ] +LEAVES = NON_PERIODIC_KERNELS + [kernels.PeriodicBNN] OPERATORS = [ operators.Multiply, @@ -47,16 +46,6 @@ ] -NON_PERIODIC_KERNELS = [ - kernels.ExponentiatedQuadraticBNN, - kernels.MaternBNN, - kernels.ExponentialBNN, - kernels.LinearBNN, - kernels.QuadraticBNN, - kernels.OneLayerBNN, -] - - def list_of_all( time_series_xs: Array, depth: int = 2, diff --git a/tensorflow_probability/python/experimental/autobnn/bnn_tree_test.py b/tensorflow_probability/python/experimental/autobnn/bnn_tree_test.py index 67e6f87334..197c535dbe 100644 --- a/tensorflow_probability/python/experimental/autobnn/bnn_tree_test.py +++ b/tensorflow_probability/python/experimental/autobnn/bnn_tree_test.py @@ -27,23 +27,24 @@ class TreeTest(parameterized.TestCase): def test_list_of_all_depth0(self): l0 = bnn_tree.list_of_all(jnp.linspace(0.0, 100.0, 100), 0) - # With no periods, there should be six kernels. - self.assertLen(l0, 6) + # With no periods, there should be three kernels. + self.assertLen(l0, 3) for k in l0: self.assertFalse(k.going_to_be_multiplied) l0 = bnn_tree.list_of_all(100, 0, 50, [20.0, 40.0], parent_is_multiply=True) - self.assertLen(l0, 8) + # With two periods specified, 3+2 = 5. + self.assertLen(l0, 5) for k in l0: self.assertTrue(k.going_to_be_multiplied) def test_list_of_all_depth1(self): l1 = bnn_tree.list_of_all(jnp.linspace(0.0, 100.0, 100), 1) # With no periods, there should be - # choose(6+1, 2) = 21 trees with a Multiply top node, - # choose(6, 2) = 15 trees with a WeightedSum top node, and - # 6*6 = 36 trees with a LearnableChangePoint top node. - self.assertLen(l1, 72) + # choose(3+1, 2) = 6 trees with a Multiply top node, + # choose(3, 2) = 3 trees with a WeightedSum top node, and + # 3*3 = 9 trees with a LearnableChangePoint top node. + self.assertLen(l1, 18) # Check that all of the BNNs in the tree can be trained. for k in l1: @@ -61,16 +62,16 @@ def test_list_of_all_depth1(self): parent_is_multiply=True, ) # With 2 periods and parent_is_multiply, there are only WeightedSum top - # nodes, with 7*8/2 = 28 trees. - self.assertLen(l1, 28) + # nodes, with 4*5/2 = 10 trees. + self.assertLen(l1, 10) def test_list_of_all_depth2(self): l2 = bnn_tree.list_of_all(jnp.linspace(0.0, 100.0, 100), 2) - # There are 66 trees of depth 1, of which 15 are safe to multiply. - # choose(15+1, 2) = 120 trees with a Multiply top node, - # choose(66, 2) = 2145 trees with a WeightedSum top node, and - # 66*66 = 4356 trees with a LearnableChangePoint top node. - self.assertLen(l2, 7860) + # There are 18 trees of depth 1, of which 3 are safe to multiply. + # choose(3+1, 2) = 6 trees with a Multiply top node, + # choose(18, 2) = 153 trees with a WeightedSum top node, and + # 18*18 = 324 trees with a LearnableChangePoint top node. + self.assertLen(l2, 483) @parameterized.parameters(0, 1) # depth=2 segfaults on my desktop :( def test_weighted_sum_of_all(self, depth):