Skip to content

TARPS-group/prob-pipe

Repository files navigation

ProbPipe

CI codecov docs

ProbPipe is a Python framework for building scalable probabilistic pipelines with automated uncertainty quantification.

Why ProbPipe?

Most workflows for probabilistic inference can be described in terms of distributions, fixed values (data, hyperparameters, covariates), and operations that transform distributions. Implementing these workflows, however, is harder than describing them:

  1. Algorithmic challenges. There are many possible algorithms for common operations, with varying trade-offs that need to be explored in a problem-specific manner. A posterior could be approximated using different MCMC algorithms, variational inference, or sequential Monte Carlo.
  2. Representational challenges. Algorithms expect (and produce) specific formats for distributions and fixed values, and those formats are not always compatible with other parts of the workflow. Fixed values may be named parameter vectors, covariate matrices, or structured observations, and different algorithms expect different representations.

Simplification via abstraction

ProbPipe addresses these challenges through a single design principle: simplification via abstraction. There are just three core types:

  • Distribution: the universal representation of random quantities (priors, posteriors, data-generating processes). A distribution's capabilities are declared via protocols (SupportsSampling, SupportsLogProb, ...), and ProbPipe converts between representations as needed.
  • Record: the universal container for non-random structured data (observed datasets, hyperparameters, design matrices). Record is the deterministic counterpart of Distribution.
  • WorkflowFunction: any function decorated with @workflow_function. Pass concrete values and it runs normally; pass a Distribution where a concrete value is expected and ProbPipe propagates uncertainty automatically, returning a Distribution over the results. Array-valued inputs (RecordArray, DistributionArray) drive parameter sweeps; all returns are wrapped in the same Record / Distribution contract so pipelines compose cleanly.

Distribution and Record share a single interface for named-field access (fields, select(...), select_all()) and passing components into a WorkflowFunction, so they are interchangeable at call sites. Fields split out of a common parent stay correlated end-to-end, so the two natural shapes — f(p=posterior) and f(**posterior.select_all()) — produce identical outputs. Implementation details (algorithms, data and distribution representations) stay invisible by default, while remaining fully configurable when control is needed.

Built-in operations

ProbPipe provides a set of built-in ops, special workflow functions that dispatch based on a distribution's protocols:

  • condition_on: condition a model on observed data, automatically selecting the best inference algorithm (or specify one with method=).
  • mean, variance, cov, expectation: compute distributional summaries, with automatic Monte Carlo fallback when exact computation is unavailable.
  • sample, log_prob: draw samples or evaluate densities through a uniform interface.
  • from_distribution: convert between distribution representations via the converter registry.
  • predictive_check: built-in prior and posterior predictive checking.

Documentation | Getting Started Tutorial | API Reference

Quick Example

Logistic regression with named parameters: simulate data, fit the model, and propagate posterior uncertainty through a prediction.

import jax, jax.numpy as jnp
import tensorflow_probability.substrates.jax.glm as tfp_glm
from probpipe import (
    Normal, ProductDistribution, GLMLikelihood, SimpleModel,
    workflow_function, condition_on, mean,
)

# --- Simulate data from a logistic regression ---
beta_true = jnp.array([-1.0, 2.0])  # intercept, slope
x = jax.random.normal(jax.random.PRNGKey(0), shape=(200,))
X = jnp.column_stack([jnp.ones_like(x), x])

likelihood = GLMLikelihood(tfp_glm.Bernoulli(), X, seed=1)
y = likelihood.generate_data(beta_true, 200)

# --- 1. Build a model with named parameters ---
prior = ProductDistribution(
    intercept=Normal(loc=0.0, scale=5.0, name="intercept"),
    slope=Normal(loc=0.0, scale=5.0, name="slope"),
)
model = SimpleModel(prior, likelihood)

# --- 2. Condition on data (auto-selects NUTS) ---
posterior = condition_on(model, y)
draws = posterior.draws()            # NumericRecordArray(intercept=..., slope=...)
draws["intercept"].mean()            # -0.93  (true: -1.0)
draws["slope"].mean()                #  2.18  (true:  2.0)

# --- 3. Propagate uncertainty through a prediction ---
@workflow_function
def predict_prob(intercept, slope, x):
    return jax.nn.sigmoid(intercept + slope * x)

x_new = jnp.linspace(-3, 3, 50)
predictive = predict_prob(**posterior.select('intercept', 'slope'), x=x_new)
# predictive is a Distribution over predicted P(y=1|x) curves

predict_prob is a @workflow_function: ProbPipe samples from the posterior and evaluates the function for each draw, returning the full predictive distribution. The two posterior fields are splatted from a single parent, so ProbPipe draws them jointly — each (intercept, slope) pair stays correlated. Plotting the result:

import numpy as np, matplotlib.pyplot as plt

S = np.array(predictive.samples)     # (n_draws, 50)
lo, hi = np.percentile(S, [5, 95], axis=0)
plt.fill_between(np.array(x_new), lo, hi, alpha=0.3, label='90% CI')
plt.plot(np.array(x_new), S.mean(axis=0), lw=2, label='Posterior mean')
true_curve = jax.nn.sigmoid(beta_true[0] + beta_true[1] * x_new)
plt.plot(np.array(x_new), np.array(true_curve), 'k--', label='True')
plt.scatter(np.array(x), np.array(y), s=12, alpha=0.4, label='Data')
plt.xlabel('x'); plt.ylabel('P(y = 1 | x)'); plt.legend(fontsize=8)

Posterior predictive

Key Features

  • Protocol-based dispatch. A distribution's capabilities are declared via @runtime_checkable protocols (SupportsSampling, SupportsLogProb, SupportsMean, ...). Operations like condition_on and from_distribution use these protocols to auto-select the best algorithm from a pluggable registry. Override with method= when you want control.
  • Multiple backends. The inference registry spans TFP (NUTS, HMC, RWMH), nutpie, CmdStan, PyMC (NUTS, ADVI), and simulation-based inference (SMC-ABC via sbijax). Swap backends without changing model code.
  • Automatic distribution conversion. A converter registry converts between distribution representations (e.g., MCMC samples to KDE) as needed, using protocol-based dispatch analogous to condition_on.
  • JAX-native. Distributions and workflow functions are compatible with JAX (vmap, jit, grad), with built-in support for TFP distributions.
  • Provenance tracking. Each distribution records how it was created (algorithm, parents, metadata), enabling full lineage tracing from any result back to its inputs.
  • Prefect orchestration. Distribute pipeline steps across machines and CPUs without code changes.

Installation

Requires Python >= 3.12 (tested on 3.12 and 3.13).

git clone https://github.com/TARPS-group/prob-pipe.git
cd prob-pipe
pip install .

Core dependencies: JAX and TensorFlow Probability. ProbPipe uses tfp-nightly, which is the recommended approach for TFP on JAX since stable TFP releases are tied to TensorFlow and often lag behind JAX.

Optional extras:

pip install .[dev]       # pytest, jupyter, matplotlib, graphviz
pip install .[prefect]   # Prefect orchestration backend
pip install .[stan]      # Stan models via BridgeStan + CmdStanPy
pip install .[pymc]      # PyMC model integration
pip install .[nutpie]    # nutpie Markov chain Monte Carlo (MCMC) sampler

Ray via Prefect

ProbPipe can dispatch Prefect-orchestrated WorkflowFunction tasks to Ray via Prefect-Ray. This is not a native Ray backend: there is no WorkflowKind.RAY, probpipe.ray module, or direct ProbPipe API for ray.remote, ray.put, actors, placement groups, or resource hints.

Install Ray support through Prefect-Ray:

pip install "probpipe[prefect]"
pip install "prefect[ray]"

For local development from this repository:

pip install -e ".[prefect]"
pip install "prefect[ray]"

The local demo uses a persistent Ray head:

prefect server start
ray start --head
python example_scripts/run_ray_demo.py

See the Ray via Prefect guide for setup details, deployment notes, and current support boundaries.

Next Steps

Contributing

See CONTRIBUTING.md for development setup, PR workflow, and guidelines.

About

A Python framework for creating probabilistic workflows

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages