From 8947615fe5b19cbce00441b2a62b4929bcd2ca15 Mon Sep 17 00:00:00 2001 From: thomaswc Date: Wed, 6 Mar 2024 10:59:35 -0800 Subject: [PATCH] Migrate users of tfp.experimental.substrates.jax to import it as tensorflow_probability.substrates.jax and to use the JAX specific BUILD target. PiperOrigin-RevId: 613275194 --- discussion/examples/TFP_and_Jax.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/discussion/examples/TFP_and_Jax.ipynb b/discussion/examples/TFP_and_Jax.ipynb index 8f58d26aa6..42dfdca5fe 100644 --- a/discussion/examples/TFP_and_Jax.ipynb +++ b/discussion/examples/TFP_and_Jax.ipynb @@ -96,8 +96,8 @@ "source": [ "# Importing the TFP with Jax backend\n", "!pip3 install -q 'tfp-nightly[jax]' tf-nightly-cpu # We (currently) still require TF, but TF's smaller CPU build will work.\n", - "import tensorflow_probability as tfp\n", - "tfp = tfp.experimental.substrates.jax\n", + "import tensorflow_probability.substrates.jax as tfp\n", + "\n", "tf = tfp.tf2jax\n", "\n", "# Standard TFP Imports\n",