Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Call convert_to_tensor in tf2jax.cast().
The intention is to get code like `tf2jax.cast(tfp.util.DeferredTensor)` to work correctly. As it is now, it calls `DeferredTensor.__array__`, which is only intended to work with numpy at the moment. It's not clear whether `jnp.array(DT)` should bypass numpy or not. That aside, calling convert_to_tensor at the beginning of functions is pretty standard, so this change is probably a good idea either way. PiperOrigin-RevId: 398363966
- Loading branch information