diff --git a/tensorflow_probability/python/experimental/fastgp/fast_gp.py b/tensorflow_probability/python/experimental/fastgp/fast_gp.py index f382271f6c..bc741b93ca 100644 --- a/tensorflow_probability/python/experimental/fastgp/fast_gp.py +++ b/tensorflow_probability/python/experimental/fastgp/fast_gp.py @@ -85,15 +85,11 @@ class GaussianProcess(distribution.Distribution): """Fast, JAX-only implementation of a GP distribution class. See tfd.distributions.GaussianProcess for a description and parameter - documentation. Currently only supports log_prob and posterior_predictive - (the only two methods used by smc.py). + documentation, but note that not all of that class's methods are supported. The default parameters are tuned to give a good time / error trade-off in the n > 15,000 regime where this class gives a substantial speed-up - over tfd.distributions.GaussianProcess. In particular, it is tuned to - give a trade-off in the case where you care about the accuracy of both - log_prob and its derivative. If you care only about log_prob, it is - recommended to use log_det_algorithm='slq' with preconditioner_num_iters=1. + over tfd.distributions.GaussianProcess. """ def __init__( @@ -311,7 +307,26 @@ def get_preconditioner(cov): @jax.named_call def log_prob(self, value, key, is_missing=None) -> Array: - """log P(value | GP).""" + """log P(value | GP). + + Args: + value: `float` or `double` jax.Array. + key: A jax KeyArray. This method uses stochastic methods to quickly + estimate the log probability of `value`, and `key` is needed to + generate the stochasticity. `key` is also used when computing the + derivative of this function. In some circumstances it is acceptable + and in fact even necessary to pass the same value of `key` to multiple + invocations of log_prob; for example if the log_prob is being + optimized by an algorithm that assumes it is deterministic. + is_missing: Optional `bool` jax.Array of shape `[..., e]` where `e` is + the number of index points in each batch. Represents a batch of + Boolean masks. When not `None`, the returned log_prob is for the + *marginal* distribution in which all dimensions with `is_missing==True` + have been marginalized out. + + Returns: + A stochastic approximation to log P(value | GP). + """ empty_sample_batch_shape = value.ndim == 1 if empty_sample_batch_shape: value = value[jnp.newaxis]