From 135080b6b1ac5724fc1731b0a9ca6f2010b1aea5 Mon Sep 17 00:00:00 2001 From: vanderplas Date: Mon, 24 Feb 2025 15:05:10 -0800 Subject: [PATCH] Future-proof reference to deprecated pytype_aval_mappings PiperOrigin-RevId: 730610446 --- tensorflow_probability/python/internal/backend/numpy/ops.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow_probability/python/internal/backend/numpy/ops.py b/tensorflow_probability/python/internal/backend/numpy/ops.py index 96e8fa5155..113043468b 100644 --- a/tensorflow_probability/python/internal/backend/numpy/ops.py +++ b/tensorflow_probability/python/internal/backend/numpy/ops.py @@ -676,8 +676,10 @@ def assign_sub(self, value, **_): if JAX_MODE: jax.interpreters.xla.canonicalize_dtype_handlers[NumpyVariable] = ( jax.interpreters.xla.canonicalize_dtype_handlers[onp.ndarray]) - jax.interpreters.xla.pytype_aval_mappings[NumpyVariable] = ( - jax.interpreters.xla.pytype_aval_mappings[onp.ndarray]) + if hasattr(jax.interpreters.xla, 'pytype_aval_mappings'): + # Deprecated in JAX v0.5.0 + jax.interpreters.xla.pytype_aval_mappings[NumpyVariable] = ( + jax.interpreters.xla.pytype_aval_mappings[onp.ndarray]) jax.core.pytype_aval_mappings[NumpyVariable] = ( jax.core.pytype_aval_mappings[onp.ndarray])