ProbPipe is a Python framework for building scalable probabilistic pipelines with automated uncertainty quantification.
Most workflows for probabilistic inference can be described in terms of distributions, fixed values (data, hyperparameters, covariates), and operations that transform distributions. Implementing these workflows, however, is harder than describing them:
- Algorithmic challenges. There are many possible algorithms for common operations, with varying trade-offs that need to be explored in a problem-specific manner. A posterior could be approximated using different MCMC algorithms, variational inference, or sequential Monte Carlo.
- Representational challenges. Algorithms expect (and produce) specific formats for distributions and fixed values, and those formats are not always compatible with other parts of the workflow. Fixed values may be named parameter vectors, covariate matrices, or structured observations, and different algorithms expect different representations.
ProbPipe addresses these challenges through a single design principle: simplification via abstraction. There are just three core types:
Distribution: the universal representation of random quantities (priors, posteriors, data-generating processes). A distribution's capabilities are declared via protocols (SupportsSampling,SupportsLogProb, ...), and ProbPipe converts between representations as needed.Record: the universal container for non-random structured data (observed datasets, hyperparameters, design matrices).Recordis the deterministic counterpart ofDistribution.WorkflowFunction: any function decorated with@workflow_function. Pass concrete values and it runs normally; pass aDistributionwhere a concrete value is expected and ProbPipe propagates uncertainty automatically, returning aDistributionover the results. Array-valued inputs (RecordArray,DistributionArray) drive parameter sweeps; all returns are wrapped in the sameRecord/Distributioncontract so pipelines compose cleanly.
Distribution and Record share a single interface for named-field access (fields, select(...), select_all()) and passing components into a WorkflowFunction, so they are interchangeable at call sites. Fields split out of a common parent stay correlated end-to-end, so the two natural shapes — f(p=posterior) and f(**posterior.select_all()) — produce identical outputs. Implementation details (algorithms, data and distribution representations) stay invisible by default, while remaining fully configurable when control is needed.
ProbPipe provides a set of built-in ops, special workflow functions that dispatch based on a distribution's protocols:
condition_on: condition a model on observed data, automatically selecting the best inference algorithm (or specify one withmethod=).mean,variance,cov,expectation: compute distributional summaries, with automatic Monte Carlo fallback when exact computation is unavailable.sample,log_prob: draw samples or evaluate densities through a uniform interface.from_distribution: convert between distribution representations via the converter registry.predictive_check: built-in prior and posterior predictive checking.
Documentation | Getting Started Tutorial | API Reference
Logistic regression with named parameters: simulate data, fit the model, and propagate posterior uncertainty through a prediction.
import jax, jax.numpy as jnp
import tensorflow_probability.substrates.jax.glm as tfp_glm
from probpipe import (
Normal, ProductDistribution, GLMLikelihood, SimpleModel,
workflow_function, condition_on, mean,
)
# --- Simulate data from a logistic regression ---
beta_true = jnp.array([-1.0, 2.0]) # intercept, slope
x = jax.random.normal(jax.random.PRNGKey(0), shape=(200,))
X = jnp.column_stack([jnp.ones_like(x), x])
likelihood = GLMLikelihood(tfp_glm.Bernoulli(), X, seed=1)
y = likelihood.generate_data(beta_true, 200)
# --- 1. Build a model with named parameters ---
prior = ProductDistribution(
intercept=Normal(loc=0.0, scale=5.0, name="intercept"),
slope=Normal(loc=0.0, scale=5.0, name="slope"),
)
model = SimpleModel(prior, likelihood)
# --- 2. Condition on data (auto-selects NUTS) ---
posterior = condition_on(model, y)
draws = posterior.draws() # NumericRecordArray(intercept=..., slope=...)
draws["intercept"].mean() # -0.93 (true: -1.0)
draws["slope"].mean() # 2.18 (true: 2.0)
# --- 3. Propagate uncertainty through a prediction ---
@workflow_function
def predict_prob(intercept, slope, x):
return jax.nn.sigmoid(intercept + slope * x)
x_new = jnp.linspace(-3, 3, 50)
predictive = predict_prob(**posterior.select('intercept', 'slope'), x=x_new)
# predictive is a Distribution over predicted P(y=1|x) curvespredict_prob is a @workflow_function: ProbPipe samples from the posterior and evaluates the function for each draw, returning the full predictive distribution. The two posterior fields are splatted from a single parent, so ProbPipe draws them jointly — each (intercept, slope) pair stays correlated. Plotting the result:
import numpy as np, matplotlib.pyplot as plt
S = np.array(predictive.samples) # (n_draws, 50)
lo, hi = np.percentile(S, [5, 95], axis=0)
plt.fill_between(np.array(x_new), lo, hi, alpha=0.3, label='90% CI')
plt.plot(np.array(x_new), S.mean(axis=0), lw=2, label='Posterior mean')
true_curve = jax.nn.sigmoid(beta_true[0] + beta_true[1] * x_new)
plt.plot(np.array(x_new), np.array(true_curve), 'k--', label='True')
plt.scatter(np.array(x), np.array(y), s=12, alpha=0.4, label='Data')
plt.xlabel('x'); plt.ylabel('P(y = 1 | x)'); plt.legend(fontsize=8)- Protocol-based dispatch. A distribution's capabilities are declared via
@runtime_checkableprotocols (SupportsSampling,SupportsLogProb,SupportsMean, ...). Operations likecondition_onandfrom_distributionuse these protocols to auto-select the best algorithm from a pluggable registry. Override withmethod=when you want control. - Multiple backends. The inference registry spans TFP (NUTS, HMC, RWMH), nutpie, CmdStan, PyMC (NUTS, ADVI), and simulation-based inference (SMC-ABC via sbijax). Swap backends without changing model code.
- Automatic distribution conversion. A converter registry converts between distribution representations (e.g., MCMC samples to KDE) as needed, using protocol-based dispatch analogous to
condition_on. - JAX-native. Distributions and workflow functions are compatible with JAX (
vmap,jit,grad), with built-in support for TFP distributions. - Provenance tracking. Each distribution records how it was created (algorithm, parents, metadata), enabling full lineage tracing from any result back to its inputs.
- Prefect orchestration. Distribute pipeline steps across machines and CPUs without code changes.
Requires Python >= 3.12 (tested on 3.12 and 3.13).
git clone https://github.com/TARPS-group/prob-pipe.git
cd prob-pipe
pip install .Core dependencies: JAX and TensorFlow Probability. ProbPipe uses tfp-nightly, which is the recommended approach for TFP on JAX since stable TFP releases are tied to TensorFlow and often lag behind JAX.
Optional extras:
pip install .[dev] # pytest, jupyter, matplotlib, graphviz
pip install .[prefect] # Prefect orchestration backend
pip install .[stan] # Stan models via BridgeStan + CmdStanPy
pip install .[pymc] # PyMC model integration
pip install .[nutpie] # nutpie Markov chain Monte Carlo (MCMC) samplerProbPipe can dispatch Prefect-orchestrated WorkflowFunction tasks to Ray via
Prefect-Ray. This is not a native Ray backend: there is no WorkflowKind.RAY,
probpipe.ray module, or direct ProbPipe API for ray.remote, ray.put,
actors, placement groups, or resource hints.
Install Ray support through Prefect-Ray:
pip install "probpipe[prefect]"
pip install "prefect[ray]"For local development from this repository:
pip install -e ".[prefect]"
pip install "prefect[ray]"The local demo uses a persistent Ray head:
prefect server start
ray start --head
python example_scripts/run_ray_demo.pySee the Ray via Prefect guide for setup details, deployment notes, and current support boundaries.
- Getting Started Tutorial: iterative Bayesian model building with ProbPipe
- API Reference: full class and function documentation
See CONTRIBUTING.md for development setup, PR workflow, and guidelines.
