diff --git a/tensorflow_probability/python/internal/samplers.py b/tensorflow_probability/python/internal/samplers.py index 3cd5aa484d..f8afd26046 100644 --- a/tensorflow_probability/python/internal/samplers.py +++ b/tensorflow_probability/python/internal/samplers.py @@ -14,6 +14,7 @@ # ============================================================================ """Random samplers.""" +import contextlib import hashlib import warnings @@ -48,6 +49,18 @@ SEED_DTYPE = np.uint32 if JAX_MODE else np.int32 +_old_salt = False + + +@contextlib.contextmanager +def enable_old_salt(enable): + global _old_salt + try: + _old_salt = enable + yield + finally: + _old_salt = False + def zeros_seed(): if JAX_MODE: @@ -140,9 +153,9 @@ def sanitize_seed(seed, salt=None, name=None): # discipline of splitting. if salt is not None: - salt = int(hashlib.sha512(str(salt).encode('utf-8')).hexdigest(), 16) % ( - 2**31 - 1 - ) + salt = int(hashlib.sha512(str(salt).encode('utf-8')).hexdigest(), 16) + if not _old_salt: + salt = salt % (2**31 - 1) seed = fold_in(seed, salt) if JAX_MODE: diff --git a/tensorflow_probability/python/internal/samplers_test.py b/tensorflow_probability/python/internal/samplers_test.py index 8c64db4a54..05d380d218 100644 --- a/tensorflow_probability/python/internal/samplers_test.py +++ b/tensorflow_probability/python/internal/samplers_test.py @@ -17,6 +17,7 @@ from absl import flags from absl.testing import parameterized import numpy as np +import tensorflow.compat.v1 as tf1 import tensorflow.compat.v2 as tf from tensorflow_probability.python.internal import samplers @@ -40,6 +41,16 @@ def setUp(self): from jax import config # pylint: disable=g-import-not-at-top config.update('jax_default_prng_impl', FLAGS.test_tfp_jax_prng) + @test_util.substrate_disable_stateful_random_test + def test_old_salt(self): + if not tf1.control_flow_v2_enabled(): + self.skipTest('TF2 only.') + with samplers.enable_old_salt(True): + seed = samplers.sanitize_seed(0, salt='nacl') + seed = samplers.sanitize_seed(seed, salt='kcl') + val = samplers.uniform([5], 0, 1000, dtype=tf.int32, seed=seed) + self.assertAllEqual([483, 61, 906, 125, 381], self.evaluate(val)) + def test_new_style_jax_keys(self): if not JAX_MODE: self.skipTest('JAX-only distinction')