Skip to content

Commit

Permalink
Make test_seed(sampler_type='integer') work in 32 bits on JAX.
Browse files Browse the repository at this point in the history
This is a pain to test automatically. I manually disabled `jax_enable_x64`
temporarily.

PiperOrigin-RevId: 611549461
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Feb 29, 2024
1 parent f8a1b82 commit ae11c4c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
4 changes: 3 additions & 1 deletion tensorflow_probability/python/internal/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,10 @@ def get_integer_seed(seed):
if isinstance(seed, six.integer_types):
return seed % (2**31)
seed = sanitize_seed(seed)
# maxval is exclusive, so technically this doesn't generate all possible
# non-negative integers, but it's good enough for our purposes.
integer_seed = tf.random.stateless_uniform(
shape=[], seed=seed, minval=0, maxval=2**31, dtype=tf.int32)
shape=[], seed=seed, minval=0, maxval=2**31 - 1, dtype=tf.int32)
if JAX_MODE:
# This function isn't ever used in a jit context, so we can eagerly convert
# it to an integer to simplify caller's code.
Expand Down
15 changes: 11 additions & 4 deletions tensorflow_probability/python/internal/test_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,20 @@ def _maybe_jax(x):


@test_util.test_all_tf_execution_regimes
class SeedSettingTest(test_util.TestCase):
class SeedTest(test_util.TestCase):

def testTypeCorrectness(self):
assert isinstance(test_util.test_seed_stream(), SeedStream)
assert isinstance(
self.assertIsInstance(test_util.test_seed_stream(), SeedStream)
self.assertIsInstance(
test_util.test_seed_stream(hardcoded_seed=7), SeedStream)
assert isinstance(test_util.test_seed_stream(salt='foo'), SeedStream)
self.assertIsInstance(test_util.test_seed_stream(salt='foo'), SeedStream)

self.assertIsInstance(test_util.test_seed(sampler_type='integer'), int)
if not JAX_MODE:
self.assertIsInstance(test_util.test_seed(sampler_type='stateful'), int)
self.assertIsInstance(
test_util.test_seed(sampler_type='stateless'), tf.Tensor
)

def testSameness(self):
with flagsaver.flagsaver(vary_seed=False):
Expand Down

0 comments on commit ae11c4c

Please sign in to comment.