Skip to content

Commit

Permalink
Jit the BNN's log_prob and init calculations.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 617210248
  • Loading branch information
ThomasColthurst authored and tensorflower-gardener committed Mar 19, 2024
1 parent 1df44cc commit b0abbc7
Showing 1 changed file with 8 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,16 @@ def _make_bayeux_model(
# which in turn imports (through __init__.py files) autobnn.
import bayeux as bx # pylint:disable=g-bad-import-order,g-import-not-at-top

test_seed, init_seed = jax.random.split(seed)
test_point = net.init(test_seed, x_train)
transform, inverse_transform, ildj = util.make_transforms(net)

def _init(seed):
return net.init(seed, x_train)
@jax.jit
def _init(rand_seed):
return net.init(rand_seed, x_train)

initial_state = jax.vmap(_init)(jax.random.split(init_seed, num_particles))
initial_state = jax.vmap(_init)(jax.random.split(seed, num_particles))
# It is okay to reuse the initial_state[0] as the test point, as Bayeux
# only uses it to figure out the treedef.
test_point = jax.tree_map(lambda t: t[0], initial_state)

if for_vi:

Expand All @@ -67,7 +69,7 @@ def log_density(params, *, seed=None):
log_density = functools.partial(
net.log_prob, data=x_train, observations=y_train)
return bx.Model(
log_density=log_density,
log_density=jax.jit(log_density),
test_point=test_point,
initial_state=initial_state,
transform_fn=transform,
Expand Down

0 comments on commit b0abbc7

Please sign in to comment.