diff --git a/tensorflow_probability/python/math/minimize.py b/tensorflow_probability/python/math/minimize.py index 001ee1f5a0..e130c94477 100644 --- a/tensorflow_probability/python/math/minimize.py +++ b/tensorflow_probability/python/math/minimize.py @@ -138,6 +138,9 @@ def run_jitted_minimize(): seed_is_none = seed is None if not seed_is_none: seed = samplers.sanitize_seed(seed, salt='minimize') + init_seed, seed = samplers.split_seed(seed, n=2) + else: + init_seed = None if not return_full_length_trace: # Augment trace to record convergence info, so we can truncate it later. @@ -153,7 +156,7 @@ def run_jitted_minimize(): initial_optimizer_state) = optimizer_step_fn( parameters=initial_parameters, optimizer_state=initial_optimizer_state, - seed=seed) + seed=init_seed) initial_convergence_criterion_state = () if convergence_criterion is not None: