You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm looking to use Inference Gym targets in numpyro, but I'm running into issues I believe because there are numpy arrays in the Inference Gym model init, which causes tracer conversion errors in numpyro/jax
Any ideas how to get around this? I cant tell how the inference_gym.using_jax module works, but I was hoping that it would change the arrays to be initialized as jax arrays and not numpy
While this particular issue is fixable in principle (see below), the core issue is that this Inference Gym targets are not intended to be used this way. The targets are high level constructs that can do IO and other things not compatible with jitted computation in their initializer. They're not "distributions" in the sense of being building blocks to constructing larger probabilistic models.
If you want a local fix, edit the /.venv/lib/python3.11/site-packages/inference_gym/targets/banana.py source code on that line to look like:
I'm looking to use Inference Gym targets in numpyro, but I'm running into issues I believe because there are numpy arrays in the Inference Gym model init, which causes tracer conversion errors in numpyro/jax
Any ideas how to get around this? I cant tell how the
inference_gym.using_jax
module works, but I was hoping that it would change the arrays to be initialized as jax arrays and not numpyHere's the full traceback
Traceback
The text was updated successfully, but these errors were encountered: