diff --git a/tensorflow_probability/python/internal/backend/numpy/ops.py b/tensorflow_probability/python/internal/backend/numpy/ops.py index 96e8fa5155..113043468b 100644 --- a/tensorflow_probability/python/internal/backend/numpy/ops.py +++ b/tensorflow_probability/python/internal/backend/numpy/ops.py @@ -676,8 +676,10 @@ def assign_sub(self, value, **_): if JAX_MODE: jax.interpreters.xla.canonicalize_dtype_handlers[NumpyVariable] = ( jax.interpreters.xla.canonicalize_dtype_handlers[onp.ndarray]) - jax.interpreters.xla.pytype_aval_mappings[NumpyVariable] = ( - jax.interpreters.xla.pytype_aval_mappings[onp.ndarray]) + if hasattr(jax.interpreters.xla, 'pytype_aval_mappings'): + # Deprecated in JAX v0.5.0 + jax.interpreters.xla.pytype_aval_mappings[NumpyVariable] = ( + jax.interpreters.xla.pytype_aval_mappings[onp.ndarray]) jax.core.pytype_aval_mappings[NumpyVariable] = ( jax.core.pytype_aval_mappings[onp.ndarray])