From 4b44fd81949cf9c58ff214280eed6d5b360563a8 Mon Sep 17 00:00:00 2001 From: thomaswc Date: Tue, 6 Feb 2024 10:25:52 -0800 Subject: [PATCH] Annotate potentially expensive methods with @jax.named_call to enable identification of hotspots. PiperOrigin-RevId: 604690173 --- .../python/experimental/autobnn/bnn.py | 2 ++ .../python/experimental/autobnn/estimators.py | 1 + .../python/experimental/autobnn/kernels.py | 5 +++++ .../python/experimental/autobnn/likelihoods.py | 1 + .../python/experimental/autobnn/operators.py | 9 +++++++++ .../python/experimental/autobnn/training_util.py | 3 +++ 6 files changed, 21 insertions(+) diff --git a/tensorflow_probability/python/experimental/autobnn/bnn.py b/tensorflow_probability/python/experimental/autobnn/bnn.py index cb33e373c1..d153d7d9ba 100644 --- a/tensorflow_probability/python/experimental/autobnn/bnn.py +++ b/tensorflow_probability/python/experimental/autobnn/bnn.py @@ -18,12 +18,14 @@ import flax from flax import linen as nn +import jax import jax.numpy as jnp from jaxtyping import Array, Float, PyTree # pylint: disable=g-importing-member,g-multiple-import from tensorflow_probability.python.experimental.autobnn import likelihoods from tensorflow_probability.substrates.jax.distributions import distribution as distribution_lib +@jax.named_call def log_prior_of_parameters(params, distributions) -> Float: """Return the prior of the parameters according to the distributions.""" if 'params' in params: diff --git a/tensorflow_probability/python/experimental/autobnn/estimators.py b/tensorflow_probability/python/experimental/autobnn/estimators.py index d1d91b5637..18f3fdfeb2 100644 --- a/tensorflow_probability/python/experimental/autobnn/estimators.py +++ b/tensorflow_probability/python/experimental/autobnn/estimators.py @@ -130,6 +130,7 @@ def summary(self) -> str: summaries = [self.net_.summarize(p) for p in params_per_particle] return '\n'.join(summaries) + @jax.named_call def predict_quantiles( self, X: jax.Array, q=(2.5, 50.0, 97.5), axis: tuple[int, ...] = (0,) # pylint: disable=invalid-name ) -> jax.Array: diff --git a/tensorflow_probability/python/experimental/autobnn/kernels.py b/tensorflow_probability/python/experimental/autobnn/kernels.py index f84dca7f4e..cc8cfc4c23 100644 --- a/tensorflow_probability/python/experimental/autobnn/kernels.py +++ b/tensorflow_probability/python/experimental/autobnn/kernels.py @@ -14,6 +14,7 @@ # ============================================================================ """`Leaf` BNNs, most of which correspond to some known GP kernel.""" +import functools from flax import linen as nn from flax.linen import initializers import jax @@ -104,10 +105,12 @@ def distributions(self): } return super().distributions() | d + @functools.partial(jax.named_call, name='OneLayer::penultimate') def penultimate(self, inputs): y = self.input_warping(inputs) return self.activation_function(self.dense1(y)) + @functools.partial(jax.named_call, name='OneLayer::__call__') def __call__(self, inputs, deterministic=True): return self.dense2(self.penultimate(inputs)) @@ -142,6 +145,7 @@ def distributions(self): }, } + @functools.partial(jax.named_call, name='RBF::__call__') def __call__(self, inputs, deterministic=True): return self.amplitude * self.dense2(self.penultimate(inputs)) @@ -214,6 +218,7 @@ def bias_init(seed, shape, dtype=jnp.float32): for _ in range(self.degree)] super().setup() + @functools.partial(jax.named_call, name='Polynomial::penultimate') def penultimate(self, inputs): x = inputs - self.shift ys = jnp.stack([h(x) for h in self.hiddens], axis=-1) diff --git a/tensorflow_probability/python/experimental/autobnn/likelihoods.py b/tensorflow_probability/python/experimental/autobnn/likelihoods.py index 384d9c6735..4c61c58381 100644 --- a/tensorflow_probability/python/experimental/autobnn/likelihoods.py +++ b/tensorflow_probability/python/experimental/autobnn/likelihoods.py @@ -49,6 +49,7 @@ def distributions(self): """Like BayesianModule::distributions but for the model's parameters.""" return {} + @jax.named_call def log_likelihood( self, params, nn_out: jax.Array, observations: jax.Array ) -> jax.Array: diff --git a/tensorflow_probability/python/experimental/autobnn/operators.py b/tensorflow_probability/python/experimental/autobnn/operators.py index 6772ab56d7..2872bc46ec 100644 --- a/tensorflow_probability/python/experimental/autobnn/operators.py +++ b/tensorflow_probability/python/experimental/autobnn/operators.py @@ -14,8 +14,10 @@ # ============================================================================ """Flax.linen modules for combining BNNs.""" +import functools from typing import Optional from flax import linen as nn +import jax import jax.numpy as jnp from tensorflow_probability.python.experimental.autobnn import bnn from tensorflow_probability.python.experimental.autobnn import likelihoods @@ -54,6 +56,7 @@ def set_likelihood_model(self, likelihood_model: likelihoods.LikelihoodModel): for b in self.bnns: b.set_likelihood_model(dummy_ll_model) + @jax.named_call def log_prior(self, params): if 'params' in params: params = params['params'] @@ -114,6 +117,7 @@ def penultimate(self, inputs): class Add(MultipliableBnnOperator): """Add two or more BNNs.""" + @functools.partial(jax.named_call, name='Add::penultimate') def penultimate(self, inputs): penultimates = [b.penultimate(inputs) for b in self.bnns] return jnp.sum(jnp.stack(penultimates, axis=-1), axis=-1) @@ -147,6 +151,7 @@ def distributions(self): 'bnn_weights': dirichlet_lib.Dirichlet(concentration=concentration) } + @functools.partial(jax.named_call, name='WeightedSum::penultimate') def penultimate(self, inputs): penultimates = [ b.penultimate(inputs) * self.bnn_weights[0, i] @@ -154,6 +159,7 @@ def penultimate(self, inputs): ] return jnp.sum(jnp.stack(penultimates, axis=-1), axis=-1) + @functools.partial(jax.named_call, name='WeightedSum::__call__') def __call__(self, inputs, deterministic=True): return jnp.sum( jnp.stack( @@ -217,6 +223,7 @@ def distributions(self): } } + @functools.partial(jax.named_call, name='Multiply::__call__') def __call__(self, inputs, deterministic=True): penultimates = [b.penultimate(inputs) for b in self.bnns] return self.dense(jnp.prod(jnp.stack(penultimates, axis=-1), axis=-1)) @@ -235,6 +242,7 @@ def setup(self): assert len(self.bnns) == 2 super().setup() + @jax.named_call def __call__(self, inputs, deterministic=True): time = inputs[..., self.change_index, jnp.newaxis] y = (time - self.change_point) / self.slope @@ -274,6 +282,7 @@ def setup(self): assert len(self.time_series_xs) >= 2 super().setup() + @functools.partial(jax.named_call, name='LearnableChangePoint::__call__') def __call__(self, inputs, deterministic=True): time = inputs[..., self.change_index, jnp.newaxis] y = (time - self.change_point) / self.change_slope diff --git a/tensorflow_probability/python/experimental/autobnn/training_util.py b/tensorflow_probability/python/experimental/autobnn/training_util.py index 1ff680ce4a..44d344788a 100644 --- a/tensorflow_probability/python/experimental/autobnn/training_util.py +++ b/tensorflow_probability/python/experimental/autobnn/training_util.py @@ -75,6 +75,7 @@ def log_density(params, *, seed=None): inverse_log_det_jacobian=ildj) +@jax.named_call def fit_bnn_map( net: bnn.BNN, seed: jax.Array, @@ -125,6 +126,7 @@ def _filter_stuck_chains(params): return jax.tree_map(lambda x: x[best_two], params) +@jax.named_call def fit_bnn_vi( net: bnn.BNN, seed: jax.Array, @@ -148,6 +150,7 @@ def fit_bnn_vi( return params, {'loss': loss} +@jax.named_call def fit_bnn_mcmc( net: bnn.BNN, seed: jax.Array,