diff --git a/tensorflow_probability/python/experimental/autobnn/training_util.py b/tensorflow_probability/python/experimental/autobnn/training_util.py index 44d344788a..8af6ccc4c2 100644 --- a/tensorflow_probability/python/experimental/autobnn/training_util.py +++ b/tensorflow_probability/python/experimental/autobnn/training_util.py @@ -42,14 +42,16 @@ def _make_bayeux_model( # which in turn imports (through __init__.py files) autobnn. import bayeux as bx # pylint:disable=g-bad-import-order,g-import-not-at-top - test_seed, init_seed = jax.random.split(seed) - test_point = net.init(test_seed, x_train) transform, inverse_transform, ildj = util.make_transforms(net) - def _init(seed): - return net.init(seed, x_train) + @jax.jit + def _init(rand_seed): + return net.init(rand_seed, x_train) - initial_state = jax.vmap(_init)(jax.random.split(init_seed, num_particles)) + initial_state = jax.vmap(_init)(jax.random.split(seed, num_particles)) + # It is okay to reuse the initial_state[0] as the test point, as Bayeux + # only uses it to figure out the treedef. + test_point = jax.tree_map(lambda t: t[0], initial_state) if for_vi: @@ -67,7 +69,7 @@ def log_density(params, *, seed=None): log_density = functools.partial( net.log_prob, data=x_train, observations=y_train) return bx.Model( - log_density=log_density, + log_density=jax.jit(log_density), test_point=test_point, initial_state=initial_state, transform_fn=transform,