diff --git a/tensorflow_probability/python/internal/test_util.py b/tensorflow_probability/python/internal/test_util.py index ae17d5a9cc..ef10557afd 100644 --- a/tensorflow_probability/python/internal/test_util.py +++ b/tensorflow_probability/python/internal/test_util.py @@ -1546,12 +1546,16 @@ def test_seed(hardcoded_seed=None, def clone_seed(seed): """Clone a seed: this is useful for JAX's experimental key reuse checking.""" - # TODO(b/328085305): switch to standard clone API when possible. if JAX_MODE: import jax # pylint: disable=g-import-not-at-top - return jax.random.wrap_key_data( - jax.random.key_data(seed), impl=jax.random.key_impl(seed) - ) + if hasattr(jax.random, 'clone'): + # jax v0.4.26 or later + return jax.random.clone(seed) + else: + # older jax versions + return jax.random.wrap_key_data( + jax.random.key_data(seed), impl=jax.random.key_impl(seed) + ) return seed