From f32c8d4e3ffb44cdaf8c086ab5577bbc8d489516 Mon Sep 17 00:00:00 2001 From: vanderplas Date: Mon, 11 Mar 2024 11:35:01 -0700 Subject: [PATCH] minimize_stateless: avoid reusing initialization seed Discovered by running tests with `jax_enable_key_reuse_checks=True`. PiperOrigin-RevId: 614737610 --- tensorflow_probability/python/math/minimize.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow_probability/python/math/minimize.py b/tensorflow_probability/python/math/minimize.py index 001ee1f5a0..e130c94477 100644 --- a/tensorflow_probability/python/math/minimize.py +++ b/tensorflow_probability/python/math/minimize.py @@ -138,6 +138,9 @@ def run_jitted_minimize(): seed_is_none = seed is None if not seed_is_none: seed = samplers.sanitize_seed(seed, salt='minimize') + init_seed, seed = samplers.split_seed(seed, n=2) + else: + init_seed = None if not return_full_length_trace: # Augment trace to record convergence info, so we can truncate it later. @@ -153,7 +156,7 @@ def run_jitted_minimize(): initial_optimizer_state) = optimizer_step_fn( parameters=initial_parameters, optimizer_state=initial_optimizer_state, - seed=seed) + seed=init_seed) initial_convergence_criterion_state = () if convergence_criterion is not None: