diff --git a/tensorflow_probability/python/experimental/mcmc/windowed_sampling.py b/tensorflow_probability/python/experimental/mcmc/windowed_sampling.py index ab30e69471..643b409864 100644 --- a/tensorflow_probability/python/experimental/mcmc/windowed_sampling.py +++ b/tensorflow_probability/python/experimental/mcmc/windowed_sampling.py @@ -642,9 +642,11 @@ def windowed_adaptive_nuts(n_draws, structure should broadcast with `current_state`. For example, if the initial state is ``` - {'a': tf.zeros(n_chains), - 'b': tf.zeros([n_chains, n_features])} - ``` + { + 'a': tf.zeros(n_chains), + 'b': tf.zeros([n_chains, n_features]), + } + ``` then any of `1.`, `{'a': 1., 'b': 1.}`, or `{'a': tf.ones(n_chains), 'b': tf.ones([n_chains, n_features])}` will work. Defaults to the dimension of the log density to the 0.25 power.