From eb29f6ecbc937d3f3c1d4d4859706ac6234c0a08 Mon Sep 17 00:00:00 2001 From: siege Date: Mon, 5 Feb 2024 14:59:18 -0800 Subject: [PATCH] Correctly handle 0-length inputs in the numpy substrate tf.linalg.lstsq. PiperOrigin-RevId: 604442726 --- .../python/internal/backend/numpy/linalg_impl.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow_probability/python/internal/backend/numpy/linalg_impl.py b/tensorflow_probability/python/internal/backend/numpy/linalg_impl.py index a7c1e75d64..e69b4986ba 100644 --- a/tensorflow_probability/python/internal/backend/numpy/linalg_impl.py +++ b/tensorflow_probability/python/internal/backend/numpy/linalg_impl.py @@ -359,7 +359,10 @@ def _lstsq(matrix, rhs, l2_regularizer=0.0, fast=True, name=None): if JAX_MODE: import jax # pylint: disable=g-import-not-at-top return jax.vmap(functools.partial(_lstsq, fast=False))(matrix, rhs) - return np.array([_lstsq(mat, r, fast=False) for mat, r in zip(matrix, rhs)]) + res = np.array([_lstsq(mat, r, fast=False) for mat, r in zip(matrix, rhs)]) + if matrix.shape[0] == 0: + res = res.reshape(matrix.shape[:-2] + (matrix.shape[-1], rhs.shape[-1])) + return res rcond = None if JAX_MODE and matrix.dtype == np.float32: rcond = 0. # https://github.com/google/jax/issues/15591