Skip to content

Inference-Gym: double precision issue #1993

Closed
@reubenharry

Description

@reubenharry

I want to use inference gym models with 64 bit precision (for use in Blackjax samplers). I encounter complicated looking bugs, which I have reduced to the following minimal example (run interactively in Python):

import jax
jax.config.update("jax_enable_x64", True)                                                                                                                               
import inference_gym.using_jax as gym 
gym.targets.Banana()._unnormalized_log_prob(jax.numpy.array([1.0,1.0]))

results in the following error. (Apologies if this is a simple mistake on my end)

Traceback (most recent call last):
File "", line 1, in
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/inference_gym/targets/banana.py", line 133, in _unnormalized_log_prob
return self._banana.log_prob(value)

File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 1287, in log_prob
return self._call_log_prob(value, name, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 1269, in _call_log_prob
return self._log_prob(value, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py", line 364, in _log_prob
log_prob, _ = self.experimental_local_measure(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py", line 611, in experimental_local_measure
x = self.bijector.inverse(y, **bijector_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py", line 1390, in inverse
return self._call_inverse(y, name, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py", line 1370, in _call_inverse
return self._cache.inverse(y, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/cache_util.py", line 347, in inverse
return self._lookup(y, self._inverse_name, self._forward_name, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/cache_util.py", line 493, in _lookup
self._invoke(input, forward_name, kwargs, attrs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/cache_util.py", line 532, in _invoke
return getattr(self.bijector, fn_name)(input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/bijectors/masked_autoregressive.py", line 383, in _inverse
bijector = self._bijector_fn(y, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/inference_gym/targets/banana.py", line 88, in bijector_fn
shift = tf.concat(
^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/_utils.py", line 62, in wrap
return new_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/numpy_array.py", line 161, in _concat
values = _args_to_matching_arrays(values)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/numpy_array.py", line 149, in _args_to_matching_arrays
ret = [ops.convert_to_tensor(arg, dtype) for arg in args_list]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/numpy_array.py", line 149, in
ret = [ops.convert_to_tensor(arg, dtype) for arg in args_list]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/_utils.py", line 62, in wrap
return new_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py", line 152, in _convert_to_tensor
raise TypeError(('Tensor conversion requested dtype {} for array with '
TypeError: Tensor conversion requested dtype float32 for array with dtype float64: [-2.97]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions