diff --git a/tensorflow_probability/python/internal/backend/numpy/ops.py b/tensorflow_probability/python/internal/backend/numpy/ops.py index b96b22128b..4cd9ed8135 100644 --- a/tensorflow_probability/python/internal/backend/numpy/ops.py +++ b/tensorflow_probability/python/internal/backend/numpy/ops.py @@ -286,6 +286,9 @@ def _default_convert_to_tensor_with_dtype(value, dtype, arr = np.array(value) if dtype is not None: # arr.astype(None) forces conversion to float64 + if (np.issubdtype(arr.dtype, np.complexfloating) and + not np.issubdtype(dtype, np.complexfloating)): + arr = arr.real return arr.astype(dtype) return arr elif isinstance(value, complex):