-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Added basic SAASBO implementation #569
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
WalkthroughA new class, Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant SAASBO
participant PyroGP
participant MCMC
User->>SAASBO: maximize(init_points, n_iter)
SAASBO->>SAASBO: Probe initial points
loop Iterative Optimization
SAASBO->>SAASBO: suggest()
SAASBO->>PyroGP: _define_gp_model(X, y)
SAASBO->>MCMC: _fit_gp()
MCMC-->>SAASBO: Posterior samples
SAASBO->>SAASBO: Evaluate acquisition over candidates
SAASBO->>User: Probe suggested point
SAASBO->>MCMC: Refit GP with updated data
end
Poem
✨ Finishing Touches
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
🧹 Nitpick comments (3)
bayes_opt/bayesian_optimization.py (3)
497-500
: Unintentionally re-seeding_random_state
self._random_state
is already initialised by the base‐class constructor.
Re-assigning it here silently discards the seeded RNG (and any state that may have been advanced during the parent initialisation).- self._random_state = ensure_rng(random_state)
Simply drop this line – you already inherited a valid RNG.
558-570
: Heavy MCMC in every call – consider caching / incremental fitting
_fit_gp()
reruns a full NUTS chain after every new observation (maximize()
loop).
For realistic dimensions this is orders-of-magnitude slower than the classic optimiser and effectively blocks interactivity.Options:
- Fit once every k iterations or when the ESS drops below a threshold.
- Use
warm_start=True
on the NUTS kernel to continue chains.- Offer a
fit_every
parameter.This is not a correctness bug but will surprise users.
671-678
:ScreenLogger
has nowarning
method & avoid f-string in loggingStatic analysis flags both issues:
- self.logger.warning(f"Ignored unknown parameters: {params}") + # Avoid failing if ScreenLogger lacks a `warning` method + if hasattr(self.logger, "warning"): + self.logger.warning("Ignored unknown parameters: %s", params) + else: + warn(f"Ignored unknown parameters: {params}", stacklevel=1)Also add a trailing newline at EOF to silence
ruff W292
.🧰 Tools
🪛 Ruff (0.11.9)
678-678: Logging statement uses f-string
(G004)
678-678: No newline at end of file
Add trailing newline
(W292)
🪛 Pylint (3.3.7)
[convention] 678-678: Final newline missing
(C0304)
[error] 678-678: Instance of 'ScreenLogger' has no 'warning' member
(E1101)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
bayes_opt/bayesian_optimization.py
(2 hunks)
🧰 Additional context used
🪛 Pylint (3.3.7)
bayes_opt/bayesian_optimization.py
[error] 16-16: Unable to import 'torch'
(E0401)
[error] 17-17: Unable to import 'pyro'
(E0401)
[error] 18-18: Unable to import 'numpy'
(E0401)
[error] 19-19: Unable to import 'scipy.optimize'
(E0401)
[error] 20-20: Unable to import 'sklearn.gaussian_process'
(E0401)
[error] 21-21: Unable to import 'sklearn.gaussian_process.kernels'
(E0401)
[error] 22-22: Unable to import 'pyro.distributions'
(E0401)
[error] 23-23: Unable to import 'pyro.infer.mcmc'
(E0401)
[convention] 22-22: Imports from package pyro are not grouped
(C0412)
[convention] 579-579: Line too long (106/100)
(C0301)
[convention] 678-678: Final newline missing
(C0304)
[refactor] 469-469: Too many arguments (12/5)
(R0913)
[refactor] 469-469: Too many positional arguments (12/5)
(R0917)
[error] 474-474: Undefined variable 'Optional'
(E0602)
[error] 477-477: Undefined variable 'Optional'
(E0602)
[convention] 511-511: Argument name "X" doesn't conform to snake_case naming style
(C0103)
[convention] 513-513: Argument name "X" doesn't conform to snake_case naming style
(C0103)
[warning] 511-511: Unused argument 'X'
(W0613)
[warning] 511-511: Unused argument 'y'
(W0613)
[convention] 552-552: Variable name "X" doesn't conform to snake_case naming style
(C0103)
[error] 678-678: Instance of 'ScreenLogger' has no 'warning' member
(E1101)
🪛 Ruff (0.11.9)
bayes_opt/bayesian_optimization.py
474-474: Undefined name Optional
(F821)
477-477: Undefined name Optional
(F821)
659-659: Missing argument description in the docstring for set_gp_params
: **params
(D417)
678-678: Logging statement uses f-string
(G004)
678-678: No newline at end of file
Add trailing newline
(W292)
# Compute acquisition function for candidates | ||
acq = self._acquisition_function.evaluate( | ||
candidates=candidates_tensor, | ||
gp=gp, | ||
target_space=self._space, | ||
) | ||
acq_values += acq / (self.num_samples // self.thinning) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Acquisition function assumes scikit-learn GP interface
acquisition.evaluate()
expects a .predict(X, return_std=True)
method (sklearn-style).
pyro.contrib.gp.models.GPRegression
provides .forward
/ .model
instead – the call will raise AttributeError
.
Provide a thin adapter:
class _PyroGPWrapper:
def __init__(self, gp):
self._gp = gp
def predict(self, X, return_std=True):
with torch.no_grad():
mvn = self._gp(X)
mean = mvn.mean.detach().cpu().numpy()
std = mvn.variance.sqrt().detach().cpu().numpy()
return (mean, std) if return_std else mean
and pass gp=_PyroGPWrapper(gp)
to evaluate()
.
🤖 Prompt for AI Agents
In bayes_opt/bayesian_optimization.py around lines 615 to 621, the acquisition
function's evaluate method expects a Gaussian Process object with a
sklearn-style predict method, but the current GP object from
pyro.contrib.gp.models lacks this and will cause an AttributeError. To fix this,
create a thin adapter class _PyroGPWrapper that wraps the pyro GP and implements
a predict method returning mean and standard deviation as numpy arrays, then
pass an instance of this wrapper (gp=_PyroGPWrapper(gp)) to the evaluate call
instead of the raw gp object.
def __init__( | ||
self, | ||
f: Callable[..., float] | None, | ||
pbounds: Mapping[str, tuple[float, float]], | ||
acquisition_function: AcquisitionFunction | None = None, | ||
constraint: Optional[NonlinearConstraint] = None, | ||
random_state: int | RandomState | None = None, | ||
verbose: int = 2, | ||
bounds_transformer: Optional[DomainTransformer] = None, | ||
allow_duplicate_points: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional
is not imported → runtime NameError
Optional
is used in the type signature but never imported. Add it next to the other typing imports or switch to the PEP-604 union syntax already used elsewhere.
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, Optional
(or replace the two Optional[...]
occurrences with | None
).
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def __init__( | |
self, | |
f: Callable[..., float] | None, | |
pbounds: Mapping[str, tuple[float, float]], | |
acquisition_function: AcquisitionFunction | None = None, | |
constraint: Optional[NonlinearConstraint] = None, | |
random_state: int | RandomState | None = None, | |
verbose: int = 2, | |
bounds_transformer: Optional[DomainTransformer] = None, | |
allow_duplicate_points: bool = False, | |
-from typing import TYPE_CHECKING, Any | |
+from typing import TYPE_CHECKING, Any, Optional |
🧰 Tools
🪛 Ruff (0.11.9)
474-474: Undefined name Optional
(F821)
477-477: Undefined name Optional
(F821)
🪛 Pylint (3.3.7)
[refactor] 469-469: Too many arguments (12/5)
(R0913)
[refactor] 469-469: Too many positional arguments (12/5)
(R0917)
[error] 474-474: Undefined variable 'Optional'
(E0602)
[error] 477-477: Undefined variable 'Optional'
(E0602)
🤖 Prompt for AI Agents
In bayes_opt/bayesian_optimization.py around lines 469 to 478, the type hint
uses Optional without importing it, causing a runtime NameError. Fix this by
either importing Optional from typing alongside other imports or replace all
Optional[...] usages with the PEP-604 union syntax using | None for consistency
and to avoid the import.
for i in range(0, self.num_samples, self.thinning): | ||
# Extract hyperparameters for this sample | ||
outputscale = self._mcmc_samples["outputscale"][i] | ||
noise = self._mcmc_samples["noise"][i] | ||
lengthscale = self._mcmc_samples["lengthscale"][i] | ||
mean = self._mcmc_samples["mean"][i] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lengthscale
key is missing from posterior samples
The posterior dictionary only contains sites declared with pyro.sample
.
Because lengthscale
was a plain tensor (see previous comment) the lookup below raises KeyError
.
After adding pyro.deterministic("lengthscale", …)
or sampling it directly, keep in mind that deterministic sites are not returned by get_samples()
by default; use mcmc.get_samples(group_by_chain=False, include_observed=False, include_deterministic=True)
.
Update the retrieval accordingly:
- lengthscale = self._mcmc_samples["lengthscale"][i]
+ lengthscale = self._mcmc_samples_deterministic["lengthscale"][i]
(or compute tau * beta
on the fly).
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In bayes_opt/bayesian_optimization.py around lines 594 to 600, the code tries to
access "lengthscale" from the MCMC samples dictionary, but this key is missing
because "lengthscale" was not declared as a sampled or deterministic site in the
Pyro model. To fix this, modify the Pyro model to include "lengthscale" as a
deterministic site using pyro.deterministic or sample it directly. Then, when
retrieving samples, call mcmc.get_samples with include_deterministic=True to
ensure "lengthscale" is included. Alternatively, compute "lengthscale" on the
fly from its components instead of storing it in samples.
def _define_gp_model(self, X: torch.Tensor, y: torch.Tensor) -> Callable: | ||
"""Define the Pyro GP model with a horseshoe prior on length scales.""" | ||
def gp_model(X: torch.Tensor, y: torch.Tensor): | ||
# Kernel hyperparameters | ||
outputscale = pyro.sample("outputscale", dist.LogNormal(0.0, 1.0)) | ||
noise = pyro.sample("noise", dist.LogNormal(-2.0, 1.0)) | ||
|
||
# Horseshoe prior on length scales for each dimension | ||
dim = X.shape[1] | ||
tau = pyro.sample("tau", dist.HalfCauchy(0.1)) | ||
beta = pyro.sample("beta", dist.HalfCauchy(torch.ones(dim))) | ||
lengthscale = tau * beta | ||
|
||
# Matern 5/2 kernel with horseshoe length scales | ||
kernel = pyro.contrib.gp.kernels.Matern52( | ||
input_dim=dim, | ||
lengthscale=lengthscale, | ||
variance=outputscale, | ||
) | ||
|
||
# Define the GP | ||
gpr = pyro.contrib.gp.models.GPRegression( | ||
X=X, | ||
y=y, | ||
kernel=kernel, | ||
noise=noise, | ||
) | ||
|
||
# Sample the mean | ||
mean = pyro.sample("mean", dist.Normal(0.0, 1.0)) | ||
gpr.mean = mean | ||
return gpr | ||
|
||
return gp_model | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gp_model
does not produce a valid Pyro model – MCMC will fail
- Returning
gpr
instead of callinggpr.model()
means the trace has no likelihood; NUTS cannot compute a log-probability. - The sample site
lengthscale
is never registered (only the python variable exists), so it is absent from the posterior.
A minimal working fix:
- lengthscale = tau * beta
+ lengthscale = pyro.deterministic("lengthscale", tau * beta)
@@
- gpr = pyro.contrib.gp.models.GPRegression(
+ gpr = pyro.contrib.gp.models.GPRegression(
X=X,
y=y,
kernel=kernel,
noise=noise,
)
@@
- mean = pyro.sample("mean", dist.Normal(0.0, 1.0))
- gpr.mean = mean
- return gpr
+ mean = pyro.sample("mean", dist.Normal(0.0, 1.0))
+ gpr.mean_function = lambda _x: mean
+
+ # run the GP’s own model to create the 'obs' site
+ gpr.model()
Without this change mcmc.run()
raises at runtime.
Committable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Pylint (3.3.7)
[convention] 511-511: Argument name "X" doesn't conform to snake_case naming style
(C0103)
[convention] 513-513: Argument name "X" doesn't conform to snake_case naming style
(C0103)
[warning] 511-511: Unused argument 'X'
(W0613)
[warning] 511-511: Unused argument 'y'
(W0613)
🤖 Prompt for AI Agents
In bayes_opt/bayesian_optimization.py lines 511 to 545, the gp_model function
incorrectly returns the GPRegression object instead of calling its model method,
causing MCMC to lack a likelihood and fail. Also, the lengthscale variable is
computed but not registered as a sample site, so it is missing from the
posterior. To fix this, replace the return of gpr with a call to gpr.model(),
and explicitly register lengthscale as a pyro sample site so it appears in the
trace. This ensures the model is valid for MCMC sampling and all parameters are
properly tracked.
import torch | ||
import pyro | ||
import numpy as np | ||
from scipy.optimize import NonlinearConstraint | ||
from sklearn.gaussian_process import GaussianProcessRegressor | ||
from sklearn.gaussian_process.kernels import Matern | ||
import pyro.distributions as dist | ||
from pyro.infer.mcmc import NUTS, MCMC | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Wrap heavy optional dependencies with lazy-import guards
torch
, pyro
, and friends are optional/heavyweight. Importing them unconditionally will crash users that only want the classic GP optimiser and do not have these libs installed.
-import torch
-import pyro
-import pyro.distributions as dist
-from pyro.infer.mcmc import NUTS, MCMC
+try:
+ import torch
+ import pyro
+ import pyro.distributions as dist
+ from pyro.infer.mcmc import NUTS, MCMC
+except ModuleNotFoundError as exc: # pragma: no cover
+ raise ModuleNotFoundError(
+ "SAASBO requires the optional dependencies `torch` and `pyro`. "
+ "Install them with `pip install bayes_opt[saas]`."
+ ) from exc
This keeps the original BO class usable while surfacing a clear message for SAAS users.
Committable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Pylint (3.3.7)
[error] 16-16: Unable to import 'torch'
(E0401)
[error] 17-17: Unable to import 'pyro'
(E0401)
[error] 18-18: Unable to import 'numpy'
(E0401)
[error] 19-19: Unable to import 'scipy.optimize'
(E0401)
[error] 20-20: Unable to import 'sklearn.gaussian_process'
(E0401)
[error] 21-21: Unable to import 'sklearn.gaussian_process.kernels'
(E0401)
[error] 22-22: Unable to import 'pyro.distributions'
(E0401)
[error] 23-23: Unable to import 'pyro.infer.mcmc'
(E0401)
[convention] 22-22: Imports from package pyro are not grouped
(C0412)
🤖 Prompt for AI Agents
In bayes_opt/bayesian_optimization.py around lines 16 to 24, the imports for
torch, pyro, and related heavy dependencies are unconditional, which can cause
crashes for users who do not have these libraries installed but want to use the
classic GP optimizer. To fix this, wrap these imports in try-except blocks or
use lazy-import guards so that the imports are only attempted when needed, and
provide clear error messages if the optional dependencies are missing. This will
keep the original BO class usable without requiring these heavy dependencies.
This PR also adds Pyro as a requirement, which is used for probabilistic sampling with MCMC and NUT Sampler. This adds SAASBO as a child class for the
BayesianOptimization
class. I need feedback for testing as the SAASBO class requires a large number of variables.Summary by CodeRabbit