Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
18512f4
Some progress
manuelgloeckler Mar 18, 2025
f5e8e7f
Some progress
manuelgloeckler Mar 18, 2025
7def57b
Merge branch '1446-add-api-for-guidance' of https://github.com/sbi-de…
manuelgloeckler Mar 18, 2025
ab6b7b4
Moving and renaming. Refactoring IID function to reduce code redundancy
manuelgloeckler Mar 18, 2025
52825d3
basic api on potentials
manuelgloeckler Mar 18, 2025
0237adf
working minimal example
manuelgloeckler Mar 18, 2025
2e7d27b
Docs: Introduce RTD website
michaeldeistler Mar 17, 2025
9b7a14f
Fix tests for new docs
michaeldeistler Mar 18, 2025
45d9ca6
chore: pytest split for ci workflow (#1465)
schroedk Mar 18, 2025
6f891da
Prevent notebook execution upon doc build
michaeldeistler Mar 18, 2025
5592682
Fix broken links on website
michaeldeistler Mar 18, 2025
bd2735f
internal dataclasses but external stringly...
manuelgloeckler Mar 18, 2025
762ed43
Other more general guidance methods
manuelgloeckler Mar 19, 2025
9653d0e
universal guidance working
manuelgloeckler Mar 19, 2025
fe23d43
Specialized child class for interval constriants
manuelgloeckler Mar 19, 2025
4789025
Some progress
manuelgloeckler Mar 18, 2025
7b43583
Moving and renaming. Refactoring IID function to reduce code redundancy
manuelgloeckler Mar 18, 2025
ec7ae06
basic api on potentials
manuelgloeckler Mar 18, 2025
1469881
working minimal example
manuelgloeckler Mar 18, 2025
b5edda9
internal dataclasses but external stringly...
manuelgloeckler Mar 18, 2025
51cde80
Other more general guidance methods
manuelgloeckler Mar 19, 2025
251fc88
universal guidance working
manuelgloeckler Mar 19, 2025
4a5f431
Specialized child class for interval constriants
manuelgloeckler Mar 19, 2025
7b5b2e1
Merge branch '1446-add-api-for-guidance' of https://github.com/sbi-de…
manuelgloeckler Mar 19, 2025
3652eb5
Formatting
manuelgloeckler Mar 19, 2025
48906b1
typing
manuelgloeckler Mar 19, 2025
1eba715
tests
manuelgloeckler Mar 19, 2025
f29d75c
check independet prior, problem
manuelgloeckler Mar 19, 2025
c38db7a
Fix hidden bug in score_utils
manuelgloeckler Mar 20, 2025
68616d4
Format
manuelgloeckler Mar 20, 2025
a75a944
Removing unncessary. Adding tests on basic API
manuelgloeckler Mar 20, 2025
2efb7c7
test lower upper, lower upper where switched
manuelgloeckler Mar 20, 2025
ff289b1
Formats and tests
manuelgloeckler Mar 20, 2025
5fdbb4c
add documentation
manuelgloeckler Mar 20, 2025
3679f3c
Merge remote-tracking branch 'origin/main' into 1446-add-api-for-guid…
manuelgloeckler Sep 5, 2025
5972a29
Port changes
manuelgloeckler Sep 5, 2025
50fde2f
Update to recent changes
manuelgloeckler Sep 8, 2025
5637db4
Tutorial
manuelgloeckler Sep 8, 2025
e50f6a1
Merge remote-tracking branch 'origin/main' into 1446-add-api-for-guid…
manuelgloeckler Sep 8, 2025
0476ab0
froamting
manuelgloeckler Sep 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 206 additions & 19 deletions docs/advanced_tutorials/20_score_based_methods_new_features.ipynb

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions sbi/inference/posteriors/vector_field_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def sample(
Literal["fnpe", "gauss", "auto_gauss", "jac_gauss"]
] = None,
iid_params: Optional[Dict] = None,
guidance_method: Optional[str] = None,
guidance_params: Optional[Dict] = None,
max_sampling_batch_size: int = 10_000,
sample_with: Optional[str] = None,
show_progress_bars: bool = True,
Expand Down Expand Up @@ -190,6 +192,17 @@ def sample(
SCORE_DEFINED and MARGINALS_DEFINED class attributes set to True.
iid_params: Additional parameters passed to the iid method. See the specific
`IIDScoreFunction` child class for details.
guidance_method: Method to guide the diffusion process. If None, no guidance
is used. currently we support `affine_classifier_free`, which allows to
scale and shift the "likelihood" or "prior" score contribution. This can
be used to perform "super" conditioning i.e. shring the variance of the
likelihood. `Universal` can be used to guide the diffusion process with
a general guidance function. `Interval` is an isntance of that where
the guidance function constraints the diffusion process to a given
interval.
guidance_params: Additional parameters passed to the guidance method. See
the specific `ScoreAdaptation` child class for details, specifically
`AffineClassifierFreeCfg`, `UniversalCfg`, and `IntervalCfg`.
max_sampling_batch_size: Maximum batch size for sampling.
sample_with: Sampling method to use - 'ode' or 'sde'. Note that in order to
use the 'sde' sampling method, the vector field estimator must support
Expand All @@ -208,6 +221,8 @@ def sample(
x_is_iid=is_iid,
iid_method=iid_method or self.potential_fn.iid_method,
iid_params=iid_params,
guidance_method=guidance_method,
guidance_params=guidance_params,
)

num_samples = torch.Size(sample_shape).numel()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import functools
import math
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Optional, Type, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Distribution

Expand All @@ -21,6 +23,7 @@
from sbi.utils.torchutils import ensure_theta_batched

IID_METHODS = {}
GUIDANCE_METHODS = {}


def get_iid_method(name: str) -> Type["IIDScoreFunction"]:
Expand All @@ -40,6 +43,40 @@ def get_iid_method(name: str) -> Type["IIDScoreFunction"]:
return IID_METHODS[name]


def get_guidance_method(name: str) -> Type["ScoreAdaptation"]:
r"""
Retrieves the guidance method by name.

Args:
name: The name of the guidance method.

Returns:
The guidance method class and its default configuration.
"""
if name not in GUIDANCE_METHODS:
raise NotImplementedError(f"Method {name} for guidance not implemented.")
return GUIDANCE_METHODS[name]


def register_guidance_method(name: str, default_cfg: Optional[Type] = None) -> Callable:
r"""
Registers a guidance method and its default configuration.

Args:
name: The name of the guidance method.
default_cfg: The default configuration class for the guidance method.

Returns:
A decorator function to register the guidance method class.
"""

def decorator(cls: Type["ScoreAdaptation"]) -> Type["ScoreAdaptation"]:
GUIDANCE_METHODS[name] = (cls, default_cfg)
return cls

return decorator


def register_iid_method(name: str) -> Callable:
r"""
Registers an IID method.
Expand All @@ -58,6 +95,230 @@ def decorator(cls: Type["IIDScoreFunction"]) -> Type["IIDScoreFunction"]:
return decorator


class ScoreAdaptation(ABC):
def __init__(
self,
vf_estimator: ConditionalVectorFieldEstimator,
prior: Optional[Distribution],
device: str = "cpu",
):
"""This class manages manipulating the score estimator to impose additional
constraints on the posterior via guidance.

Args:
score_estimator: The score estimator.
prior: The prior distribution.
device: The device on which to evaluate the potential.
"""
self.vf_estimator = vf_estimator
self.prior = prior
self.device = device

@abstractmethod
def __call__(self, input: Tensor, condition: Tensor, time: Optional[Tensor] = None):
pass

def score(self, input: Tensor, condition: Tensor, t: Optional[Tensor] = None):
return self.__call__(input, condition, t)


@dataclass
class AffineClassifierFreeGuidanceCfg:
prior_scale: Union[float, Tensor] = 1.0
prior_shift: Union[float, Tensor] = 0.0
likelihood_scale: Union[float, Tensor] = 1.0
likelihood_shift: Union[float, Tensor] = 0.0


@register_guidance_method("affine_classifier_free", AffineClassifierFreeGuidanceCfg)
class AffineClassifierFreeGuidance(ScoreAdaptation):
def __init__(
self,
vf_estimator: ConditionalVectorFieldEstimator,
prior: Optional[Distribution],
cfg: AffineClassifierFreeGuidanceCfg,
device: str = "cpu",
):
"""This class manages manipulating the score estimator to temper or shift the
prior and likelihood.

This is usually known as classifier-free guidance. And works by decomposing the
posterior score into a prior and likelihood component. These can then be scaled
and shifted to impose change the posterior to p

Args:
score_estimator: The score estimator.
prior: The prior distribution.
cfg: Configuration for the affine classifier-free guidance. This includes
the scale and shift applied to the prior and likelihood contributions.
device: The device on which to evaluate the potential.

References:
- [1] Classifier-Free Diffusion Guidance (2022)
- [2] All-in-one simulation-based inference (2024)

"""

if prior is None:
raise ValueError(
"Prior is required for classifier-free guidance, please"
" provide as least an improper empirical prior."
)

self.prior_scale = torch.tensor(cfg.prior_scale, device=device)
self.prior_shift = torch.tensor(cfg.prior_shift, device=device)
self.likelihood_scale = torch.tensor(cfg.likelihood_scale, device=device)
self.likelihood_shift = torch.tensor(cfg.likelihood_shift, device=device)
super().__init__(vf_estimator, prior, device)

def marginal_prior_score(self, theta: Tensor, time: Tensor):
"""Computes the marginal prior score analyticaly (or approximatly)"""
m = self.vf_estimator.mean_t_fn(time)
std = self.vf_estimator.std_fn(time)
marginal_prior = marginalize(self.prior, m, std) # type: ignore
marginal_prior_score = compute_score(marginal_prior, theta)
return marginal_prior_score

def __call__(self, input: Tensor, condition: Tensor, time: Optional[Tensor] = None):
if time is None:
time = torch.tensor([self.vf_estimator.t_min])

posterior_score = self.vf_estimator(input=input, condition=condition, time=time)
prior_score = self.marginal_prior_score(input, time)
ll_score = posterior_score - prior_score
ll_score_mod = ll_score * self.likelihood_scale + self.likelihood_shift
prior_score_mod = prior_score * self.prior_scale + self.prior_shift

return ll_score_mod + prior_score_mod


@dataclass
class UniversalGuidanceCfg:
guidance_fn: Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
guidance_fn_score: Optional[Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]] = (
None
)


@register_guidance_method("universal", UniversalGuidanceCfg)
class UniversalGuidance(ScoreAdaptation):
def __init__(
self,
vf_estimator: ConditionalVectorFieldEstimator,
prior: Optional[Distribution],
cfg: UniversalGuidanceCfg,
device: str = "cpu",
):
"""This class manages manipulating the score estimator using a custom guidance
function.

Args:
score_estimator: The score estimator.
prior: The prior distribution.
cfg: Configuration for the universal guidance.
device: The device on which to evaluate the potential.



References:
- [1] Universal Guidance for Diffusion Models (2022)
"""
self.guidance_fn = cfg.guidance_fn

if cfg.guidance_fn_score is None:

def guidance_fn_score(input, condition, m, std):
with torch.enable_grad():
input = input.detach().clone().requires_grad_(True)
score = torch.autograd.grad(
cfg.guidance_fn(input, condition, m, std).sum(),
input,
create_graph=True,
)[0]
return score

self.guidance_fn_score = guidance_fn_score
else:
self.guidance_fn_score = cfg.guidance_fn_score

super().__init__(vf_estimator, prior, device)

def __call__(self, input: Tensor, condition: Tensor, time: Optional[Tensor] = None):
if time is None:
time = torch.tensor([self.vf_estimator.t_min])
score = self.vf_estimator(input, condition, time)
m = self.vf_estimator.mean_t_fn(time)
std = self.vf_estimator.std_fn(time)

# Tweedie's formula for denoising
denoised_input = (input + std**2 * score) / m
guidance_score = self.guidance_fn_score(denoised_input, condition, m, std)

return score + guidance_score


@dataclass
class IntervalGuidanceCfg:
lower_bound: Optional[Union[float, Tensor]]
upper_bound: Optional[Union[float, Tensor]]
mask: Optional[Tensor] = None
scale_factor: float = 0.5


@register_guidance_method("interval", IntervalGuidanceCfg)
class IntervalGuidance(UniversalGuidance):
def __init__(
self,
vf_estimator: ConditionalVectorFieldEstimator,
prior: Optional[Distribution],
cfg: IntervalGuidanceCfg,
device: str = "cpu",
):
"""Implements interval guidance to constrain parameters within bounds.

Args:
score_estimator: The score estimator.
prior: The prior distribution.
cfg: Configuration specifying the interval bounds.
device: The device on which to evaluate the potential.

References:
- [2] All-in-one simulation-based inference (2024)
"""

def interval_fn(input, condition, m, std):
if cfg.lower_bound is None and cfg.upper_bound is None:
raise ValueError(
"At least one of lower_bound or upper_bound is required. Otherwise"
" the guidance function has no effect."
)

scale = cfg.scale_factor / (m**2 * std**2)
upper_bound = (
F.logsigmoid(-scale * (input - cfg.upper_bound))
if cfg.upper_bound is not None
else torch.zeros_like(input)
)
lower_bound = (
F.logsigmoid(scale * (input - cfg.lower_bound))
if cfg.lower_bound is not None
else torch.zeros_like(input)
)
out = upper_bound + lower_bound
if cfg.mask is not None:
if cfg.mask.shape != out.shape:
cfg.mask = cfg.mask.unsqueeze(0).expand_as(out)
out = torch.where(cfg.mask, out, torch.zeros_like(out))
return out

super().__init__(
vf_estimator,
prior,
UniversalGuidanceCfg(guidance_fn=interval_fn),
device=device,
)


class IIDScoreFunction(ABC):
def __init__(
self,
Expand Down Expand Up @@ -217,39 +478,16 @@ def __call__(
base_score = self.vector_field_estimator.score(inputs, conditions, time)

# Compute the prior score
prior_score = self.prior_score_weight_fn(time) * self.prior_score_fn(inputs)

prior_score = self.prior_score_weight_fn(time) * compute_score(
self.prior, inputs
)

# Accumulate
score = (1 - N) * prior_score + base_score.sum(-2, keepdim=True)

return score

def prior_score_fn(self, theta: Tensor) -> Tensor:
r"""
Computes the score of the prior distribution.

Args:
theta: The parameters at which to evaluate the prior score.

Returns:
The computed prior score.
"""
# NOTE The try except is for unifrom priors which do not have a grad, and
# implementations that do not implement the log_prob method.
try:
with torch.enable_grad():
theta = theta.detach().clone().requires_grad_(True)
prior_log_prob = self.prior.log_prob(theta)
prior_score = torch.autograd.grad(
prior_log_prob,
theta,
grad_outputs=torch.ones_like(prior_log_prob),
create_graph=True,
)[0].detach()
except Exception:
prior_score = torch.zeros_like(theta)
return prior_score


class BaseGaussCorrectedScoreFunction(IIDScoreFunction):
def __init__(
Expand Down Expand Up @@ -735,6 +973,24 @@ def marginal_denoising_posterior_precision_est_fn(
return denoising_posterior_precision


def compute_score(p: Distribution, inputs: Tensor):
# NOTE The try except is for unifrom priors which do not have a grad, and
# implementations that do not implement the log_prob method.
try:
with torch.enable_grad():
inputs = inputs.detach().clone().requires_grad_(True)
log_prob = p.log_prob(inputs)
score = torch.autograd.grad(
log_prob,
inputs,
grad_outputs=torch.ones_like(log_prob),
create_graph=True,
)[0].detach()
except Exception:
score = torch.zeros_like(inputs)
return score


def ensure_lam_positive_definite(
denoising_prior_precision: torch.Tensor,
denoising_posterior_precision: torch.Tensor,
Expand Down
Loading