-
Notifications
You must be signed in to change notification settings - Fork 45
Add L-BFGS optimizer from pyensmallen #566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
gauravmanmode
wants to merge
19
commits into
optimagic-dev:main
Choose a base branch
from
gauravmanmode:pyensmallen-draft
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
62037db
add pyensmallen lbfgs
gauravmanmode 788b072
Merge branch 'main' into pyensmallen-draft
gauravmanmode de1acd0
add stub
gauravmanmode 622fd2c
ensmallen lbfgs improve
gauravmanmode e9515cc
fix environment problems for windows and mac
gauravmanmode 5b4a660
Add how-to add an algorithm guide and improve documentation of some i…
janosg 1215434
change test names, add docs, use fun_and_jac instead of fun and jac
gauravmanmode fed0f2d
add test
gauravmanmode 9211c60
Merge branch 'main' into pyensmallen-draft
gauravmanmode 22104fa
rename
gauravmanmode 131632d
Merge branch 'main' into pyensmallen-draft
gauravmanmode 0b5f667
Merge branch 'optimagic-dev:main' into pyensmallen-draft
gauravmanmode 86f0222
Update pyensmallen_optimizers.py
gauravmanmode c5db83c
use pyensmallen_experimental for now
gauravmanmode 83c6645
Merge branch 'main' into pyensmallen-draft
gauravmanmode 94ff0f0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 52aefde
Merge branch 'main' into pyensmallen-draft
gauravmanmode e3a1c87
add message
gauravmanmode d83433b
change to pyensmallen_experimental
gauravmanmode File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
"""Implement ensmallen optimizers.""" | ||
|
||
from dataclasses import dataclass | ||
from typing import Any | ||
|
||
import numpy as np | ||
from numpy.typing import NDArray | ||
|
||
from optimagic import mark | ||
from optimagic.config import IS_PYENSMALLEN_INSTALLED | ||
from optimagic.optimization.algo_options import ( | ||
CONVERGENCE_FTOL_REL, | ||
CONVERGENCE_GTOL_ABS, | ||
LIMITED_MEMORY_STORAGE_LENGTH, | ||
MAX_LINE_SEARCH_STEPS, | ||
STOPPING_MAXITER, | ||
) | ||
from optimagic.optimization.algorithm import Algorithm, InternalOptimizeResult | ||
from optimagic.optimization.internal_optimization_problem import ( | ||
InternalOptimizationProblem, | ||
) | ||
from optimagic.typing import AggregationLevel, NonNegativeFloat, PositiveInt | ||
|
||
# use pyensmallen_experimental for testing purpose | ||
if IS_PYENSMALLEN_INSTALLED: | ||
import pyensmallen_experimental as pye | ||
|
||
MIN_LINE_SEARCH_STEPS = 1e-20 | ||
"""The minimum step of the line search.""" | ||
MAX_LINE_SEARCH_TRIALS = 50 | ||
"""The maximum number of trials for the line search (before giving up).""" | ||
ARMIJO_CONSTANT = 1e-4 | ||
"""Controls the accuracy of the line search routine for determining the Armijo | ||
condition.""" | ||
WOLFE_CONDITION = 0.9 | ||
"""Parameter for detecting the Wolfe condition.""" | ||
|
||
STEP_SIZE = 0.001 | ||
"""Step size for each iteration.""" | ||
BATCH_SIZE = 32 | ||
"""Step size for each iteration.""" | ||
EXP_DECAY_RATE_FOR_FIRST_MOMENT = 0.9 | ||
"""Exponential decay rate for the first moment estimates.""" | ||
EXP_DECAY_RATE_FOR_WEIGHTED_INF_NORM = 0.999 | ||
"""Exponential decay rate for the first moment estimates.""" | ||
|
||
|
||
@mark.minimizer( | ||
name="ensmallen_lbfgs", | ||
solver_type=AggregationLevel.SCALAR, | ||
is_available=IS_PYENSMALLEN_INSTALLED, | ||
is_global=False, | ||
needs_jac=True, | ||
needs_hess=False, | ||
supports_parallelism=False, | ||
supports_bounds=False, | ||
supports_linear_constraints=False, | ||
supports_nonlinear_constraints=False, | ||
disable_history=False, | ||
) | ||
@dataclass(frozen=True) | ||
class EnsmallenLBFGS(Algorithm): | ||
limited_memory_storage_length: PositiveInt = LIMITED_MEMORY_STORAGE_LENGTH | ||
stopping_maxiter: PositiveInt = STOPPING_MAXITER | ||
armijo_constant: NonNegativeFloat = ARMIJO_CONSTANT # needs review | ||
wolfe_condition: NonNegativeFloat = WOLFE_CONDITION # needs review | ||
convergence_gtol_abs: NonNegativeFloat = CONVERGENCE_GTOL_ABS | ||
convergence_ftol_rel: NonNegativeFloat = CONVERGENCE_FTOL_REL | ||
max_line_search_trials: PositiveInt = MAX_LINE_SEARCH_TRIALS | ||
min_step_for_line_search: NonNegativeFloat = MIN_LINE_SEARCH_STEPS | ||
max_step_for_line_search: NonNegativeFloat = MAX_LINE_SEARCH_STEPS | ||
|
||
def _solve_internal_problem( | ||
self, problem: InternalOptimizationProblem, x0: NDArray[np.float64] | ||
) -> InternalOptimizeResult: | ||
optimizer = pye.L_BFGS( | ||
numBasis=self.limited_memory_storage_length, | ||
maxIterations=self.stopping_maxiter, | ||
armijoConstant=self.armijo_constant, | ||
wolfe=self.wolfe_condition, | ||
minGradientNorm=self.convergence_gtol_abs, | ||
factr=self.convergence_ftol_rel, | ||
maxLineSearchTrials=self.max_line_search_trials, | ||
minStep=self.min_step_for_line_search, | ||
maxStep=self.max_step_for_line_search, | ||
) | ||
|
||
def objective_function( | ||
x: NDArray[np.float64], grad: NDArray[np.float64] | ||
) -> np.float64: | ||
fun, jac = problem.fun_and_jac(x) | ||
grad[:] = jac | ||
return np.float64(fun) | ||
|
||
# Passing a Report class to the optimizer allows us to retrieve additional info | ||
ens_res: dict[str, Any] = dict() | ||
report = pye.Report(resultIn=ens_res, disableOutput=True) | ||
best_x = optimizer.optimize(objective_function, x0, report) | ||
|
||
res = InternalOptimizeResult( | ||
x=best_x, | ||
fun=ens_res["objective_value"], | ||
n_iterations=ens_res["iterations"], | ||
n_fun_evals=ens_res["evaluate_calls"], | ||
n_jac_evals=ens_res["gradient_calls"], | ||
) | ||
|
||
return res | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
"""Tests for pyensmallen optimizers.""" | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
import optimagic as om | ||
from optimagic.config import IS_PYENSMALLEN_INSTALLED | ||
from optimagic.optimization.optimize import minimize | ||
|
||
|
||
@pytest.mark.skipif(not IS_PYENSMALLEN_INSTALLED, reason="pyensmallen not installed.") | ||
def test_stop_after_one_iteration(): | ||
algo = om.algos.ensmallen_lbfgs(stopping_maxiter=1) | ||
expected = np.array([0, 0.81742581, 1.63485163, 2.45227744, 3.26970326]) | ||
res = minimize( | ||
fun=lambda x: x @ x, | ||
fun_and_jac=lambda x: (x @ x, 2 * x), | ||
params=np.arange(5), | ||
algorithm=algo, | ||
) | ||
|
||
assert np.allclose(res.x, expected) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.