From 7089d1a65dc735130894761b4351941884f95ad3 Mon Sep 17 00:00:00 2001 From: siege Date: Wed, 22 Sep 2021 16:58:30 -0700 Subject: [PATCH] Call convert_to_tensor in tf2jax.cast(). The intention is to get code like `tf2jax.cast(tfp.util.DeferredTensor)` to work correctly. As it is now, it calls `DeferredTensor.__array__`, which is only intended to work with numpy at the moment. It's not clear whether `jnp.array(DT)` should bypass numpy or not. That aside, calling convert_to_tensor at the beginning of functions is pretty standard, so this change is probably a good idea either way. PiperOrigin-RevId: 398363966 --- tensorflow_probability/python/internal/backend/numpy/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_probability/python/internal/backend/numpy/ops.py b/tensorflow_probability/python/internal/backend/numpy/ops.py index c49a459327..75ffb57c52 100644 --- a/tensorflow_probability/python/internal/backend/numpy/ops.py +++ b/tensorflow_probability/python/internal/backend/numpy/ops.py @@ -377,7 +377,7 @@ def batch_jacobian(self, target, source, # pylint: disable=unused-argument def _cast(x, dtype): - x = np.asarray(x) + x = convert_to_tensor(x) if (np.issubdtype(x.dtype, np.complexfloating) and not np.issubdtype(dtype, np.complexfloating)): x = np.real(x)