From 0435c3627ea72ad97113981776aeff43c4620bde Mon Sep 17 00:00:00 2001 From: vanderplas Date: Fri, 8 Mar 2024 13:52:32 -0800 Subject: [PATCH] test_util: use jax.random.clone when available PiperOrigin-RevId: 614035309 --- tensorflow_probability/python/internal/test_util.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tensorflow_probability/python/internal/test_util.py b/tensorflow_probability/python/internal/test_util.py index ae17d5a9cc..ef10557afd 100644 --- a/tensorflow_probability/python/internal/test_util.py +++ b/tensorflow_probability/python/internal/test_util.py @@ -1546,12 +1546,16 @@ def test_seed(hardcoded_seed=None, def clone_seed(seed): """Clone a seed: this is useful for JAX's experimental key reuse checking.""" - # TODO(b/328085305): switch to standard clone API when possible. if JAX_MODE: import jax # pylint: disable=g-import-not-at-top - return jax.random.wrap_key_data( - jax.random.key_data(seed), impl=jax.random.key_impl(seed) - ) + if hasattr(jax.random, 'clone'): + # jax v0.4.26 or later + return jax.random.clone(seed) + else: + # older jax versions + return jax.random.wrap_key_data( + jax.random.key_data(seed), impl=jax.random.key_impl(seed) + ) return seed