diff --git a/tensorflow_probability/python/vi/csiszar_divergence.py b/tensorflow_probability/python/vi/csiszar_divergence.py index 925fb38db0..36049c0ec6 100644 --- a/tensorflow_probability/python/vi/csiszar_divergence.py +++ b/tensorflow_probability/python/vi/csiszar_divergence.py @@ -1074,7 +1074,6 @@ def monte_carlo_variational_loss( raise TypeError('`target_log_prob_fn` must be a Python `callable`' 'function.') - sample_seed, target_seed = samplers.split_seed(seed, 2) reparameterization_types = tf.nest.flatten( surrogate_posterior.reparameterization_type) if gradient_estimator is None: @@ -1089,6 +1088,8 @@ def monte_carlo_variational_loss( num_draws=importance_sample_size, num_batch_draws=sample_size, seed=seed) + + sample_seed, target_seed = samplers.split_seed(seed, 2) if gradient_estimator == GradientEstimators.SCORE_FUNCTION: if tf.get_static_value(importance_sample_size) != 1: # TODO(b/213378570): Support score function gradients for