From 898cfe9f62ebddae834c98971163703e852e3124 Mon Sep 17 00:00:00 2001 From: vanderplas Date: Thu, 7 Mar 2024 12:24:50 -0800 Subject: [PATCH] TFP: fix key reuse issue in monte_carlo_variational_loss Detected with JAX's enable_key_reuse_checks. We can avoid falling afoul of the reuse checker by splitting only once we determine it to be necessary. This should not have any user-visible change. PiperOrigin-RevId: 613666004 --- tensorflow_probability/python/vi/csiszar_divergence.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow_probability/python/vi/csiszar_divergence.py b/tensorflow_probability/python/vi/csiszar_divergence.py index 925fb38db0..36049c0ec6 100644 --- a/tensorflow_probability/python/vi/csiszar_divergence.py +++ b/tensorflow_probability/python/vi/csiszar_divergence.py @@ -1074,7 +1074,6 @@ def monte_carlo_variational_loss( raise TypeError('`target_log_prob_fn` must be a Python `callable`' 'function.') - sample_seed, target_seed = samplers.split_seed(seed, 2) reparameterization_types = tf.nest.flatten( surrogate_posterior.reparameterization_type) if gradient_estimator is None: @@ -1089,6 +1088,8 @@ def monte_carlo_variational_loss( num_draws=importance_sample_size, num_batch_draws=sample_size, seed=seed) + + sample_seed, target_seed = samplers.split_seed(seed, 2) if gradient_estimator == GradientEstimators.SCORE_FUNCTION: if tf.get_static_value(importance_sample_size) != 1: # TODO(b/213378570): Support score function gradients for