From ae11c4c5e66744edf7137a423805afe8350083a0 Mon Sep 17 00:00:00 2001 From: siege Date: Thu, 29 Feb 2024 11:59:42 -0800 Subject: [PATCH] Make test_seed(sampler_type='integer') work in 32 bits on JAX. This is a pain to test automatically. I manually disabled `jax_enable_x64` temporarily. PiperOrigin-RevId: 611549461 --- .../python/internal/samplers.py | 4 +++- .../python/internal/test_util_test.py | 15 +++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/tensorflow_probability/python/internal/samplers.py b/tensorflow_probability/python/internal/samplers.py index 0ce1ff8caf..cbb076024d 100644 --- a/tensorflow_probability/python/internal/samplers.py +++ b/tensorflow_probability/python/internal/samplers.py @@ -165,8 +165,10 @@ def get_integer_seed(seed): if isinstance(seed, six.integer_types): return seed % (2**31) seed = sanitize_seed(seed) + # maxval is exclusive, so technically this doesn't generate all possible + # non-negative integers, but it's good enough for our purposes. integer_seed = tf.random.stateless_uniform( - shape=[], seed=seed, minval=0, maxval=2**31, dtype=tf.int32) + shape=[], seed=seed, minval=0, maxval=2**31 - 1, dtype=tf.int32) if JAX_MODE: # This function isn't ever used in a jit context, so we can eagerly convert # it to an integer to simplify caller's code. diff --git a/tensorflow_probability/python/internal/test_util_test.py b/tensorflow_probability/python/internal/test_util_test.py index d6c1508df9..4c2511d9ab 100644 --- a/tensorflow_probability/python/internal/test_util_test.py +++ b/tensorflow_probability/python/internal/test_util_test.py @@ -45,13 +45,20 @@ def _maybe_jax(x): @test_util.test_all_tf_execution_regimes -class SeedSettingTest(test_util.TestCase): +class SeedTest(test_util.TestCase): def testTypeCorrectness(self): - assert isinstance(test_util.test_seed_stream(), SeedStream) - assert isinstance( + self.assertIsInstance(test_util.test_seed_stream(), SeedStream) + self.assertIsInstance( test_util.test_seed_stream(hardcoded_seed=7), SeedStream) - assert isinstance(test_util.test_seed_stream(salt='foo'), SeedStream) + self.assertIsInstance(test_util.test_seed_stream(salt='foo'), SeedStream) + + self.assertIsInstance(test_util.test_seed(sampler_type='integer'), int) + if not JAX_MODE: + self.assertIsInstance(test_util.test_seed(sampler_type='stateful'), int) + self.assertIsInstance( + test_util.test_seed(sampler_type='stateless'), tf.Tensor + ) def testSameness(self): with flagsaver.flagsaver(vary_seed=False):