Skip to content

nutpie worker panic under jax backend #188

@fonnesbeck

Description

@fonnesbeck

I'm running into jax backend issues when running a model that samples without error under the default backend. Switching to the jax backend via:

pm.sample(1000, tune=1000, chains=2, nuts_sampler='nutpie', nuts_sampler_kwargs={'backend': 'jax'}, random_seed=RANDOM_SEED)

results in the following panic in a nutpie thread:

thread 'nutpie-worker-1' panicked at /home/runner/.cargo/registry/src/index.crates.io-1949cf8c6b5b557f/nuts-rs-0.15.0/src/sampler.rs:635:18:
Could not send sampling results to main thread.: SendError { .. }
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[67], line 32
     28 p = pm.Deterministic('p', pm.math.invlogit(logit_p), dims='obs')
     30 y = pm.Binomial('y', n=pa, p=p, observed=hr)
---> 32 gp_covariate_trace = pm.sample(1000, tune=1000, chains=2, nuts_sampler='nutpie', nuts_sampler_kwargs={'backend': 'jax'}, random_seed=RANDOM_SEED)

File ~/repos/field_of_play/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/mcmc.py:809, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    804         raise ValueError(
    805             "Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability."
    806         )
    808     with joined_blas_limiter():
--> 809         return _sample_external_nuts(
    810             sampler=nuts_sampler,
    811             draws=draws,
    812             tune=tune,
    813             chains=chains,
    814             target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    815             random_seed=random_seed,
    816             initvals=initvals,
    817             model=model,
    818             var_names=var_names,
    819             progressbar=progress_bool,
    820             idata_kwargs=idata_kwargs,
    821             compute_convergence_checks=compute_convergence_checks,
    822             nuts_sampler_kwargs=nuts_sampler_kwargs,
    823             **kwargs,
    824         )
    826 if exclusive_nuts and not provided_steps:
    827     # Special path for NUTS initialization
    828     if "nuts" in kwargs:

File ~/repos/field_of_play/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/mcmc.py:349, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
    344 compiled_model = nutpie.compile_pymc_model(
    345     model,
    346     **compile_kwargs,
    347 )
    348 t_start = time.time()
--> 349 idata = nutpie.sample(
    350     compiled_model,
    351     draws=draws,
    352     tune=tune,
    353     chains=chains,
    354     target_accept=target_accept,
    355     seed=_get_seeds_per_chain(random_seed, 1)[0],
    356     progress_bar=progressbar,
    357     **nuts_sampler_kwargs,
    358 )
    359 t_sample = time.time() - t_start
    360 # Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed
    361 # gather observed and constant data as nutpie.sample() has no access to the PyMC model

File ~/repos/field_of_play/.pixi/envs/default/lib/python3.12/site-packages/nutpie/sample.py:654, in sample(compiled_model, draws, tune, chains, cores, seed, save_warmup, progress_bar, low_rank_modified_mass_matrix, transform_adapt, init_mean, return_raw_trace, blocking, progress_template, progress_style, progress_rate, **kwargs)
    651     return sampler
    653 try:
--> 654     result = sampler.wait()
    655 except KeyboardInterrupt:
    656     result = sampler.abort()

File ~/repos/field_of_play/.pixi/envs/default/lib/python3.12/site-packages/nutpie/sample.py:388, in _BackgroundSampler.wait(self, timeout)
    378 def wait(self, *, timeout=None):
    379     """Wait until sampling is finished and return the trace.
    380 
    381     KeyboardInterrupt will lead to interrupt the waiting.
   (...)
    386     This resumes the sampler in case it had been paused.
    387     """
--> 388     self._sampler.wait(timeout)
    389     results = self._sampler.extract_results()
    390     return self._extract(results)

RuntimeError: All initialization points failed

Caused by:
    Logp function returned error: PyError(PyErr { type: <class 'AttributeError'>, value: AttributeError("module 'jax.lax' has no attribute 'mul_without_zeros'"), traceback: Some(<traceback object at 0x7f813ad75080>) })

Running on the following environment:

Python implementation: CPython
Python version       : 3.12.8
IPython version      : 8.32.0

numpy     : 1.26.4
scipy     : 1.12.0
pymc      : 5.20.1
preliz    : 0.15.0
nutpie    : 0.14.2
pandas    : 2.2.3
pytensor  : 2.27.1
matplotlib: 3.10.0
plotly    : 6.0.0
polars    : 1.24.0
arviz     : 0.20.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions