Description
I want to use inference gym models with 64 bit precision (for use in Blackjax samplers). I encounter complicated looking bugs, which I have reduced to the following minimal example (run interactively in Python):
import jax
jax.config.update("jax_enable_x64", True)
import inference_gym.using_jax as gym
gym.targets.Banana()._unnormalized_log_prob(jax.numpy.array([1.0,1.0]))
results in the following error. (Apologies if this is a simple mistake on my end)
Traceback (most recent call last):
File "", line 1, in
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/inference_gym/targets/banana.py", line 133, in _unnormalized_log_prob
return self._banana.log_prob(value)
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 1287, in log_prob
return self._call_log_prob(value, name, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 1269, in _call_log_prob
return self._log_prob(value, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py", line 364, in _log_prob
log_prob, _ = self.experimental_local_measure(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py", line 611, in experimental_local_measure
x = self.bijector.inverse(y, **bijector_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py", line 1390, in inverse
return self._call_inverse(y, name, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py", line 1370, in _call_inverse
return self._cache.inverse(y, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/cache_util.py", line 347, in inverse
return self._lookup(y, self._inverse_name, self._forward_name, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/cache_util.py", line 493, in _lookup
self._invoke(input, forward_name, kwargs, attrs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/cache_util.py", line 532, in _invoke
return getattr(self.bijector, fn_name)(input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/bijectors/masked_autoregressive.py", line 383, in _inverse
bijector = self._bijector_fn(y, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/inference_gym/targets/banana.py", line 88, in bijector_fn
shift = tf.concat(
^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/_utils.py", line 62, in wrap
return new_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/numpy_array.py", line 161, in _concat
values = _args_to_matching_arrays(values)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/numpy_array.py", line 149, in _args_to_matching_arrays
ret = [ops.convert_to_tensor(arg, dtype) for arg in args_list]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/numpy_array.py", line 149, in
ret = [ops.convert_to_tensor(arg, dtype) for arg in args_list]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/_utils.py", line 62, in wrap
return new_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py", line 152, in _convert_to_tensor
raise TypeError(('Tensor conversion requested dtype {} for array with '
TypeError: Tensor conversion requested dtype float32 for array with dtype float64: [-2.97]