Skip to content

Commit 359f4da

Browse files
ENH Add FISTA solver (#91)
Co-authored-by: Badr MOUFAD <[email protected]>
1 parent f9ee2e5 commit 359f4da

File tree

6 files changed

+310
-9
lines changed

6 files changed

+310
-9
lines changed

doc/api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ Solvers
7070
:toctree: generated/
7171

7272
AndersonCD
73+
FISTA
7374
GramCD
7475
GroupBCD
7576
MultiTaskBCD

skglm/datafits/single_task.py

+52-8
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from numba import float64
55

66
from skglm.datafits.base import BaseDatafit
7+
from skglm.utils import spectral_norm
78

89

910
class Quadratic(BaseDatafit):
@@ -22,6 +23,10 @@ class Quadratic(BaseDatafit):
2223
The coordinatewise gradient Lipschitz constants. Equal to
2324
norm(X, axis=0) ** 2 / n_samples.
2425
26+
global_lipschitz : float
27+
Global Lipschitz constant. Equal to
28+
norm(X, ord=2) ** 2 / n_samples.
29+
2530
Note
2631
----
2732
The class is jit compiled at fit time using Numba compiler.
@@ -35,6 +40,7 @@ def get_spec(self):
3540
spec = (
3641
('Xty', float64[:]),
3742
('lipschitz', float64[:]),
43+
('global_lipschitz', float64),
3844
)
3945
return spec
4046

@@ -44,14 +50,18 @@ def params_to_dict(self):
4450
def initialize(self, X, y):
4551
self.Xty = X.T @ y
4652
n_features = X.shape[1]
53+
self.global_lipschitz = norm(X, ord=2) ** 2 / len(y)
4754
self.lipschitz = np.zeros(n_features, dtype=X.dtype)
4855
for j in range(n_features):
4956
self.lipschitz[j] = (X[:, j] ** 2).sum() / len(y)
5057

51-
def initialize_sparse(
52-
self, X_data, X_indptr, X_indices, y):
58+
def initialize_sparse(self, X_data, X_indptr, X_indices, y):
5359
n_features = len(X_indptr) - 1
5460
self.Xty = np.zeros(n_features, dtype=X_data.dtype)
61+
62+
self.global_lipschitz = spectral_norm(X_data, X_indptr, X_indices, len(y)) ** 2
63+
self.global_lipschitz /= len(y)
64+
5565
self.lipschitz = np.zeros(n_features, dtype=X_data.dtype)
5666
for j in range(n_features):
5767
nrm2 = 0.
@@ -111,6 +121,10 @@ class Logistic(BaseDatafit):
111121
The coordinatewise gradient Lipschitz constants. Equal to
112122
norm(X, axis=0) ** 2 / (4 * n_samples).
113123
124+
global_lipschitz : float
125+
Global Lipschitz constant. Equal to
126+
norm(X, ord=2) ** 2 / (4 * n_samples).
127+
114128
Note
115129
----
116130
The class is jit compiled at fit time using Numba compiler.
@@ -123,6 +137,7 @@ def __init__(self):
123137
def get_spec(self):
124138
spec = (
125139
('lipschitz', float64[:]),
140+
('global_lipschitz', float64),
126141
)
127142
return spec
128143

@@ -140,9 +155,14 @@ def raw_hessian(self, y, Xw):
140155

141156
def initialize(self, X, y):
142157
self.lipschitz = (X ** 2).sum(axis=0) / (len(y) * 4)
158+
self.global_lipschitz = norm(X, ord=2) ** 2 / (len(y) * 4)
143159

144160
def initialize_sparse(self, X_data, X_indptr, X_indices, y):
145161
n_features = len(X_indptr) - 1
162+
163+
self.global_lipschitz = spectral_norm(X_data, X_indptr, X_indices, len(y)) ** 2
164+
self.global_lipschitz /= 4 * len(y)
165+
146166
self.lipschitz = np.zeros(n_features, dtype=X_data.dtype)
147167
for j in range(n_features):
148168
Xj = X_data[X_indptr[j]:X_indptr[j+1]]
@@ -187,6 +207,11 @@ class QuadraticSVC(BaseDatafit):
187207
----------
188208
lipschitz : array, shape (n_features,)
189209
The coordinatewise gradient Lipschitz constants.
210+
Equal to norm(yXT, axis=0) ** 2.
211+
212+
global_lipschitz : float
213+
Global Lipschitz constant. Equal to
214+
norm(yXT, ord=2) ** 2.
190215
191216
Note
192217
----
@@ -200,6 +225,7 @@ def __init__(self):
200225
def get_spec(self):
201226
spec = (
202227
('lipschitz', float64[:]),
228+
('global_lipschitz', float64),
203229
)
204230
return spec
205231

@@ -209,12 +235,16 @@ def params_to_dict(self):
209235
def initialize(self, yXT, y):
210236
n_features = yXT.shape[1]
211237
self.lipschitz = np.zeros(n_features, dtype=yXT.dtype)
238+
self.global_lipschitz = norm(yXT, ord=2) ** 2
212239
for j in range(n_features):
213240
self.lipschitz[j] = norm(yXT[:, j]) ** 2
214241

215-
def initialize_sparse(
216-
self, yXT_data, yXT_indptr, yXT_indices, y):
242+
def initialize_sparse(self, yXT_data, yXT_indptr, yXT_indices, y):
217243
n_features = len(yXT_indptr) - 1
244+
245+
self.global_lipschitz = spectral_norm(
246+
yXT_data, yXT_indptr, yXT_indices, max(yXT_indices)+1) ** 2
247+
218248
self.lipschitz = np.zeros(n_features, dtype=yXT_data.dtype)
219249
for j in range(n_features):
220250
nrm2 = 0.
@@ -264,8 +294,16 @@ class Huber(BaseDatafit):
264294
265295
Attributes
266296
----------
297+
delta : float
298+
Threshold hyperparameter.
299+
267300
lipschitz : array, shape (n_features,)
268-
The coordinatewise gradient Lipschitz constants.
301+
The coordinatewise gradient Lipschitz constants. Equal to
302+
norm(X, axis=0) ** 2 / n_samples.
303+
304+
global_lipschitz : float
305+
Global Lipschitz constant. Equal to
306+
norm(X, ord=2) ** 2 / n_samples.
269307
270308
Note
271309
----
@@ -279,7 +317,8 @@ def __init__(self, delta):
279317
def get_spec(self):
280318
spec = (
281319
('delta', float64),
282-
('lipschitz', float64[:])
320+
('lipschitz', float64[:]),
321+
('global_lipschitz', float64),
283322
)
284323
return spec
285324

@@ -289,12 +328,17 @@ def params_to_dict(self):
289328
def initialize(self, X, y):
290329
n_features = X.shape[1]
291330
self.lipschitz = np.zeros(n_features, dtype=X.dtype)
331+
self.global_lipschitz = 0.
292332
for j in range(n_features):
293333
self.lipschitz[j] = (X[:, j] ** 2).sum() / len(y)
334+
self.global_lipschitz += (X[:, j] ** 2).sum() / len(y)
294335

295-
def initialize_sparse(
296-
self, X_data, X_indptr, X_indices, y):
336+
def initialize_sparse(self, X_data, X_indptr, X_indices, y):
297337
n_features = len(X_indptr) - 1
338+
339+
self.global_lipschitz = spectral_norm(X_data, X_indptr, X_indices, len(y)) ** 2
340+
self.global_lipschitz /= len(y)
341+
298342
self.lipschitz = np.zeros(n_features, dtype=X_data.dtype)
299343
for j in range(n_features):
300344
nrm2 = 0.

skglm/solvers/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from .anderson_cd import AndersonCD
22
from .base import BaseSolver
3+
from .fista import FISTA
34
from .gram_cd import GramCD
45
from .group_bcd import GroupBCD
56
from .multitask_bcd import MultiTaskBCD
67
from .prox_newton import ProxNewton
78

89

9-
__all__ = [AndersonCD, BaseSolver, GramCD, GroupBCD, MultiTaskBCD, ProxNewton]
10+
__all__ = [AndersonCD, BaseSolver, FISTA, GramCD, GroupBCD, MultiTaskBCD, ProxNewton]

skglm/solvers/fista.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import numpy as np
2+
from scipy.sparse import issparse
3+
from skglm.solvers.base import BaseSolver
4+
from skglm.solvers.common import construct_grad, construct_grad_sparse
5+
from skglm.utils import _prox_vec
6+
7+
8+
class FISTA(BaseSolver):
9+
r"""ISTA solver with Nesterov acceleration (FISTA).
10+
11+
Attributes
12+
----------
13+
max_iter : int, default 100
14+
Maximum number of iterations.
15+
16+
tol : float, default 1e-4
17+
Tolerance for convergence.
18+
19+
verbose : bool, default False
20+
Amount of verbosity. 0/False is silent.
21+
22+
References
23+
----------
24+
.. [1] Beck, A. and Teboulle M.
25+
"A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse
26+
problems", 2009, SIAM J. Imaging Sci.
27+
https://epubs.siam.org/doi/10.1137/080716542
28+
"""
29+
30+
def __init__(self, max_iter=100, tol=1e-4, verbose=0):
31+
self.max_iter = max_iter
32+
self.tol = tol
33+
self.verbose = verbose
34+
35+
def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
36+
p_objs_out = []
37+
n_samples, n_features = X.shape
38+
all_features = np.arange(n_features)
39+
t_new = 1.
40+
41+
w = w_init.copy() if w_init is not None else np.zeros(n_features)
42+
z = w_init.copy() if w_init is not None else np.zeros(n_features)
43+
Xw = Xw_init.copy() if Xw_init is not None else np.zeros(n_samples)
44+
45+
if hasattr(datafit, "global_lipschitz"):
46+
lipschitz = datafit.global_lipschitz
47+
else:
48+
# TODO: OR line search
49+
raise Exception("Line search is not yet implemented for FISTA solver.")
50+
51+
for n_iter in range(self.max_iter):
52+
t_old = t_new
53+
t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2
54+
w_old = w.copy()
55+
if issparse(X):
56+
grad = construct_grad_sparse(
57+
X.data, X.indptr, X.indices, y, z, X @ z, datafit, all_features)
58+
else:
59+
grad = construct_grad(X, y, z, X @ z, datafit, all_features)
60+
61+
step = 1 / lipschitz
62+
z -= step * grad
63+
w = _prox_vec(w, z, penalty, step)
64+
Xw = X @ w
65+
z = w + (t_old - 1.) / t_new * (w - w_old)
66+
67+
opt = penalty.subdiff_distance(w, grad, all_features)
68+
stop_crit = np.max(opt)
69+
70+
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
71+
p_objs_out.append(p_obj)
72+
if self.verbose:
73+
print(
74+
f"Iteration {n_iter+1}: {p_obj:.10f}, "
75+
f"stopping crit: {stop_crit:.2e}"
76+
)
77+
78+
if stop_crit < self.tol:
79+
if self.verbose:
80+
print(f"Stopping criterion max violation: {stop_crit:.2e}")
81+
break
82+
return w, np.array(p_objs_out), stop_crit

skglm/tests/test_fista.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import pytest
2+
3+
import numpy as np
4+
from numpy.linalg import norm
5+
6+
import scipy.sparse
7+
import scipy.sparse.linalg
8+
from scipy.sparse import csc_matrix, issparse
9+
10+
from skglm.penalties import L1, IndicatorBox
11+
from skglm.solvers import FISTA, AndersonCD
12+
from skglm.datafits import Quadratic, Logistic, QuadraticSVC
13+
from skglm.utils import make_correlated_data, compiled_clone, spectral_norm
14+
15+
16+
random_state = 113
17+
n_samples, n_features = 50, 60
18+
19+
rng = np.random.RandomState(random_state)
20+
X, y, _ = make_correlated_data(n_samples, n_features, random_state=rng)
21+
rng.seed(random_state)
22+
X_sparse = csc_matrix(X * np.random.binomial(1, 0.5, X.shape))
23+
y_classif = np.sign(y)
24+
25+
alpha_max = norm(X.T @ y, ord=np.inf) / len(y)
26+
alpha = alpha_max / 10
27+
28+
tol = 1e-10
29+
30+
31+
@pytest.mark.parametrize("X", [X, X_sparse])
32+
@pytest.mark.parametrize("Datafit, Penalty", [
33+
(Quadratic, L1),
34+
(Logistic, L1),
35+
(QuadraticSVC, IndicatorBox),
36+
])
37+
def test_fista_solver(X, Datafit, Penalty):
38+
_y = y if isinstance(Datafit, Quadratic) else y_classif
39+
datafit = compiled_clone(Datafit())
40+
_init = y @ X.T if isinstance(Datafit, QuadraticSVC) else X
41+
if issparse(X):
42+
datafit.initialize_sparse(_init.data, _init.indptr, _init.indices, _y)
43+
else:
44+
datafit.initialize(_init, _y)
45+
penalty = compiled_clone(Penalty(alpha))
46+
47+
solver = FISTA(max_iter=1000, tol=tol)
48+
w_fista = solver.solve(X, _y, datafit, penalty)[0]
49+
50+
solver_cd = AndersonCD(tol=tol, fit_intercept=False)
51+
w_cd = solver_cd.solve(X, _y, datafit, penalty)[0]
52+
53+
np.testing.assert_allclose(w_fista, w_cd, atol=1e-7)
54+
55+
56+
def test_spectral_norm():
57+
n_samples, n_features = 50, 60
58+
A_sparse = scipy.sparse.random(n_samples, n_features, density=0.7, format='csc',
59+
random_state=random_state)
60+
61+
A_bundles = (A_sparse.data, A_sparse.indptr, A_sparse.indices)
62+
spectral_norm_our = spectral_norm(*A_bundles, n_samples=len(y))
63+
spectral_norm_sp = scipy.sparse.linalg.svds(A_sparse, k=1)[1]
64+
65+
np.testing.assert_allclose(spectral_norm_our, spectral_norm_sp)
66+
67+
68+
if __name__ == '__main__':
69+
pass

0 commit comments

Comments
 (0)