Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference-Gym: double precision issue #1993

Open
reubenharry opened this issue Feb 21, 2025 · 4 comments
Open

Inference-Gym: double precision issue #1993

reubenharry opened this issue Feb 21, 2025 · 4 comments

Comments

@reubenharry
Copy link

I want to use inference gym models with 64 bit precision (for use in Blackjax samplers). I encounter complicated looking bugs, which I have reduced to the following minimal example (run interactively in Python):

import jax
jax.config.update("jax_enable_x64", True)                                                                                                                               
import inference_gym.using_jax as gym 
gym.targets.Banana()._unnormalized_log_prob(jax.numpy.array([1.0,1.0]))

results in the following error. (Apologies if this is a simple mistake on my end)

Traceback (most recent call last):
File "", line 1, in
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/inference_gym/targets/banana.py", line 133, in _unnormalized_log_prob
return self._banana.log_prob(value)

File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 1287, in log_prob
return self._call_log_prob(value, name, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 1269, in _call_log_prob
return self._log_prob(value, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py", line 364, in _log_prob
log_prob, _ = self.experimental_local_measure(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py", line 611, in experimental_local_measure
x = self.bijector.inverse(y, **bijector_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py", line 1390, in inverse
return self._call_inverse(y, name, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py", line 1370, in _call_inverse
return self._cache.inverse(y, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/cache_util.py", line 347, in inverse
return self._lookup(y, self._inverse_name, self._forward_name, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/cache_util.py", line 493, in _lookup
self._invoke(input, forward_name, kwargs, attrs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/cache_util.py", line 532, in _invoke
return getattr(self.bijector, fn_name)(input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/bijectors/masked_autoregressive.py", line 383, in _inverse
bijector = self._bijector_fn(y, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/inference_gym/targets/banana.py", line 88, in bijector_fn
shift = tf.concat(
^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/_utils.py", line 62, in wrap
return new_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/numpy_array.py", line 161, in _concat
values = _args_to_matching_arrays(values)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/numpy_array.py", line 149, in _args_to_matching_arrays
ret = [ops.convert_to_tensor(arg, dtype) for arg in args_list]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/numpy_array.py", line 149, in
ret = [ops.convert_to_tensor(arg, dtype) for arg in args_list]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/_utils.py", line 62, in wrap
return new_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py", line 152, in _convert_to_tensor
raise TypeError(('Tensor conversion requested dtype {} for array with '
TypeError: Tensor conversion requested dtype float32 for array with dtype float64: [-2.97]

@reubenharry
Copy link
Author

I will note that I do not encounter this issue for all inference gym models. For instance, with target = gym.targets.VectorModel(gym.targets.BrownianMotionUnknownScalesMissingMiddleObservations()), I have no problem.

@SiegeLordEx
Copy link
Member

This will need some dedicated changes to work. It'll likely look like a dtype arg for each target's initializer.

@reubenharry
Copy link
Author

Is this something you'd recommend I try, or are there plans for the inference-gym developers to address this? (Currently this is a pretty serious blocker for me to use inference-gym for jax benchmarking of samplers, since I can't use double precision).

@SiegeLordEx
Copy link
Member

I'm almost done making the necessary changes, should be done in a few days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants