Skip to content

Commit

Permalink
test_util: use jax.random.clone when available
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 614035309
  • Loading branch information
vanderplas authored and tensorflower-gardener committed Mar 8, 2024
1 parent 0af6f41 commit 0435c36
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions tensorflow_probability/python/internal/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,12 +1546,16 @@ def test_seed(hardcoded_seed=None,

def clone_seed(seed):
"""Clone a seed: this is useful for JAX's experimental key reuse checking."""
# TODO(b/328085305): switch to standard clone API when possible.
if JAX_MODE:
import jax # pylint: disable=g-import-not-at-top
return jax.random.wrap_key_data(
jax.random.key_data(seed), impl=jax.random.key_impl(seed)
)
if hasattr(jax.random, 'clone'):
# jax v0.4.26 or later
return jax.random.clone(seed)
else:
# older jax versions
return jax.random.wrap_key_data(
jax.random.key_data(seed), impl=jax.random.key_impl(seed)
)
return seed


Expand Down

0 comments on commit 0435c36

Please sign in to comment.