Skip to content

Commit d6ab8c2

Browse files
authored
FEAT implement sparse group lasso penalty and ws_strategy="fixpoint" for BCD (#267)
1 parent 776afdf commit d6ab8c2

9 files changed

+315
-14
lines changed

examples/plot_group_logistic_regression.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,4 @@
4141

4242
# %%
4343
# Fit check that groups are either all 0 or all non zero
44-
print(clf.coef_.reshape(-1, grp_size))
44+
print(clf.coef_.reshape(-1, grp_size))

examples/plot_logreg_various_penalties.py

-3
Original file line numberDiff line numberDiff line change
@@ -55,23 +55,20 @@
5555
clf_enet.coef_[clf_enet.coef_ != 0],
5656
markerfmt="x",
5757
label="Elastic net coefficients",
58-
use_line_collection=True,
5958
)
6059
plt.setp([m, s], color="#2ca02c")
6160
m, s, _ = plt.stem(
6261
np.where(clf_mcp.coef_.ravel())[0],
6362
clf_mcp.coef_[clf_mcp.coef_ != 0],
6463
markerfmt="x",
6564
label="MCP coefficients",
66-
use_line_collection=True,
6765
)
6866
plt.setp([m, s], color="#ff7f0e")
6967
plt.stem(
7068
np.where(w_star)[0],
7169
w_star[w_star != 0],
7270
label="true coefficients",
7371
markerfmt="bx",
74-
use_line_collection=True,
7572
)
7673

7774
plt.legend(loc="best")

examples/plot_sparse_group_lasso.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""
2+
=================================
3+
Fast Sparse Group Lasso in python
4+
=================================
5+
Scikit-learn is missing a Sparse Group Lasso regression estimator. We show how to
6+
implement one with ``skglm``.
7+
"""
8+
9+
# Author: Mathurin Massias
10+
11+
# %%
12+
import numpy as np
13+
import matplotlib.pyplot as plt
14+
15+
from skglm.solvers import GroupBCD
16+
from skglm.datafits import QuadraticGroup
17+
from skglm import GeneralizedLinearEstimator
18+
from skglm.penalties import WeightedL1GroupL2
19+
from skglm.utils.data import make_correlated_data, grp_converter
20+
21+
n_features = 30
22+
X, y, _ = make_correlated_data(
23+
n_samples=10, n_features=30, random_state=0)
24+
25+
26+
# %%
27+
# Model creation: combination of penalty, datafit and solver.
28+
#
29+
# penalty:
30+
grp_size = 10 # take groups of 10 consecutive features
31+
n_groups = n_features // grp_size
32+
grp_indices, grp_ptr = grp_converter(grp_size, n_features)
33+
n_groups = len(grp_ptr) - 1
34+
weights_g = np.ones(n_groups, dtype=np.float64)
35+
weights_f = 0.5 * np.ones(n_features)
36+
penalty = WeightedL1GroupL2(
37+
alpha=0.5, weights_groups=weights_g,
38+
weights_features=weights_f, grp_indices=grp_indices, grp_ptr=grp_ptr)
39+
40+
# %% Datafit and solver
41+
datafit = QuadraticGroup(grp_ptr, grp_indices)
42+
solver = GroupBCD(ws_strategy="fixpoint", verbose=1, fit_intercept=False, tol=1e-10)
43+
44+
model = GeneralizedLinearEstimator(datafit, penalty, solver=solver)
45+
46+
# %%
47+
# Train the model
48+
clf = GeneralizedLinearEstimator(datafit, penalty, solver)
49+
clf.fit(X, y)
50+
51+
# %%
52+
# Some groups are fully 0, and inside non zero groups,
53+
# some values are 0 too
54+
plt.imshow(clf.coef_.reshape(-1, grp_size) != 0, cmap='Greys')
55+
plt.title("Non zero values (in black) in model coefficients")
56+
plt.ylabel('Group index')
57+
plt.xlabel('Feature index inside group')
58+
plt.xticks(np.arange(grp_size))
59+
plt.yticks(np.arange(n_groups));
60+
61+
# %%

skglm/penalties/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
WeightedL1, IndicatorBox, PositiveConstraint, LogSumPenalty
55
)
66
from .block_separable import (
7-
L2_05, L2_1, BlockMCPenalty, BlockSCAD, WeightedGroupL2
7+
L2_05, L2_1, BlockMCPenalty, BlockSCAD, WeightedGroupL2, WeightedL1GroupL2
88
)
99

1010
from .non_separable import SLOPE
@@ -14,5 +14,5 @@
1414
BasePenalty,
1515
L1_plus_L2, L0_5, L1, L2, L2_3, MCPenalty, WeightedMCPenalty, SCAD, WeightedL1,
1616
IndicatorBox, PositiveConstraint, L2_05, L2_1, BlockMCPenalty, BlockSCAD,
17-
WeightedGroupL2, SLOPE, LogSumPenalty
17+
WeightedGroupL2, WeightedL1GroupL2, SLOPE, LogSumPenalty
1818
]

skglm/penalties/block_separable.py

+107-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from skglm.penalties.base import BasePenalty
88
from skglm.utils.prox_funcs import (
9-
BST, prox_block_2_05, prox_SCAD, value_SCAD, prox_MCP, value_MCP)
9+
BST, ST_vec, prox_block_2_05, prox_SCAD, value_SCAD, prox_MCP, value_MCP)
1010

1111

1212
class L2_1(BasePenalty):
@@ -382,3 +382,109 @@ def generalized_support(self, w):
382382
gsupp[g] = True
383383

384384
return gsupp
385+
386+
387+
class WeightedL1GroupL2(BasePenalty):
388+
r"""Weighted Group L2 penalty, aka sparse group Lasso.
389+
390+
The penalty reads
391+
392+
.. math::
393+
sum_{g=1}^{n_"groups"} "weights"^1_g ||w_{[g]}|| +
394+
sum_{j=1}^{n_"features"} "weights"^2_j ||w_{j}||
395+
396+
with :math:`w_{[g]}` being the coefficients of the g-th group and
397+
398+
Attributes
399+
----------
400+
alpha : float
401+
The regularization parameter.
402+
403+
weights_groups : array, shape (n_groups,)
404+
The penalization weights of the groups.
405+
406+
weights_features : array, shape (n_features,)
407+
The penalization weights of the features.
408+
409+
grp_indices : array, shape (n_features,)
410+
The group indices stacked contiguously
411+
([grp1_indices, grp2_indices, ...]).
412+
413+
grp_ptr : array, shape (n_groups + 1,)
414+
The group pointers such that two consecutive elements delimit
415+
the indices of a group in ``grp_indices``.
416+
417+
"""
418+
419+
def __init__(
420+
self, alpha, weights_groups, weights_features, grp_ptr, grp_indices):
421+
self.alpha = alpha
422+
self.grp_ptr, self.grp_indices = grp_ptr, grp_indices
423+
self.weights_groups = weights_groups
424+
self.weights_features = weights_features
425+
426+
def get_spec(self):
427+
spec = (
428+
('alpha', float64),
429+
('weights_groups', float64[:]),
430+
('weights_features', float64[:]),
431+
('grp_ptr', int32[:]),
432+
('grp_indices', int32[:]),
433+
)
434+
return spec
435+
436+
def params_to_dict(self):
437+
return dict(alpha=self.alpha, weights_features=self.weights_features,
438+
weights_groups=self.weights_groups, grp_ptr=self.grp_ptr,
439+
grp_indices=self.grp_indices)
440+
441+
def value(self, w):
442+
"""Value of penalty at vector ``w``."""
443+
grp_ptr, grp_indices = self.grp_ptr, self.grp_indices
444+
n_grp = len(grp_ptr) - 1
445+
446+
sum_penalty = 0.
447+
for g in range(n_grp):
448+
grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]
449+
w_g = w[grp_g_indices]
450+
451+
sum_penalty += self.weights_groups[g] * norm(w_g)
452+
sum_penalty += np.sum(self.weights_features * np.abs(w))
453+
454+
return self.alpha * sum_penalty
455+
456+
def prox_1group(self, value, stepsize, g):
457+
"""Compute the proximal operator of group ``g``."""
458+
res = ST_vec(value, self.alpha * stepsize * self.weights_features[g])
459+
return BST(res, self.alpha * stepsize * self.weights_groups[g])
460+
461+
def subdiff_distance(self, w, grad_ws, ws):
462+
"""Compute distance to the subdifferential at ``w`` of negative gradient.
463+
464+
Refer to :ref:`subdiff_positive_group_lasso` for details of the derivation.
465+
466+
Note:
467+
----
468+
``grad_ws`` is a stacked array of gradients ``[grad_ws_1, grad_ws_2, ...]``.
469+
"""
470+
raise NotImplementedError("Too hard for now")
471+
472+
def is_penalized(self, n_groups):
473+
return np.ones(n_groups, dtype=np.bool_)
474+
475+
def generalized_support(self, w):
476+
grp_indices, grp_ptr = self.grp_indices, self.grp_ptr
477+
n_groups = len(grp_ptr) - 1
478+
is_penalized = self.is_penalized(n_groups)
479+
480+
gsupp = np.zeros(n_groups, dtype=np.bool_)
481+
for g in range(n_groups):
482+
if not is_penalized[g]:
483+
gsupp[g] = True
484+
continue
485+
486+
grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]
487+
if np.any(w[grp_g_indices]):
488+
gsupp[g] = True
489+
490+
return gsupp

skglm/solvers/anderson_cd.py

+2
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
103103
# The intercept is not taken into account in the optimality conditions since
104104
# the derivative w.r.t. to the intercept may be very large. It is not likely
105105
# to change significantly the optimality conditions.
106+
# TODO: MM I don't understand the comment above: the intercept is
107+
# taken into account intercept_opt 6 lines below
106108
if self.ws_strategy == "subdiff":
107109
opt = penalty.subdiff_distance(w[:n_features], grad, all_feats)
108110
elif self.ws_strategy == "fixpoint":

skglm/solvers/common.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import numpy as np
22
from numba import njit
3+
from numpy.linalg import norm
34

45

56
@njit
67
def dist_fix_point_cd(w, grad_ws, lipschitz_ws, datafit, penalty, ws):
7-
"""Compute the violation of the fixed point iterate scheme.
8+
"""Compute the violation of the fixed point iterate scheme for CD.
89
910
Parameters
1011
----------
@@ -44,6 +45,60 @@ def dist_fix_point_cd(w, grad_ws, lipschitz_ws, datafit, penalty, ws):
4445
return dist
4546

4647

48+
@njit
49+
def dist_fix_point_bcd(
50+
w, grad_ws, lipschitz_ws, datafit, penalty, ws):
51+
"""Compute the violation of the fixed point iterate scheme for BCD.
52+
53+
Parameters
54+
----------
55+
w : array, shape (n_features,)
56+
Coefficient vector.
57+
58+
grad_ws : array, shape (ws_size,)
59+
Gradient restricted to the working set.
60+
61+
lipschitz_ws : array, shape (len(ws),)
62+
Coordinatewise gradient Lipschitz constants, restricted to working set.
63+
64+
datafit: instance of BaseDatafit
65+
Datafit.
66+
67+
penalty: instance of BasePenalty
68+
Penalty.
69+
70+
ws : array, shape (len(ws),)
71+
The working set.
72+
73+
Returns
74+
-------
75+
dist : array, shape (n_groups,)
76+
Violation score for every group.
77+
78+
Note:
79+
----
80+
``grad_ws`` is a stacked array of gradients ``[grad_ws_1, grad_ws_2, ...]``.
81+
"""
82+
n_groups = len(penalty.grp_ptr) - 1
83+
dist = np.zeros(n_groups, dtype=w.dtype)
84+
85+
grad_ptr = 0
86+
for idx, g in enumerate(ws):
87+
if lipschitz_ws[idx] == 0.:
88+
continue
89+
grp_g_indices = penalty.grp_indices[penalty.grp_ptr[g]: penalty.grp_ptr[g+1]]
90+
91+
grad_g = grad_ws[grad_ptr: grad_ptr + len(grp_g_indices)]
92+
grad_ptr += len(grp_g_indices)
93+
94+
step_g = 1 / lipschitz_ws[idx]
95+
w_g = w[grp_g_indices]
96+
dist[idx] = norm(
97+
w_g - penalty.prox_1group(w_g - grad_g * step_g, step_g, g)
98+
)
99+
return dist
100+
101+
47102
@njit
48103
def construct_grad(X, y, w, Xw, datafit, ws):
49104
"""Compute the gradient of the datafit restricted to the working set.

skglm/solvers/group_bcd.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from skglm.solvers.base import BaseSolver
66
from skglm.utils.anderson import AndersonAcceleration
77
from skglm.utils.validation import check_group_compatible
8+
from skglm.solvers.common import dist_fix_point_bcd
89

910

1011
class GroupBCD(BaseSolver):
@@ -36,17 +37,22 @@ class GroupBCD(BaseSolver):
3637
Amount of verbosity. 0/False is silent.
3738
"""
3839

39-
def __init__(self, max_iter=1000, max_epochs=100, p0=10, tol=1e-4,
40-
fit_intercept=False, warm_start=False, verbose=0):
40+
def __init__(
41+
self, max_iter=1000, max_epochs=100, p0=10, tol=1e-4, fit_intercept=False,
42+
warm_start=False, ws_strategy="subdiff", verbose=0):
4143
self.max_iter = max_iter
4244
self.max_epochs = max_epochs
4345
self.p0 = p0
4446
self.tol = tol
4547
self.fit_intercept = fit_intercept
4648
self.warm_start = warm_start
49+
self.ws_strategy = ws_strategy
4750
self.verbose = verbose
4851

4952
def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
53+
if self.ws_strategy not in ("subdiff", "fixpoint"):
54+
raise ValueError(
55+
'Unsupported value for self.ws_strategy:', self.ws_strategy)
5056
check_group_compatible(datafit)
5157
check_group_compatible(penalty)
5258

@@ -86,7 +92,14 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
8692
X.data, X.indptr, X.indices, y, w, Xw, datafit, all_groups)
8793
else:
8894
grad = _construct_grad(X, y, w, Xw, datafit, all_groups)
89-
opt = penalty.subdiff_distance(w, grad, all_groups)
95+
96+
if self.ws_strategy == "subdiff":
97+
# MM TODO: AndersonCD passes w[:n_features] here
98+
opt = penalty.subdiff_distance(w, grad, all_groups)
99+
elif self.ws_strategy == "fixpoint":
100+
opt = dist_fix_point_bcd(
101+
w[:n_features], grad, lipschitz, datafit, penalty, all_groups
102+
)
90103

91104
if self.fit_intercept:
92105
intercept_opt = np.abs(datafit.intercept_update_step(y, Xw))
@@ -144,8 +157,15 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
144157
else:
145158
grad_ws = _construct_grad(X, y, w, Xw, datafit, ws)
146159

147-
opt_in = penalty.subdiff_distance(w, grad_ws, ws)
148-
stop_crit_in = np.max(opt_in)
160+
if self.ws_strategy == "subdiff":
161+
# TODO MM: AndersonCD uses w[:n_features] here
162+
opt_ws = penalty.subdiff_distance(w, grad_ws, ws)
163+
elif self.ws_strategy == "fixpoint":
164+
opt_ws = dist_fix_point_bcd(
165+
w, grad_ws, lipschitz[ws], datafit, penalty, ws
166+
)
167+
168+
stop_crit_in = np.max(opt_ws)
149169

150170
if max(self.verbose - 1, 0):
151171
p_obj = datafit.value(y, w, Xw) + penalty.value(w)

0 commit comments

Comments
 (0)