From d8fea6e1577b0998d5c64c0d8f01cbadcedf2c88 Mon Sep 17 00:00:00 2001 From: Richang Date: Wed, 10 Dec 2025 16:24:10 +0530 Subject: [PATCH 1/3] Add background sampling handle and docs for pm.sample --- pymc/sampling/mcmc.py | 146 +++++++++++++-------- tests/sampling/test_background_sampling.py | 33 +++++ 2 files changed, 125 insertions(+), 54 deletions(-) create mode 100644 tests/sampling/test_background_sampling.py diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index de341c68cd..c9b084a348 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -20,6 +20,7 @@ import sys import time import warnings +import threading from collections.abc import Callable, Iterator, Mapping, Sequence from typing import ( @@ -91,6 +92,44 @@ Step: TypeAlias = BlockedStep | CompoundStep +class BackgroundSampleHandle: + def __init__(self, target, args=None, kwargs=None): + self._done = threading.Event() + self._result = None + self._exception = None + args = args or () + kwargs = kwargs or {} + + def runner(): + try: + self._result = target(*args, **kwargs) + except Exception as exc: # noqa: BLE001 + self._exception = exc + finally: + self._done.set() + + self._thread = threading.Thread(target=runner, daemon=True) + + def start(self): + self._thread.start() + return self + + def done(self): + return self._done.is_set() + + def result(self, timeout=None): + self._thread.join(timeout=timeout) + if not self._done.is_set(): + raise TimeoutError("Background sampling not finished yet") + if self._exception: + raise self._exception + return self._result + + def exception(self, timeout=None): + self._thread.join(timeout=timeout) + return self._exception + + class SamplingIteratorCallback(Protocol): """Signature of the callable that may be passed to `pm.sample(callable=...)`.""" @@ -439,6 +478,7 @@ def sample( mp_ctx=None, blas_cores: int | None | Literal["auto"] = "auto", compile_kwargs: dict | None = None, + background: bool = False, **kwargs, ) -> InferenceData: ... @@ -472,6 +512,7 @@ def sample( model: Model | None = None, blas_cores: int | None | Literal["auto"] = "auto", compile_kwargs: dict | None = None, + background: bool = False, **kwargs, ) -> MultiTrace: ... @@ -504,6 +545,8 @@ def sample( blas_cores: int | None | Literal["auto"] = "auto", model: Model | None = None, compile_kwargs: dict | None = None, + background: bool = False, + _background_internal: bool = False, **kwargs, ) -> InferenceData | MultiTrace | ZarrTrace: r"""Draw samples from the posterior using the given step methods. @@ -540,7 +583,7 @@ def sample( - "combined": A single progress bar that displays the total progress across all chains. Only timing information is shown. - "split": A separate progress bar for each chain. Only timing information is shown. - - "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all + - "combined+stats" or "stats+combined": A single progress bar displaying the total progress across chains. Aggregate sample statistics are also displayed. - "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain are also displayed. @@ -618,7 +661,14 @@ def sample( Model to sample from. The model needs to have free random variables. compile_kwargs: dict, optional Dictionary with keyword argument to pass to the functions compiled by the step methods. + You can find a full list of arguments in the docstring of the step methods. + Background mode + ---------------- + - Set ``background=True`` to run sampling in a background thread; this returns a handle. + - The handle supports ``done()``, ``result()``, and ``exception()``. + - Progress bars are suppressed in background mode. + - Currently limited to ``nuts_sampler="pymc"``; other samplers raise ``NotImplementedError``. Returns ------- @@ -629,65 +679,53 @@ def sample( ``ZarrTrace`` instance. Refer to :class:`~pymc.backends.zarr.ZarrTrace` for the benefits this backend provides. - Notes - ----- - Optional keyword arguments can be passed to ``sample`` to be delivered to the - ``step_method``\ s used during sampling. - - For example: - - 1. ``target_accept`` to NUTS: nuts={'target_accept':0.9} - 2. ``transit_p`` to BinaryGibbsMetropolis: binary_gibbs_metropolis={'transit_p':.7} - - Note that available step names are: - - ``nuts``, ``hmc``, ``metropolis``, ``binary_metropolis``, - ``binary_gibbs_metropolis``, ``categorical_gibbs_metropolis``, - ``DEMetropolis``, ``DEMetropolisZ``, ``slice`` - - The NUTS step method has several options including: - - * target_accept : float in [0, 1]. The step size is tuned such that we - approximate this acceptance rate. Higher values like 0.9 or 0.95 often - work better for problematic posteriors. This argument can be passed directly to sample. - * max_treedepth : The maximum depth of the trajectory tree - * step_scale : float, default 0.25 - The initial guess for the step size scaled down by :math:`1/n**(1/4)`, - where n is the dimensionality of the parameter space - - Alternatively, if you manually declare the ``step_method``\ s, within the ``step`` - kwarg, then you can address the ``step_method`` kwargs directly. - e.g. for a CompoundStep comprising NUTS and BinaryGibbsMetropolis, - you could send :: - - step = [ - pm.NUTS([freeRV1, freeRV2], target_accept=0.9), - pm.BinaryGibbsMetropolis([freeRV3], transit_p=0.7), - ] - - You can find a full list of arguments in the docstring of the step methods. - Examples -------- .. code-block:: ipython + """ + if background and not _background_internal: + if nuts_sampler != "pymc": + raise NotImplementedError("background=True currently supports nuts_sampler='pymc' only") + progressbar = False - In [1]: import pymc as pm - ...: n = 100 - ...: h = 61 - ...: alpha = 2 - ...: beta = 2 + # Resolve the model now so the background thread has a concrete model object. + resolved_model = modelcontext(model) - In [2]: with pm.Model() as model: # context management - ...: p = pm.Beta("p", alpha=alpha, beta=beta) - ...: y = pm.Binomial("y", n=n, p=p, observed=h) - ...: idata = pm.sample() + def _run(): + return sample( + draws=draws, + tune=tune, + chains=chains, + cores=cores, + random_seed=random_seed, + progressbar=progressbar, + progressbar_theme=progressbar_theme, + step=step, + var_names=var_names, + nuts_sampler=nuts_sampler, + initvals=initvals, + init=init, + jitter_max_retries=jitter_max_retries, + n_init=n_init, + trace=trace, + discard_tuned_samples=discard_tuned_samples, + compute_convergence_checks=compute_convergence_checks, + keep_warning_stat=keep_warning_stat, + return_inferencedata=return_inferencedata, + idata_kwargs=idata_kwargs, + nuts_sampler_kwargs=nuts_sampler_kwargs, + callback=callback, + mp_ctx=mp_ctx, + blas_cores=blas_cores, + model=resolved_model, + compile_kwargs=compile_kwargs, + background=False, + _background_internal=True, + **kwargs, + ) - In [3]: az.summary(idata, kind="stats") + return BackgroundSampleHandle(target=_run).start() - Out[3]: - mean sd hdi_3% hdi_97% - p 0.609 0.047 0.528 0.699 - """ if "start" in kwargs: if initvals is not None: raise ValueError("Passing both `start` and `initvals` is not supported.") @@ -1735,4 +1773,4 @@ def model_logp_fn(ip: PointType) -> np.ndarray: for initial_point in initial_points ] - return initial_points, step + return initial_points, step \ No newline at end of file diff --git a/tests/sampling/test_background_sampling.py b/tests/sampling/test_background_sampling.py new file mode 100644 index 0000000000..55df949d99 --- /dev/null +++ b/tests/sampling/test_background_sampling.py @@ -0,0 +1,33 @@ +import pymc as pm +import pytest + + +def test_background_sampling_happy_path(): + with pm.Model(): + pm.Normal("x", 0, 1) + handle = pm.sample( + draws=20, + tune=10, + chains=1, + cores=1, + background=True, + progressbar=False, + ) + idata = handle.result() + assert hasattr(idata, "posterior") + assert idata.posterior.sizes["chain"] >= 1 + + +def test_background_sampling_raises(): + with pm.Model(): + pm.Normal("x", 0, sigma=-1) + handle = pm.sample( + draws=10, + tune=5, + chains=1, + cores=1, + background=True, + progressbar=False, + ) + with pytest.raises(Exception): + handle.result() \ No newline at end of file From 7e43b478e704dcd222b2886eec891e71105bed39 Mon Sep 17 00:00:00 2001 From: Richang Date: Wed, 10 Dec 2025 16:52:48 +0530 Subject: [PATCH 2/3] Run pre-commit fixes (end-of-file, license header, ruff fixes, etc.) --- pymc/sampling/mcmc.py | 6 +++--- tests/sampling/test_background_sampling.py | 18 ++++++++++++++++-- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index c9b084a348..e59da7479a 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -18,9 +18,9 @@ import logging import pickle import sys +import threading import time import warnings -import threading from collections.abc import Callable, Iterator, Mapping, Sequence from typing import ( @@ -103,7 +103,7 @@ def __init__(self, target, args=None, kwargs=None): def runner(): try: self._result = target(*args, **kwargs) - except Exception as exc: # noqa: BLE001 + except Exception as exc: self._exception = exc finally: self._done.set() @@ -1773,4 +1773,4 @@ def model_logp_fn(ip: PointType) -> np.ndarray: for initial_point in initial_points ] - return initial_points, step \ No newline at end of file + return initial_points, step diff --git a/tests/sampling/test_background_sampling.py b/tests/sampling/test_background_sampling.py index 55df949d99..72510353ff 100644 --- a/tests/sampling/test_background_sampling.py +++ b/tests/sampling/test_background_sampling.py @@ -1,6 +1,20 @@ -import pymc as pm +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest +import pymc as pm + def test_background_sampling_happy_path(): with pm.Model(): @@ -30,4 +44,4 @@ def test_background_sampling_raises(): progressbar=False, ) with pytest.raises(Exception): - handle.result() \ No newline at end of file + handle.result() From 2565e211cde3c3071426a39e2298ea1a4be47b4f Mon Sep 17 00:00:00 2001 From: Richang Date: Wed, 10 Dec 2025 16:57:39 +0530 Subject: [PATCH 3/3] Add background sampling test to CI matrix --- .github/workflows/tests.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 37a56f8a06..df788f9729 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -85,6 +85,7 @@ jobs: tests/sampling/test_deterministic.py tests/sampling/test_forward.py tests/sampling/test_population.py + tests/sampling/test_background_sampling.py tests/stats/test_convergence.py tests/stats/test_log_density.py tests/distributions/test_distribution.py @@ -190,7 +191,7 @@ jobs: python-version: ["3.11"] test-subset: - tests/variational/test_approximations.py tests/variational/test_callbacks.py tests/variational/test_inference.py tests/variational/test_opvi.py tests/test_initial_point.py - - tests/model/test_core.py tests/sampling/test_mcmc.py + - tests/model/test_core.py tests/sampling/test_mcmc.py tests/sampling/test_background_sampling.py - tests/gp/test_cov.py tests/gp/test_gp.py tests/gp/test_mean.py tests/gp/test_util.py tests/ode/test_ode.py tests/ode/test_utils.py tests/smc/test_smc.py tests/sampling/test_parallel.py - tests/step_methods/test_metropolis.py tests/step_methods/test_slicer.py tests/step_methods/hmc/test_nuts.py tests/step_methods/test_compound.py tests/step_methods/hmc/test_hmc.py tests/step_methods/test_state.py @@ -247,6 +248,7 @@ jobs: - | tests/sampling/test_mcmc.py + tests/sampling/test_background_sampling.py - | tests/backends/test_arviz.py @@ -347,7 +349,7 @@ jobs: floatx: [float32] python-version: ["3.13"] test-subset: - - tests/sampling/test_mcmc.py tests/ode/test_ode.py tests/ode/test_utils.py tests/distributions/test_transform.py + - tests/sampling/test_mcmc.py tests/ode/test_ode.py tests/ode/test_utils.py tests/distributions/test_transform.py tests/sampling/test_background_sampling.py fail-fast: false runs-on: ${{ matrix.os }} env: