Skip to content

Commit

Permalink
Check the types of the kernel matrices rather than the kernels in the…
Browse files Browse the repository at this point in the history
… MultitaskGaussianProcess distribution, to support Vizier's usage.

PiperOrigin-RevId: 620080901
  • Loading branch information
emilyfertig authored and tensorflower-gardener committed Mar 28, 2024
1 parent 6e86fb2 commit eaee193
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def _compute_flattened_scale(
observation_noise_variance = tf.convert_to_tensor(observation_noise_variance)

# We can add the observation noise to each block.
if isinstance(kernel, multitask_kernel.Independent):
# Check if the kernel is tfpke.Independent-like.
if (isinstance(kernel_matrix, tf.linalg.LinearOperatorKronecker) and
isinstance(kernel_matrix.operators[1], tf.linalg.LinearOperatorIdentity)):
# The Independent kernel matrix is realized as a kronecker product of the
# kernel over inputs, and an identity matrix per task (representing
# independent tasks). Update the diagonal of the first matrix and take the
Expand All @@ -123,7 +125,8 @@ def _compute_flattened_scale(
operators=[base_kernel_matrix] + kernel_matrix.operators[1:])
return cholesky_util.cholesky_from_fn(kernel_matrix, cholesky_fn)

if isinstance(kernel, multitask_kernel.Separable):
# Check if the kernel is tfpke.Separable-like.
if isinstance(kernel_matrix, tf.linalg.LinearOperatorKronecker):
# When `kernel_matrix` is a kronecker product, we can compute
# an eigenvalue decomposition to get a matrix square-root, which will
# be faster than densifying the kronecker product.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,18 @@ def _scale_from_precomputed(precomputed_cholesky, kernel):
f'Unexpected value for `precompute_cholesky`: {precomputed_cholesky}.')


def _precomputed_from_scale(observation_scale, kernel):
def _precomputed_from_scale(observation_scale):
"""Extracts expensive precomputed values."""
# TODO(b/331842377): Remove these checks and just store `observation_scale`
# when TFP objects are fully-functional JAX pytrees.
if isinstance(observation_scale, tf.linalg.LinearOperatorLowerTriangular):
return {'tril': {'chol_tril': observation_scale.tril}}
if isinstance(kernel, multitask_kernel.Independent):
# Check tfpke.Independent-like.
if (isinstance(observation_scale, tf.linalg.LinearOperatorKronecker)):
base_kernel_chol_op = observation_scale.operators[0]
return {'independent': {'chol_tril': base_kernel_chol_op.tril}}
if isinstance(kernel, multitask_kernel.Separable):
# Check tfpke.Separable-like.
if isinstance(observation_scale, tf.linalg.LinearOperatorComposition):
kronecker_op, diag_op = observation_scale.operators
kronecker_orths = [
{'identity': k.domain_dimension_tensor()}
Expand All @@ -190,7 +194,7 @@ def _precomputed_from_scale(observation_scale, kernel):
return {'separable': {'kronecker_orths': kronecker_orths,
'diag': diag_op.diag}}
# This should not happen.
raise ValueError('Unexpected values for kernel and observation_scale.')
raise ValueError('Unexpected value for observation_scale.')


class MultiTaskGaussianProcessRegressionModel(
Expand Down Expand Up @@ -685,7 +689,7 @@ def flattened_conditional_mean_fn(x):

# pylint: disable=protected-access
mtgprm._precomputed_divisor_matrix_cholesky = (
_precomputed_from_scale(observation_scale, kernel))
_precomputed_from_scale(observation_scale))
mtgprm._precomputed_solve_on_observation = solve_on_observations
# pylint: enable=protected-access

Expand Down

0 comments on commit eaee193

Please sign in to comment.