Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference gym / numpyro compatibility #1986

Open
kylejcaron opened this issue Jan 30, 2025 · 1 comment
Open

inference gym / numpyro compatibility #1986

kylejcaron opened this issue Jan 30, 2025 · 1 comment

Comments

@kylejcaron
Copy link

I'm looking to use Inference Gym targets in numpyro, but I'm running into issues I believe because there are numpy arrays in the Inference Gym model init, which causes tracer conversion errors in numpyro/jax

Any ideas how to get around this? I cant tell how the inference_gym.using_jax module works, but I was hoping that it would change the arrays to be initialized as jax arrays and not numpy

import jax
import numpyro
from numpyro import distributions as dist
from numpyro.infer import MCMC, NUTS
from inference_gym import using_jax as gym

class Banana(dist.Distribution):
    arg_constraints = {"ndims": dist.constraints.positive_integer, "curvature": dist.constraints.real}
    support = dist.constraints.real_vector
    pytree_data_fields = ("ndims", "curvature")

    def __init__(self, ndims, curvature):
        self.ndims = ndims
        self.curvature = curvature
        self.gym_dist = gym.targets.Banana(ndims=ndims, curvature=curvature)
        super().__init__(event_shape=(ndims,))
    
    def sample(self, key, sample_shape=()):
        return self.gym_dist.sample(seed=key, sample_shape=sample_shape)
    
    def log_prob(self, value):
        return self.gym_dist._unnormalized_log_prob(value)
    

samples = Banana(ndims=3, curvature=0.03).sample(jax.random.PRNGKey(0), sample_shape=(100,))

def model(X):
    curvature = numpyro.sample("curvature", dist.Beta(1,30))
    return numpyro.sample("obs", Banana(ndims=3, curvature=curvature), obs=X)

kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), X=samples)

Here's the full traceback

Traceback

File "/var/folders/1p/v01fvg3j1cz1hzv988ygj8m00000gn/T/ipykernel_28885/3778384098.py", line 37, in
mcmc.run(jax.random.PRNGKey(0), X=samples)
File "/.venv/lib/python3.11/site-packages/numpyro/infer/mcmc.py", line 702, in run
states_flat, last_state = partial_map_fn(map_args)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/numpyro/infer/mcmc.py", line 465, in _single_chain_mcmc
new_init_state = self.sampler.init(
^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/numpyro/infer/hmc.py", line 749, in init
init_params = self._init_state(
^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/numpyro/infer/hmc.py", line 693, in _init_state
) = initialize_model(
^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/numpyro/infer/util.py", line 713, in initialize_model
(init_params, pe, grad), is_valid = find_valid_initial_params(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/numpyro/infer/util.py", line 447, in find_valid_initial_params
(init_params, pe, z_grad), is_valid = _find_valid_params(
^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/numpyro/infer/util.py", line 433, in _find_valid_params
_, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/numpyro/infer/util.py", line 417, in body_fn
pe, z_grad = value_and_grad(potential_fn)(params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/jax/_src/api.py", line 468, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/jax/_src/api.py", line 1975, in _vjp
out_primals, vjp = ad.vjp(flat_fun, primals_flat)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/jax/_src/interpreters/ad.py", line 252, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/jax/_src/interpreters/ad.py", line 237, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 333, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 574, in trace_to_jaxpr_nounits
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/jax/_src/linear_util.py", line 192, in call_wrapped
return self.f_transformed(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 587, in trace_to_subjaxpr_nounits
out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 616, in _trace_to_subjaxpr_nounits
ans = f(*in_args)
^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/jax/_src/api_util.py", line 72, in flatten_fun
ans = f(*py_args, **py_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/jax/_src/interpreters/ad.py", line 78, in jvpfun
out_primals, out_tangents = f(tag, primals, tangents)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/jax/_src/interpreters/ad.py", line 115, in jvp_subtrace
ans = f(*in_tracers)
^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/jax/_src/api_util.py", line 88, in flatten_fun_nokwargs
ans = f(*py_args)
^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/jax/_src/api_util.py", line 292, in _argnums_partial
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/numpyro/infer/util.py", line 299, in potential_energy
log_joint, model_trace = log_density
(
^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/numpyro/infer/util.py", line 70, in log_density
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/numpyro/handlers.py", line 186, in get_trace
self(*args, **kwargs)
File "/.venv/lib/python3.11/site-packages/numpyro/primitives.py", line 105, in call
return self.fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/numpyro/primitives.py", line 105, in call
return self.fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/numpyro/primitives.py", line 105, in call
return self.fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
[Previous line repeated 3 more times]
File "/var/folders/1p/v01fvg3j1cz1hzv988ygj8m00000gn/T/ipykernel_28885/3778384098.py", line 31, in model
return numpyro.sample("obs", Banana(ndims=3, curvature=curvature), obs=X)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/numpyro/distributions/distribution.py", line 100, in call
return super().call(*args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/var/folders/1p/v01fvg3j1cz1hzv988ygj8m00000gn/T/ipykernel_28885/3778384098.py", line 17, in init
self.gym_dist = gym.targets.Banana(ndims=ndims, curvature=curvature)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/inference_gym/targets/banana.py", line 116, in init
[10.] + [np.sqrt(1. + 2 * curvature
2 * 10.**4)] +
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 692, in array
raise TracerArrayConversionError(self)

@SiegeLordEx
Copy link
Member

SiegeLordEx commented Jan 30, 2025

While this particular issue is fixable in principle (see below), the core issue is that this Inference Gym targets are not intended to be used this way. The targets are high level constructs that can do IO and other things not compatible with jitted computation in their initializer. They're not "distributions" in the sense of being building blocks to constructing larger probabilistic models.

If you want a local fix, edit the /.venv/lib/python3.11/site-packages/inference_gym/targets/banana.py source code on that line to look like:

ground_truth_standard_deviation=tf.constant(
   [10.] + [tf.sqrt(1. + 2 * curvature**2 * 10.**4)] +
   [1.] * (ndims - 2)),

That's just a local workaround. It's unclear what the proper solution would look like because there's a lot of corner cases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants