From 4e1d631669c28f8b09aaf488e2a67759f0a3e365 Mon Sep 17 00:00:00 2001 From: thomaswc Date: Fri, 26 Apr 2024 12:24:39 -0700 Subject: [PATCH] Add fastgp to tfp.experimental. PiperOrigin-RevId: 628478279 --- .../python/experimental/BUILD | 2 ++ .../python/experimental/__init__.py | 7 ++++ .../python/experimental/fastgp/BUILD | 13 -------- .../python/experimental/fastgp/__init__.py | 16 +++++++++ .../python/experimental/fastgp/fast_gp.py | 2 +- .../experimental/fastgp/fast_log_det.py | 33 +++++++++---------- .../python/experimental/fastgp/linalg.py | 5 ++- .../experimental/fastgp/preconditioners.py | 22 ++++++------- 8 files changed, 54 insertions(+), 46 deletions(-) diff --git a/tensorflow_probability/python/experimental/BUILD b/tensorflow_probability/python/experimental/BUILD index f955ce118a..029398427f 100644 --- a/tensorflow_probability/python/experimental/BUILD +++ b/tensorflow_probability/python/experimental/BUILD @@ -52,6 +52,7 @@ multi_substrate_py_library( "//tensorflow_probability/python/experimental/bijectors", "//tensorflow_probability/python/experimental/distribute", "//tensorflow_probability/python/experimental/distributions", + "//tensorflow_probability/python/experimental/fastgp", "//tensorflow_probability/python/experimental/joint_distribution_layers", "//tensorflow_probability/python/experimental/linalg", "//tensorflow_probability/python/experimental/marginalize", @@ -70,5 +71,6 @@ multi_substrate_py_library( "//tensorflow_probability/python/experimental/vi", "//tensorflow_probability/python/internal:all_util", "//tensorflow_probability/python/internal:auto_composite_tensor", + "//tensorflow_probability/python/internal:lazy_loader", ], ) diff --git a/tensorflow_probability/python/experimental/__init__.py b/tensorflow_probability/python/experimental/__init__.py index d316a44d90..7b7d5c895c 100644 --- a/tensorflow_probability/python/experimental/__init__.py +++ b/tensorflow_probability/python/experimental/__init__.py @@ -50,9 +50,15 @@ from tensorflow_probability.python.experimental.util.composite_tensor import as_composite from tensorflow_probability.python.experimental.util.composite_tensor import register_composite from tensorflow_probability.python.internal import all_util +from tensorflow_probability.python.internal import lazy_loader from tensorflow_probability.python.internal.auto_composite_tensor import auto_composite_tensor from tensorflow_probability.python.internal.auto_composite_tensor import AutoCompositeTensor +# TODO(thomaswc): Figure out why fastgp needs to be lazy_loaded. +globals()['fastgp'] = lazy_loader.LazyLoader( + 'fastgp', globals(), 'tensorflow_probability.python.experimental.fastgp' +) + _allowed_symbols = [ 'auto_batching', @@ -63,6 +69,7 @@ 'bijectors', 'distribute', 'distributions', + 'fastgp', 'joint_distribution_layers', 'linalg', 'marginalize', diff --git a/tensorflow_probability/python/experimental/fastgp/BUILD b/tensorflow_probability/python/experimental/fastgp/BUILD index f27a3bd7f2..1c93eec5da 100644 --- a/tensorflow_probability/python/experimental/fastgp/BUILD +++ b/tensorflow_probability/python/experimental/fastgp/BUILD @@ -62,7 +62,6 @@ py_library( name = "mbcg", srcs = ["mbcg.py"], deps = [ - # jax dep, ], ) @@ -84,7 +83,6 @@ py_library( ":fast_log_det", ":mbcg", ":preconditioners", - # jax dep, "//tensorflow_probability/python/bijectors:softplus.jax", "//tensorflow_probability/python/distributions:distribution.jax", "//tensorflow_probability/python/distributions:gaussian_process_regression_model.jax", @@ -121,7 +119,6 @@ py_library( ":linear_operator_sum", ":mbcg", ":preconditioners", - # jax dep, "//tensorflow_probability/python/distributions:distribution.jax", "//tensorflow_probability/python/distributions/internal:stochastic_process_util.jax", "//tensorflow_probability/python/experimental/psd_kernels:multitask_kernel.jax", @@ -153,7 +150,6 @@ py_library( ":fast_gp", ":preconditioners", ":schur_complement", - # jax dep, "//tensorflow_probability/python/bijectors:softplus.jax", "//tensorflow_probability/python/distributions/internal:stochastic_process_util.jax", "//tensorflow_probability/python/internal:dtype_util.jax", @@ -180,9 +176,7 @@ py_library( srcs = ["linalg.py"], deps = [ ":partial_lanczos", - # jax dep, # jax:experimental_sparse dep, - # jaxtyping dep, "//tensorflow_probability/python/internal/backend/jax", ], ) @@ -205,7 +199,6 @@ py_library( srcs = ["partial_lanczos.py"], deps = [ ":mbcg", - # jax dep, # scipy dep, "//tensorflow_probability/python/internal/backend/jax", ], @@ -231,8 +224,6 @@ py_library( ":mbcg", ":partial_lanczos", ":preconditioners", - # jax dep, - # jaxtyping dep, # numpy dep, # scipy dep, ], @@ -266,9 +257,6 @@ py_library( deps = [ ":linalg", ":linear_operator_sum", - # jax dep, - # jax:experimental_sparse dep, - # jaxtyping dep, "//tensorflow_probability/python/internal/backend/jax", "//tensorflow_probability/python/math:linalg.jax", ], @@ -292,7 +280,6 @@ py_library( srcs = ["schur_complement.py"], deps = [ ":preconditioners", - # jax dep, "//tensorflow_probability/python/bijectors:softplus.jax", "//tensorflow_probability/python/internal:distribution_util.jax", "//tensorflow_probability/python/internal:dtype_util.jax", diff --git a/tensorflow_probability/python/experimental/fastgp/__init__.py b/tensorflow_probability/python/experimental/fastgp/__init__.py index 828ff72c70..d8240969d3 100644 --- a/tensorflow_probability/python/experimental/fastgp/__init__.py +++ b/tensorflow_probability/python/experimental/fastgp/__init__.py @@ -24,9 +24,25 @@ from tensorflow_probability.python.experimental.fastgp import partial_lanczos from tensorflow_probability.python.experimental.fastgp import preconditioners from tensorflow_probability.python.experimental.fastgp import schur_complement +from tensorflow_probability.python.experimental.fastgp.fast_gp import GaussianProcess +from tensorflow_probability.python.experimental.fastgp.fast_gp import GaussianProcessConfig +from tensorflow_probability.python.experimental.fastgp.fast_gprm import GaussianProcessRegressionModel +from tensorflow_probability.python.experimental.fastgp.fast_log_det import get_log_det_algorithm +from tensorflow_probability.python.experimental.fastgp.fast_log_det import ProbeVectorType +from tensorflow_probability.python.experimental.fastgp.fast_mtgp import MultiTaskGaussianProcess +from tensorflow_probability.python.experimental.fastgp.preconditioners import get_preconditioner +from tensorflow_probability.python.experimental.fastgp.schur_complement import SchurComplement from tensorflow_probability.python.internal import all_util _allowed_symbols = [ + 'GaussianProcessConfig', + 'GaussianProcess', + 'GaussianProcessRegressionModel', + 'ProbeVectorType', + 'get_log_det_algorithm', + 'MultiTaskGaussianProcess', + 'get_preconditioner', + 'SchurComplement', 'fast_log_det', 'fast_gp', 'fast_gprm', diff --git a/tensorflow_probability/python/experimental/fastgp/fast_gp.py b/tensorflow_probability/python/experimental/fastgp/fast_gp.py index bc741b93ca..32a3e08cbb 100644 --- a/tensorflow_probability/python/experimental/fastgp/fast_gp.py +++ b/tensorflow_probability/python/experimental/fastgp/fast_gp.py @@ -72,7 +72,7 @@ class GaussianProcessConfig: probe_vector_type: str = 'rademacher' # The number of probe vectors to use when estimating the log det. num_probe_vectors: int = 35 - # One of 'slq' (for stochastic Lancos quadrature) or + # One of 'slq' (for stochastic Lanczos quadrature) or # 'r1', 'r2', 'r3', 'r4', 'r5', or 'r6' for the rational function # approximation of the given order. log_det_algorithm: str = 'r3' diff --git a/tensorflow_probability/python/experimental/fastgp/fast_log_det.py b/tensorflow_probability/python/experimental/fastgp/fast_log_det.py index ae4fc146b8..a1f41997bf 100644 --- a/tensorflow_probability/python/experimental/fastgp/fast_log_det.py +++ b/tensorflow_probability/python/experimental/fastgp/fast_log_det.py @@ -22,7 +22,6 @@ import jax import jax.numpy as jnp -from jaxtyping import Float import numpy as np from tensorflow_probability.python.experimental.fastgp import mbcg from tensorflow_probability.python.experimental.fastgp import partial_lanczos @@ -159,7 +158,7 @@ def _log_det_rational_approx_with_hutchinson( probe_vectors: Array, key: jax.Array, num_iters: int, -) -> Float: +): """Approximate log det using a rational function. We calculate log det M as the trace of log M, and we approximate the @@ -295,7 +294,7 @@ def _r1( probe_vectors: Array, key: jax.Array, num_iters: int, -) -> Float: +): """Approximate log det using a 1st order rational function.""" return _log_det_rational_approx_with_hutchinson( jnp.asarray(R1_SHIFTS, dtype=probe_vectors.dtype), @@ -315,7 +314,7 @@ def _r2( probe_vectors: Array, key: jax.Array, num_iters: int, -) -> Float: +): """Approximate log det using a 2nd order rational function.""" return _log_det_rational_approx_with_hutchinson( jnp.asarray(R2_SHIFTS, dtype=probe_vectors.dtype), @@ -335,7 +334,7 @@ def _r3( probe_vectors: Array, key: jax.Array, num_iters: int, -) -> Float: +): """Approximate log det using a 4th order rational function.""" return _log_det_rational_approx_with_hutchinson( jnp.asarray(R3_SHIFTS, dtype=probe_vectors.dtype), @@ -355,7 +354,7 @@ def _r4( probe_vectors: Array, key: jax.Array, num_iters: int, -) -> Float: +): """Approximate log det using a 4th order rational function.""" return _log_det_rational_approx_with_hutchinson( jnp.asarray(R4_SHIFTS, dtype=probe_vectors.dtype), @@ -375,7 +374,7 @@ def _r5( probe_vectors: Array, key: jax.Array, num_iters: int, -) -> Float: +): """Approximate log det using a 4th order rational function.""" return _log_det_rational_approx_with_hutchinson( jnp.asarray(R5_SHIFTS, dtype=probe_vectors.dtype), @@ -395,7 +394,7 @@ def _r6( probe_vectors: Array, key: jax.Array, num_iters: int, -) -> Float: +): """Approximate log det using a 4th order rational function.""" return _log_det_rational_approx_with_hutchinson( jnp.asarray(R6_SHIFTS, dtype=probe_vectors.dtype), @@ -453,7 +452,7 @@ def r1( probe_vector_type: ProbeVectorType = ProbeVectorType.RADEMACHER, num_iters: int = 20, **unused_kwargs, -) -> Float: +): """Approximate log det using a 2nd order rational function.""" n = M.shape[-1] key1, key2 = jax.random.split(key) @@ -473,7 +472,7 @@ def r2( probe_vector_type: ProbeVectorType = ProbeVectorType.RADEMACHER, num_iters: int = 20, **unused_kwargs, -) -> Float: +): """Approximate log det using a 2nd order rational function.""" n = M.shape[-1] key1, key2 = jax.random.split(key) @@ -493,7 +492,7 @@ def r3( probe_vector_type: ProbeVectorType = ProbeVectorType.RADEMACHER, num_iters: int = 20, **unused_kwargs, -) -> Float: +): """Approximate log det using a 3rd order rational function.""" n = M.shape[-1] key1, key2 = jax.random.split(key) @@ -513,7 +512,7 @@ def r4( probe_vector_type: ProbeVectorType = ProbeVectorType.RADEMACHER, num_iters: int = 20, **unused_kwargs, -) -> Float: +): """Approximate log det using a 4th order rational function.""" n = M.shape[-1] key1, key2 = jax.random.split(key) @@ -533,7 +532,7 @@ def r5( probe_vector_type: ProbeVectorType = ProbeVectorType.RADEMACHER, num_iters: int = 20, **unused_kwargs, -) -> Float: +): """Approximate log det using a 5th order rational function.""" n = M.shape[-1] key1, key2 = jax.random.split(key) @@ -553,7 +552,7 @@ def r6( probe_vector_type: ProbeVectorType = ProbeVectorType.RADEMACHER, num_iters: int = 20, **unused_kwargs, -) -> Float: +): """Approximate log det using a 6th order rational function.""" n = M.shape[-1] key1, key2 = jax.random.split(key) @@ -597,7 +596,7 @@ def _stochastic_lanczos_quadrature_log_det( unused_key, probe_vectors_are_rademacher: bool, num_iters: int, -) -> Float: +): """Fast log det using the alg. from https://arxiv.org/pdf/1809.11165.pdf .""" n = M.shape[-1] @@ -639,7 +638,7 @@ def stochastic_lanczos_quadrature_log_det( probe_vector_type: ProbeVectorType = ProbeVectorType.RADEMACHER, num_iters: int = 25, **unused_kwargs, -) -> Float: +): """Fast log det using the alg. from https://arxiv.org/pdf/1809.11165.pdf .""" n = M.shape[-1] num_iters = min(n, num_iters) @@ -703,7 +702,7 @@ def log_det_taylor_series_with_hutchinson( num_probe_vectors: int, key: jax.Array, num_taylor_series_iterations: int = 10, -) -> Float: +): """Return an approximation of log det M.""" # TODO(thomaswc): Consider having this support a batch of LinearOperators. n = M.shape[0] diff --git a/tensorflow_probability/python/experimental/fastgp/linalg.py b/tensorflow_probability/python/experimental/fastgp/linalg.py index 4c28864bf8..51c64e53bd 100644 --- a/tensorflow_probability/python/experimental/fastgp/linalg.py +++ b/tensorflow_probability/python/experimental/fastgp/linalg.py @@ -18,7 +18,6 @@ import jax import jax.experimental.sparse import jax.numpy as jnp -from jaxtyping import Float import numpy as np from tensorflow_probability.python.experimental.fastgp import partial_lanczos from tensorflow_probability.python.internal.backend import jax as tf2jax @@ -36,7 +35,7 @@ def _matvec(M, x) -> jax.Array: def largest_eigenvector( M: tf2jax.linalg.LinearOperator, key: jax.Array, num_iters: int = 10 -) -> tuple[Float, Array]: +): """Returns the largest (eigenvalue, eigenvector) of M.""" n = M.shape[-1] v = jax.random.uniform(key, shape=(n,), dtype=M.dtype) @@ -55,7 +54,7 @@ def make_randomized_truncated_svd( rank: int = 20, oversampling: int = 10, num_iters: int = 4, -) -> tuple[Float, Array]: +): """Returns approximate SVD for symmetric `M`.""" # This is based on: # N. Halko, P.G. Martinsson, J. A. Tropp diff --git a/tensorflow_probability/python/experimental/fastgp/preconditioners.py b/tensorflow_probability/python/experimental/fastgp/preconditioners.py index be80af4729..1c11be48c1 100644 --- a/tensorflow_probability/python/experimental/fastgp/preconditioners.py +++ b/tensorflow_probability/python/experimental/fastgp/preconditioners.py @@ -38,9 +38,7 @@ """ import jax -import jax.experimental.sparse import jax.numpy as jnp -from jaxtyping import Float from tensorflow_probability.python.experimental.fastgp import linalg from tensorflow_probability.python.experimental.fastgp import linear_operator_sum from tensorflow_probability.python.internal.backend import jax as tf2jax @@ -82,7 +80,7 @@ def log_det(self) -> tf2jax.linalg.LinearOperator: """The log absolute value of the determinant of the preconditioner.""" return self.full_preconditioner().log_abs_determinant() - def trace_of_inverse_product(self, A: jax.Array) -> Float: + def trace_of_inverse_product(self, A: jax.Array): """Returns tr( P^(-1) A ) for a n x n, non-batched A.""" result = self.full_preconditioner().solve(A) if isinstance(result, tf2jax.linalg.LinearOperator): @@ -105,10 +103,10 @@ def full_preconditioner(self) -> tf2jax.linalg.LinearOperator: def preconditioned_operator(self) -> tf2jax.linalg.LinearOperator: return promote_to_operator(self.M) - def log_det(self) -> Float: + def log_det(self): return 0.0 - def trace_of_inverse_product(self, A: jax.Array) -> Float: + def trace_of_inverse_product(self, A: jax.Array): return jnp.trace(A) def tree_flatten(self): @@ -137,10 +135,10 @@ def preconditioned_operator(self) -> tf2jax.linalg.LinearOperator: [promote_to_operator(self.M), self.full_preconditioner().inverse()] ) - def log_det(self) -> Float: + def log_det(self): return jnp.sum(jnp.log(self.d)) - def trace_of_inverse_product(self, A: jax.Array) -> Float: + def trace_of_inverse_product(self, A: jax.Array): return jnp.sum(jnp.diag(A) / self.d) def tree_flatten(self): @@ -481,11 +479,11 @@ def preconditioned_operator(self) -> tf2jax.linalg.LinearOperator: is_positive_definite=True, ) - def log_det(self) -> Float: + def log_det(self): """Returns log det(R^T R) = 2 log det R.""" return 2 * self.right_half().log_abs_determinant() - def trace_of_inverse_product(self, A: jax.Array) -> Float: + def trace_of_inverse_product(self, A: jax.Array): """Returns tr( (R^T R)^(-1) A ) for a n x n, non-batched A.""" raise NotImplementedError( 'Base classes must override trace_of_inverse_product.') @@ -510,10 +508,10 @@ def full_preconditioner(self) -> tf2jax.linalg.LinearOperator: self.d, is_non_singular=True, is_positive_definite=True ) - def log_det(self) -> Float: + def log_det(self): return jnp.sum(jnp.log(self.d)) - def trace_of_inverse_product(self, A: jax.Array) -> Float: + def trace_of_inverse_product(self, A: jax.Array): return jnp.sum(jnp.diag(A) / self.d) def tree_flatten(self): @@ -582,7 +580,7 @@ def __init__( def right_half(self) -> tf2jax.linalg.LinearOperator: return self.P - def trace_of_inverse_product(self, A: jax.Array) -> Float: + def trace_of_inverse_product(self, A: jax.Array): # We want the trace of (P^T P)^(-1) A # = P^(-1) P^(-t) A # = [[ B^(-1), - B^(-1) C D^(-1)], [0, D^(-1)]]