Skip to content

Commit

Permalink
Only use RBF kernels in the leaves when making trees of BNN's. The other
Browse files Browse the repository at this point in the history
kernel types (Matern, Exponential, and OneLayer) aren't needed because they
make very similar predictions.

PiperOrigin-RevId: 615850103
  • Loading branch information
ThomasColthurst authored and tensorflower-gardener committed Mar 14, 2024
1 parent 9e996db commit 1df44cc
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 29 deletions.
19 changes: 4 additions & 15 deletions tensorflow_probability/python/experimental/autobnn/bnn_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,15 @@
Array = jnp.ndarray


LEAVES = [
NON_PERIODIC_KERNELS = [
kernels.ExponentiatedQuadraticBNN,
kernels.MaternBNN,
kernels.ExponentialBNN,
kernels.LinearBNN,
kernels.QuadraticBNN,
kernels.PeriodicBNN,
kernels.OneLayerBNN,
# Don't use Matern, Exponential or OneLayer BNN's in the leaves because
# they all give very similar predictions to ExponentiatedQuadratic.
]

LEAVES = NON_PERIODIC_KERNELS + [kernels.PeriodicBNN]

OPERATORS = [
operators.Multiply,
Expand All @@ -47,16 +46,6 @@
]


NON_PERIODIC_KERNELS = [
kernels.ExponentiatedQuadraticBNN,
kernels.MaternBNN,
kernels.ExponentialBNN,
kernels.LinearBNN,
kernels.QuadraticBNN,
kernels.OneLayerBNN,
]


def list_of_all(
time_series_xs: Array,
depth: int = 2,
Expand Down
29 changes: 15 additions & 14 deletions tensorflow_probability/python/experimental/autobnn/bnn_tree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,24 @@ class TreeTest(parameterized.TestCase):

def test_list_of_all_depth0(self):
l0 = bnn_tree.list_of_all(jnp.linspace(0.0, 100.0, 100), 0)
# With no periods, there should be six kernels.
self.assertLen(l0, 6)
# With no periods, there should be three kernels.
self.assertLen(l0, 3)
for k in l0:
self.assertFalse(k.going_to_be_multiplied)

l0 = bnn_tree.list_of_all(100, 0, 50, [20.0, 40.0], parent_is_multiply=True)
self.assertLen(l0, 8)
# With two periods specified, 3+2 = 5.
self.assertLen(l0, 5)
for k in l0:
self.assertTrue(k.going_to_be_multiplied)

def test_list_of_all_depth1(self):
l1 = bnn_tree.list_of_all(jnp.linspace(0.0, 100.0, 100), 1)
# With no periods, there should be
# choose(6+1, 2) = 21 trees with a Multiply top node,
# choose(6, 2) = 15 trees with a WeightedSum top node, and
# 6*6 = 36 trees with a LearnableChangePoint top node.
self.assertLen(l1, 72)
# choose(3+1, 2) = 6 trees with a Multiply top node,
# choose(3, 2) = 3 trees with a WeightedSum top node, and
# 3*3 = 9 trees with a LearnableChangePoint top node.
self.assertLen(l1, 18)

# Check that all of the BNNs in the tree can be trained.
for k in l1:
Expand All @@ -61,16 +62,16 @@ def test_list_of_all_depth1(self):
parent_is_multiply=True,
)
# With 2 periods and parent_is_multiply, there are only WeightedSum top
# nodes, with 7*8/2 = 28 trees.
self.assertLen(l1, 28)
# nodes, with 4*5/2 = 10 trees.
self.assertLen(l1, 10)

def test_list_of_all_depth2(self):
l2 = bnn_tree.list_of_all(jnp.linspace(0.0, 100.0, 100), 2)
# There are 66 trees of depth 1, of which 15 are safe to multiply.
# choose(15+1, 2) = 120 trees with a Multiply top node,
# choose(66, 2) = 2145 trees with a WeightedSum top node, and
# 66*66 = 4356 trees with a LearnableChangePoint top node.
self.assertLen(l2, 7860)
# There are 18 trees of depth 1, of which 3 are safe to multiply.
# choose(3+1, 2) = 6 trees with a Multiply top node,
# choose(18, 2) = 153 trees with a WeightedSum top node, and
# 18*18 = 324 trees with a LearnableChangePoint top node.
self.assertLen(l2, 483)

@parameterized.parameters(0, 1) # depth=2 segfaults on my desktop :(
def test_weighted_sum_of_all(self, depth):
Expand Down

0 comments on commit 1df44cc

Please sign in to comment.