-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inference-Gym: double precision issue #1993
Comments
I will note that I do not encounter this issue for all inference gym models. For instance, with |
This will need some dedicated changes to work. It'll likely look like a |
Is this something you'd recommend I try, or are there plans for the inference-gym developers to address this? (Currently this is a pretty serious blocker for me to use inference-gym for jax benchmarking of samplers, since I can't use double precision). |
I'm almost done making the necessary changes, should be done in a few days. |
I want to use inference gym models with 64 bit precision (for use in Blackjax samplers). I encounter complicated looking bugs, which I have reduced to the following minimal example (run interactively in Python):
results in the following error. (Apologies if this is a simple mistake on my end)
Traceback (most recent call last):
File "", line 1, in
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/inference_gym/targets/banana.py", line 133, in _unnormalized_log_prob
return self._banana.log_prob(value)
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 1287, in log_prob
return self._call_log_prob(value, name, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 1269, in _call_log_prob
return self._log_prob(value, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py", line 364, in _log_prob
log_prob, _ = self.experimental_local_measure(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py", line 611, in experimental_local_measure
x = self.bijector.inverse(y, **bijector_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py", line 1390, in inverse
return self._call_inverse(y, name, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py", line 1370, in _call_inverse
return self._cache.inverse(y, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/cache_util.py", line 347, in inverse
return self._lookup(y, self._inverse_name, self._forward_name, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/cache_util.py", line 493, in _lookup
self._invoke(input, forward_name, kwargs, attrs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/cache_util.py", line 532, in _invoke
return getattr(self.bijector, fn_name)(input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/bijectors/masked_autoregressive.py", line 383, in _inverse
bijector = self._bijector_fn(y, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/inference_gym/targets/banana.py", line 88, in bijector_fn
shift = tf.concat(
^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/_utils.py", line 62, in wrap
return new_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/numpy_array.py", line 161, in _concat
values = _args_to_matching_arrays(values)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/numpy_array.py", line 149, in _args_to_matching_arrays
ret = [ops.convert_to_tensor(arg, dtype) for arg in args_list]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/numpy_array.py", line 149, in
ret = [ops.convert_to_tensor(arg, dtype) for arg in args_list]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/_utils.py", line 62, in wrap
return new_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/global/homes/r/reubenh/.conda/envs/jax2024/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py", line 152, in _convert_to_tensor
raise TypeError(('Tensor conversion requested dtype {} for array with '
TypeError: Tensor conversion requested dtype float32 for array with dtype float64: [-2.97]
The text was updated successfully, but these errors were encountered: