Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test functions package #75

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
153 changes: 43 additions & 110 deletions examples/viz_optimizers.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,16 @@
import math
import matplotlib.pyplot as plt
import numpy as np
import torch_optimizer as optim
import torch
from hyperopt import fmin, tpe, hp
import matplotlib.pyplot as plt


plt.style.use('seaborn-white')


def rosenbrock(tensor):
# https://en.wikipedia.org/wiki/Test_functions_for_optimization
x, y = tensor
return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2
import torch_optimizer as optim
from test_functions import Rastrigin, Rosenbrock

plt.style.use("seaborn-white")

def rastrigin(tensor, lib=torch):
# https://en.wikipedia.org/wiki/Test_functions_for_optimization
x, y = tensor
A = 10
f = (
A * 2
+ (x ** 2 - A * lib.cos(x * math.pi * 2))
+ (y ** 2 - A * lib.cos(y * math.pi * 2))
)
return f


def execute_steps(
func, initial_state, optimizer_class, optimizer_config, num_iter=500
):
def execute_steps(func, initial_state, optimizer_class, optimizer_config, num_iter=500):
""" Execute one steps of the optimizer """
x = torch.Tensor(initial_state).requires_grad_(True)
optimizer = optimizer_class([x], **optimizer_config)
steps = []
Expand All @@ -45,116 +26,77 @@ def execute_steps(
return steps


def objective_rastrigin(params):
lr = params['lr']
optimizer_class = params['optimizer_class']
initial_state = (-2.0, 3.5)
minimum = (0, 0)
optimizer_config = dict(lr=lr)
num_iter = 100
steps = execute_steps(
rastrigin, initial_state, optimizer_class, optimizer_config, num_iter
)
return (steps[0][-1] - minimum[0]) ** 2 + (steps[1][-1] - minimum[1]) ** 2


def objective_rosenbrok(params):
lr = params['lr']
optimizer_class = params['optimizer_class']
minimum = (1.0, 1.0)
initial_state = (-2.0, 2.0)
def objective(params, func):
""" Execute objective """
lr = params["lr"]
optimizer_class = params["optimizer_class"]
minimum = func.minimum
initial_state = func.initial_state
optimizer_config = dict(lr=lr)
num_iter = 100
steps = execute_steps(
rosenbrock, initial_state, optimizer_class, optimizer_config, num_iter
func, initial_state, optimizer_class, optimizer_config, num_iter
)
return (steps[0][-1] - minimum[0]) ** 2 + (steps[1][-1] - minimum[1]) ** 2


def plot_rastrigin(grad_iter, optimizer_name, lr):
x = np.linspace(-4.5, 4.5, 250)
y = np.linspace(-4.5, 4.5, 250)
minimum = (0, 0)
def plot(func, grad_iter, optimizer_name, lr):
""" Plot result of a simulation """
x = torch.linspace(func.x_domain[0], func.x_domain[1], func.num_pt)
y = torch.linspace(func.y_domain[0], func.y_domain[1], func.num_pt)
minimum = func.minimum

X, Y = np.meshgrid(x, y)
Z = rastrigin([X, Y], lib=np)
X, Y = torch.meshgrid(x, y)
Z = func([X, Y])

iter_x, iter_y = grad_iter[0, :], grad_iter[1, :]

fig = plt.figure(figsize=(8, 8))

ax = fig.add_subplot(1, 1, 1)
ax.contour(X, Y, Z, 20, cmap='jet')
ax.plot(iter_x, iter_y, color='r', marker='x')
ax.set_title(
f'Rastrigin func: {optimizer_name} with '
f'{len(iter_x)} iterations, lr={lr:.6}'
)
plt.plot(*minimum, 'gD')
plt.plot(iter_x[-1], iter_y[-1], 'rD')
plt.savefig(f'docs/rastrigin_{optimizer_name}.png')


def plot_rosenbrok(grad_iter, optimizer_name, lr):
x = np.linspace(-2, 2, 250)
y = np.linspace(-1, 3, 250)
minimum = (1.0, 1.0)

X, Y = np.meshgrid(x, y)
Z = rosenbrock([X, Y])

iter_x, iter_y = grad_iter[0, :], grad_iter[1, :]

fig = plt.figure(figsize=(8, 8))

ax = fig.add_subplot(1, 1, 1)
ax.contour(X, Y, Z, 90, cmap='jet')
ax.plot(iter_x, iter_y, color='r', marker='x')
ax.contour(X, Y, Z, func.levels, cmap="jet")
ax.plot(iter_x, iter_y, color="r", marker="x")

func_name = func.__name__()
ax.set_title(
f'Rosenbrock func: {optimizer_name} with {len(iter_x)} '
f'iterations, lr={lr:.6}'
f"{func_name} func: {optimizer_name} with {len(iter_x)} "
f"iterations, lr={lr:.6}"
)
plt.plot(*minimum, 'gD')
plt.plot(iter_x[-1], iter_y[-1], 'rD')
plt.savefig(f'docs/rosenbrock_{optimizer_name}.png')
plt.plot(*minimum, "gD")
plt.plot(iter_x[-1], iter_y[-1], "rD")
plt.savefig(f"docs/{func_name}_{optimizer_name}.png")


def execute_experiments(
optimizers, objective, func, plot_func, initial_state, seed=1
):
def execute_experiments(optimizers, func, seed=1):
""" Execute simulation on a list of optimizers using a test function """
seed = seed
for item in optimizers:
optimizer_class, lr_low, lr_hi = item
space = {
'optimizer_class': hp.choice('optimizer_class', [optimizer_class]),
'lr': hp.loguniform('lr', lr_low, lr_hi),
"optimizer_class": hp.choice("optimizer_class", [optimizer_class]),
"lr": hp.loguniform("lr", lr_low, lr_hi),
}
best = fmin(
fn=objective,
fn=lambda x: objective(x, func),
space=space,
algo=tpe.suggest,
max_evals=200,
rstate=np.random.RandomState(seed),
)
print(best['lr'], optimizer_class)
print(best["lr"], optimizer_class)

steps = execute_steps(
func,
initial_state,
optimizer_class,
{'lr': best['lr']},
num_iter=500,
func, func.initial_state, optimizer_class, {"lr": best["lr"]}, num_iter=500,
)
plot_func(steps, optimizer_class.__name__, best['lr'])
plot(func, steps, optimizer_class.__name__, best["lr"])


if __name__ == '__main__':
if __name__ == "__main__":
# python examples/viz_optimizers.py

# Each optimizer has tweaked search space to produce better plots and
# help to converge on better lr faster.
optimizers = [
list_optimizers = [
# Adam based
(optim.AdaBound, -8, 0.3),
(optim.AdaMod, -8, 0.2),
Expand All @@ -168,18 +110,9 @@ def execute_experiments(
(optim.SGDW, -8, -1.5),
(optim.PID, -8, -1.0),
]
execute_experiments(
optimizers,
objective_rastrigin,
rastrigin,
plot_rastrigin,
(-2.0, 3.5),
)

execute_experiments(
optimizers,
objective_rosenbrok,
rosenbrock,
plot_rosenbrok,
(-2.0, 2.0),
)
for test_func in [Rastrigin(), Rosenbrock()]:
print(f"Test function {test_func.__name__()}")
execute_experiments(
list_optimizers, test_func,
)
2 changes: 2 additions & 0 deletions test_functions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .rosenbrock import Rosenbrock
from .rastrigin import Rastrigin
47 changes: 47 additions & 0 deletions test_functions/rastrigin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
import math

from .test_function import TestFunction


class Rastrigin(TestFunction):
r"""Rastrigin test function.

Example:
>>> import test_functions
>>> rastrigin = test_functions.Rastrigin()
>>> x = torch.linspace(
rastrigin.x_domain[0],
rastrigin.x_domain[1],
rastrigin.num_pt
)
>>> y = torch.linspace(
rastrigin.y_domain[0],
rastrigin.y_domain[1],
rastrigin.num_pt
)
>>> Y, X = torch.meshgrid(x, y)
>>> Z = rastrigin([X, Y])

__ https://en.wikipedia.org/wiki/Test_functions_for_optimization
__ https://en.wikipedia.org/wiki/Rastrigin_function
"""

def __init__(self):
super(Rastrigin, self).__init__(
x_domain=(-4.5, 4.5),
y_domain=(-4.5, 4.5),
minimum=(0, 0),
initial_state=(-2.0, 3.5),
levels=20,
)

def __call__(self, tensor, lib=torch):
x, y = tensor
A = 10
f = (
A * 2
+ (x ** 2 - A * lib.cos(x * math.pi * 2))
+ (y ** 2 - A * lib.cos(y * math.pi * 2))
)
return f
40 changes: 40 additions & 0 deletions test_functions/rosenbrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch

from .test_function import TestFunction


class Rosenbrock(TestFunction):
r"""Rosenbrock test function.

Example:
>>> import test_functions
>>> rosenbrock = test_functions.Rosenbrock()
>>> x = torch.linspace(
rosenbrock.x_domain[0],
rosenbrock.x_domain[1],
rosenbrock.num_pt
)
>>> y = torch.linspace(
rosenbrock.y_domain[0],
rosenbrock.y_domain[1],
rosenbrock.num_pt
)
>>> Y, X = torch.meshgrid(x, y)
>>> Z = rosenbrock([X, Y])

__ https://en.wikipedia.org/wiki/Test_functions_for_optimization
__ https://en.wikipedia.org/wiki/Rosenbrock_function
"""

def __init__(self):
super(Rosenbrock, self).__init__(
x_domain=(-2.0, 2.0),
y_domain=(-1.0, 3.0),
minimum=(1.0, 1.0),
initial_state=(-2.0, 2.0),
levels=90,
)

def __call__(self, tensor, lib=torch):
x, y = tensor
return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2
58 changes: 58 additions & 0 deletions test_functions/test_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
class TestFunction:
r"""Super class to implement a test function.

Arguments:
x_domain: iterable of int for the x search domain
y_domain: iterable of int for the y search domain
minimum: global minimum (x0, y0) of the
function for the given domain
initial_state: starting point
num_pt: number of point in the search domain
levels: determines the number and positions of the contour lines / regions.
"""

def __init__(
self,
x_domain: iter,
y_domain: iter,
minimum: iter,
initial_state: iter,
num_pt: int = 250,
levels: int = 50,
):
self._x_domain = x_domain
self._y_domain = y_domain
self._minimum = minimum
self._initial_state = initial_state
self._num_pt = num_pt
self._levels = levels

def __call__(self, tensor, lib):
raise NotImplementedError

def __name__(self):
return self.__class__.__name__

@property
def x_domain(self):
return self._x_domain

@property
def y_domain(self):
return self._y_domain

@property
def num_pt(self):
return self._num_pt

@property
def minimum(self):
return self._minimum

@property
def initial_state(self):
return self._initial_state

@property
def levels(self):
return self._levels