Skip to content

Commit a4362ba

Browse files
authored
ENH Add SLOPE penalty (#92)
1 parent 92c2a45 commit a4362ba

File tree

7 files changed

+167
-9
lines changed

7 files changed

+167
-9
lines changed

.github/workflows/main.yml

+1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,6 @@ jobs:
2525
pip install numpydoc
2626
pip install .
2727
pip install statsmodels cvxopt
28+
pip install git+https://github.com/jolars/pyslope.git
2829
- name: Test with pytest
2930
run: pytest -v skglm/

doc/api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ Penalties
4343
WeightedGroupL2
4444
SCAD
4545
BlockSCAD
46+
SLOPE
4647

4748

4849
Datafits

skglm/penalties/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
L2_05, L2_1, BlockMCPenalty, BlockSCAD, WeightedGroupL2
77
)
88

9+
from .non_separable import SLOPE
10+
911

1012
__all__ = [
1113
BasePenalty,
1214
L1_plus_L2, L0_5, L1, L2_3, MCPenalty, SCAD, WeightedL1, IndicatorBox,
13-
L2_05, L2_1, BlockMCPenalty, BlockSCAD, WeightedGroupL2
15+
L2_05, L2_1, BlockMCPenalty, BlockSCAD, WeightedGroupL2, SLOPE
1416
]

skglm/penalties/non_separable.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import numpy as np
2+
from numba import float64
3+
4+
from skglm.penalties.base import BasePenalty
5+
from skglm.utils import prox_SLOPE
6+
7+
8+
class SLOPE(BasePenalty):
9+
"""Sorted L-One Penalized Estimation (SLOPE) penalty.
10+
11+
Attributes
12+
----------
13+
alphas : array, shape (n_features,)
14+
Contain regularization levels for every feature.
15+
When ``alphas`` contain a single unique value, ``SLOPE``
16+
is equivalent to the ``L1``penalty.
17+
18+
References
19+
----------
20+
.. [1] M. Bogdan, E. van den Berg, C. Sabatti, W. Su, E. Candes
21+
"SLOPE - Adaptive Variable Selection via Convex Optimization",
22+
The Annals of Applied Statistics 9 (3): 1103-40
23+
https://doi.org/10.1214/15-AOAS842
24+
"""
25+
26+
def __init__(self, alphas):
27+
self.alphas = alphas
28+
29+
def get_spec(self):
30+
spec = (
31+
('alphas', float64[:]),
32+
)
33+
return spec
34+
35+
def params_to_dict(self):
36+
return dict(alphas=self.alphas)
37+
38+
def value(self, w):
39+
"""Compute the value of SLOPE at w."""
40+
return np.sum(np.sort(np.abs(w)) * self.alphas[::-1])
41+
42+
def prox_vec(self, x, stepsize):
43+
def _prox(_x, _alphas):
44+
sign_x = np.sign(_x)
45+
_x = np.abs(_x)
46+
sorted_indices = np.argsort(_x)[::-1]
47+
prox = prox_SLOPE(_x[sorted_indices], _alphas)
48+
prox[sorted_indices] = prox
49+
return prox * sign_x
50+
return _prox(x, self.alphas * stepsize)
51+
52+
def prox_1d(self, value, stepsize, j):
53+
raise ValueError(
54+
"No coordinate-wise proximal operator for SLOPE. Use `prox_vec` instead."
55+
)
56+
57+
def subdiff_distance(self, w, grad, ws):
58+
return ValueError(
59+
"No subdifferential distance for SLOPE. Use `opt_strategy='fixpoint'`"
60+
)

skglm/solvers/fista.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,13 @@ class FISTA(BaseSolver):
2727
https://epubs.siam.org/doi/10.1137/080716542
2828
"""
2929

30-
def __init__(self, max_iter=100, tol=1e-4, verbose=0):
30+
def __init__(self, max_iter=100, tol=1e-4, opt_strategy="subdiff", verbose=0):
3131
self.max_iter = max_iter
3232
self.tol = tol
3333
self.verbose = verbose
34+
self.opt_strategy = opt_strategy
35+
self.fit_intercept = False # needed to be passed to GeneralizedLinearEstimator
36+
self.warm_start = False
3437

3538
def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
3639
p_objs_out = []
@@ -60,11 +63,22 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
6063

6164
step = 1 / lipschitz
6265
z -= step * grad
63-
w = _prox_vec(w, z, penalty, step)
66+
if hasattr(penalty, "prox_vec"):
67+
w = penalty.prox_vec(z, step)
68+
else:
69+
w = _prox_vec(w, z, penalty, step)
6470
Xw = X @ w
6571
z = w + (t_old - 1.) / t_new * (w - w_old)
6672

67-
opt = penalty.subdiff_distance(w, grad, all_features)
73+
if self.opt_strategy == "subdiff":
74+
opt = penalty.subdiff_distance(w, grad, all_features)
75+
elif self.opt_strategy == "fixpoint":
76+
opt = np.abs(w - penalty.prox_vec(w - grad / lipschitz, 1 / lipschitz))
77+
else:
78+
raise ValueError(
79+
"Unknown error optimality strategy. Expected "
80+
f"`subdiff` or `fixpoint`. Got {self.opt_strategy}")
81+
6882
stop_crit = np.max(opt)
6983

7084
p_obj = datafit.value(y, w, Xw) + penalty.value(w)

skglm/tests/test_penalties.py

+38-5
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66

77
from skglm.datafits import Quadratic, QuadraticMultiTask
88
from skglm.penalties import (
9-
L1, L1_plus_L2, WeightedL1, MCPenalty, SCAD, IndicatorBox, L0_5, L2_3,
9+
L1, L1_plus_L2, WeightedL1, MCPenalty, SCAD, IndicatorBox, L0_5, L2_3, SLOPE,
1010
L2_1, L2_05, BlockMCPenalty, BlockSCAD)
11-
from skglm import GeneralizedLinearEstimator
12-
from skglm.solvers import AndersonCD, MultiTaskBCD
11+
from skglm import GeneralizedLinearEstimator, Lasso
12+
from skglm.solvers import AndersonCD, MultiTaskBCD, FISTA
1313
from skglm.utils import make_correlated_data
1414

1515

@@ -25,6 +25,8 @@
2525
alpha_max = norm(X.T @ y, ord=np.inf) / n_samples
2626
alpha = alpha_max / 1000
2727

28+
tol = 1e-10
29+
2830
penalties = [
2931
L1(alpha=alpha),
3032
L1_plus_L2(alpha=alpha, l1_ratio=0.5),
@@ -44,7 +46,6 @@
4446

4547
@pytest.mark.parametrize('penalty', penalties)
4648
def test_subdiff_diff(penalty):
47-
tol = 1e-10
4849
# tol=1e-14 is too low when coefs are of order 1. square roots are computed in
4950
# some penalties and precision is lost
5051
est = GeneralizedLinearEstimator(
@@ -58,7 +59,6 @@ def test_subdiff_diff(penalty):
5859

5960
@pytest.mark.parametrize('block_penalty', block_penalties)
6061
def test_subdiff_diff_block(block_penalty):
61-
tol = 1e-10 # see test_subdiff_dist
6262
est = GeneralizedLinearEstimator(
6363
datafit=QuadraticMultiTask(),
6464
penalty=block_penalty,
@@ -68,5 +68,38 @@ def test_subdiff_diff_block(block_penalty):
6868
assert_array_less(est.stop_crit_, est.solver.tol)
6969

7070

71+
def test_slope_lasso():
72+
# check that when alphas = [alpha, ..., alpha], SLOPE and L1 solutions are equal
73+
alphas = np.full(n_features, alpha)
74+
est = GeneralizedLinearEstimator(
75+
penalty=SLOPE(alphas),
76+
solver=FISTA(max_iter=1000, tol=tol, opt_strategy="fixpoint"),
77+
).fit(X, y)
78+
lasso = Lasso(alpha, fit_intercept=False, tol=tol).fit(X, y)
79+
np.testing.assert_allclose(est.coef_, lasso.coef_, rtol=1e-5)
80+
81+
82+
def test_slope():
83+
# compare solutions with `pyslope`: https://github.com/jolars/pyslope
84+
try:
85+
from slope.solvers import pgd_slope # noqa
86+
from slope.utils import lambda_sequence # noqa
87+
except ImportError:
88+
pytest.xfail(
89+
"This test requires slope to run.\n"
90+
"https://github.com/jolars/pyslope")
91+
92+
q = 0.1
93+
alphas = lambda_sequence(
94+
X, y, fit_intercept=False, reg=alpha / alpha_max, q=q)
95+
ours = GeneralizedLinearEstimator(
96+
penalty=SLOPE(alphas),
97+
solver=FISTA(max_iter=1000, tol=tol, opt_strategy="fixpoint"),
98+
).fit(X, y)
99+
pyslope_out = pgd_slope(
100+
X, y, alphas, fit_intercept=False, max_it=1000, gap_tol=tol)
101+
np.testing.assert_allclose(ours.coef_, pyslope_out["beta"], rtol=1e-5)
102+
103+
71104
if __name__ == "__main__":
72105
pass

skglm/utils.py

+47
Original file line numberDiff line numberDiff line change
@@ -559,5 +559,52 @@ def _XT_dot_vec(X_data, X_indptr, X_indices, vec):
559559
return result
560560

561561

562+
@njit
563+
def prox_SLOPE(z, alphas):
564+
"""Fast computation for proximal operator of SLOPE.
565+
566+
Extracted from:
567+
https://github.com/agisga/grpSLOPE/blob/master/src/proxSortedL1.c
568+
569+
Parameters
570+
----------
571+
z : array, shape (n_features,)
572+
Non-negative coefficient vector sorted in non-increasing order.
573+
574+
alphas : array, shape (n_features,)
575+
Regularization hyperparameter sorted in non-increasing order.
576+
"""
577+
n_features = z.shape[0]
578+
x = np.empty(n_features)
579+
580+
k = 0
581+
idx_i = np.empty((n_features,), dtype=np.int64)
582+
idx_j = np.empty((n_features,), dtype=np.int64)
583+
s = np.empty((n_features,), dtype=np.float64)
584+
w = np.empty((n_features,), dtype=np.float64)
585+
586+
for i in range(n_features):
587+
idx_i[k] = i
588+
idx_j[k] = i
589+
s[k] = z[i] - alphas[i]
590+
w[k] = s[k]
591+
592+
while k > 0 and w[k - 1] <= w[k]:
593+
k -= 1
594+
idx_j[k] = i
595+
s[k] += s[k+1]
596+
w[k] = s[k] / (i - idx_i[k] + 1)
597+
598+
k += 1
599+
600+
for j in range(k):
601+
d = w[j]
602+
d = 0 if d < 0 else d
603+
for i in range(idx_i[j], idx_j[j] + 1):
604+
x[i] = d
605+
606+
return x
607+
608+
562609
if __name__ == '__main__':
563610
pass

0 commit comments

Comments
 (0)