diff --git a/tensorflow_probability/python/internal/backend/numpy/linalg_impl.py b/tensorflow_probability/python/internal/backend/numpy/linalg_impl.py index a7c1e75d64..e69b4986ba 100644 --- a/tensorflow_probability/python/internal/backend/numpy/linalg_impl.py +++ b/tensorflow_probability/python/internal/backend/numpy/linalg_impl.py @@ -359,7 +359,10 @@ def _lstsq(matrix, rhs, l2_regularizer=0.0, fast=True, name=None): if JAX_MODE: import jax # pylint: disable=g-import-not-at-top return jax.vmap(functools.partial(_lstsq, fast=False))(matrix, rhs) - return np.array([_lstsq(mat, r, fast=False) for mat, r in zip(matrix, rhs)]) + res = np.array([_lstsq(mat, r, fast=False) for mat, r in zip(matrix, rhs)]) + if matrix.shape[0] == 0: + res = res.reshape(matrix.shape[:-2] + (matrix.shape[-1], rhs.shape[-1])) + return res rcond = None if JAX_MODE and matrix.dtype == np.float32: rcond = 0. # https://github.com/google/jax/issues/15591