Skip to content

Commit

Permalink
Avoid complex->real casts via jax.numpy.astype
Browse files Browse the repository at this point in the history
This currently issues a warning about implicitly discarding the imaginary part, and it will issue an error in the future.

PiperOrigin-RevId: 658052473
  • Loading branch information
vanderplas authored and tensorflower-gardener committed Jul 31, 2024
1 parent a6f3989 commit 68f4dc0
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tensorflow_probability/python/internal/backend/numpy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,9 @@ def _default_convert_to_tensor_with_dtype(value, dtype,
arr = np.array(value)
if dtype is not None:
# arr.astype(None) forces conversion to float64
if (np.issubdtype(arr.dtype, np.complexfloating) and
not np.issubdtype(dtype, np.complexfloating)):
arr = arr.real
return arr.astype(dtype)
return arr
elif isinstance(value, complex):
Expand Down

0 comments on commit 68f4dc0

Please sign in to comment.