diff --git a/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process.py b/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process.py index bc03dff245..dea9b61e4a 100644 --- a/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process.py +++ b/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process.py @@ -101,7 +101,9 @@ def _compute_flattened_scale( observation_noise_variance = tf.convert_to_tensor(observation_noise_variance) # We can add the observation noise to each block. - if isinstance(kernel, multitask_kernel.Independent): + # Check if the kernel is tfpke.Independent-like. + if (isinstance(kernel_matrix, tf.linalg.LinearOperatorKronecker) and + isinstance(kernel_matrix.operators[1], tf.linalg.LinearOperatorIdentity)): # The Independent kernel matrix is realized as a kronecker product of the # kernel over inputs, and an identity matrix per task (representing # independent tasks). Update the diagonal of the first matrix and take the @@ -123,7 +125,8 @@ def _compute_flattened_scale( operators=[base_kernel_matrix] + kernel_matrix.operators[1:]) return cholesky_util.cholesky_from_fn(kernel_matrix, cholesky_fn) - if isinstance(kernel, multitask_kernel.Separable): + # Check if the kernel is tfpke.Separable-like. + if isinstance(kernel_matrix, tf.linalg.LinearOperatorKronecker): # When `kernel_matrix` is a kronecker product, we can compute # an eigenvalue decomposition to get a matrix square-root, which will # be faster than densifying the kronecker product. diff --git a/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py b/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py index fc7ec6a5de..f9a6fe3bbf 100644 --- a/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py +++ b/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py @@ -174,14 +174,18 @@ def _scale_from_precomputed(precomputed_cholesky, kernel): f'Unexpected value for `precompute_cholesky`: {precomputed_cholesky}.') -def _precomputed_from_scale(observation_scale, kernel): +def _precomputed_from_scale(observation_scale): """Extracts expensive precomputed values.""" + # TODO(b/331842377): Remove these checks and just store `observation_scale` + # when TFP objects are fully-functional JAX pytrees. if isinstance(observation_scale, tf.linalg.LinearOperatorLowerTriangular): return {'tril': {'chol_tril': observation_scale.tril}} - if isinstance(kernel, multitask_kernel.Independent): + # Check tfpke.Independent-like. + if (isinstance(observation_scale, tf.linalg.LinearOperatorKronecker)): base_kernel_chol_op = observation_scale.operators[0] return {'independent': {'chol_tril': base_kernel_chol_op.tril}} - if isinstance(kernel, multitask_kernel.Separable): + # Check tfpke.Separable-like. + if isinstance(observation_scale, tf.linalg.LinearOperatorComposition): kronecker_op, diag_op = observation_scale.operators kronecker_orths = [ {'identity': k.domain_dimension_tensor()} @@ -190,7 +194,7 @@ def _precomputed_from_scale(observation_scale, kernel): return {'separable': {'kronecker_orths': kronecker_orths, 'diag': diag_op.diag}} # This should not happen. - raise ValueError('Unexpected values for kernel and observation_scale.') + raise ValueError('Unexpected value for observation_scale.') class MultiTaskGaussianProcessRegressionModel( @@ -685,7 +689,7 @@ def flattened_conditional_mean_fn(x): # pylint: disable=protected-access mtgprm._precomputed_divisor_matrix_cholesky = ( - _precomputed_from_scale(observation_scale, kernel)) + _precomputed_from_scale(observation_scale)) mtgprm._precomputed_solve_on_observation = solve_on_observations # pylint: enable=protected-access