Skip to content

Commit

Permalink
Call convert_to_tensor in tf2jax.cast().
Browse files Browse the repository at this point in the history
The intention is to get code like `tf2jax.cast(tfp.util.DeferredTensor)` to work correctly. As it is now, it calls `DeferredTensor.__array__`, which is only intended to work with numpy at the moment. It's not clear whether `jnp.array(DT)` should bypass numpy or not.

That aside, calling convert_to_tensor at the beginning of functions is pretty standard, so this change is probably a good idea either way.

PiperOrigin-RevId: 398363966
  • Loading branch information
SiegeLordEx authored and jburnim committed Sep 30, 2021
1 parent 36c6167 commit 7089d1a
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def batch_jacobian(self, target, source, # pylint: disable=unused-argument


def _cast(x, dtype):
x = np.asarray(x)
x = convert_to_tensor(x)
if (np.issubdtype(x.dtype, np.complexfloating) and
not np.issubdtype(dtype, np.complexfloating)):
x = np.real(x)
Expand Down

0 comments on commit 7089d1a

Please sign in to comment.