Skip to content

Commit 07c49bb

Browse files
PABannierBadr-MOUFADQB3mathurinm
authored andcommitted
ENH - check datafit + penalty compatibility with solver (scikit-learn-contrib#137)
Co-authored-by: Badr-MOUFAD <[email protected]> Co-authored-by: Badr MOUFAD <[email protected]> Co-authored-by: Quentin Bertrand <[email protected]> Co-authored-by: mathurinm <[email protected]> Co-authored-by: mathurinm <[email protected]>
1 parent 3260ebd commit 07c49bb

19 files changed

+408
-84
lines changed

doc/changes/0.4.rst

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Version 0.4 (in progress)
44
-------------------------
55
- Add :ref:`GroupLasso Estimator <skglm.GroupLasso>` (PR: :gh:`228`)
66
- Add support and tutorial for positive coefficients to :ref:`Group Lasso Penalty <skglm.penalties.WeightedGroupL2>` (PR: :gh:`221`)
7+
- Check compatibility with datafit and penalty in solver (PR :gh:`137`)
78
- Add support to weight samples in the quadratic datafit :ref:`Weighted Quadratic Datafit <skglm.datafit.WeightedQuadratic>` (PR: :gh:`258`)
89

910

skglm/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = '0.3.2dev'
1+
__version__ = '0.4dev'
22

33
from skglm.estimators import ( # noqa F401
44
Lasso, WeightedLasso, ElasticNet, MCPRegression, MultiTaskLasso, LinearSVC,

skglm/datafits/single_task.py

-3
Original file line numberDiff line numberDiff line change
@@ -661,9 +661,6 @@ def value(self, y, w, Xw):
661661
def gradient_scalar(self, X, y, w, Xw, j):
662662
return X[:, j] @ (1 - y * np.exp(-Xw)) / len(y)
663663

664-
def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j):
665-
pass
666-
667664
def intercept_update_step(self, y, Xw):
668665
return np.sum(self.raw_grad(y, Xw))
669666

skglm/experimental/pdcd_ws.py

+11-26
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
from scipy.sparse import issparse
66

77
from numba import njit
8-
from skglm.utils.jit_compilation import compiled_clone
8+
from skglm.solvers import BaseSolver
9+
910
from sklearn.exceptions import ConvergenceWarning
1011

1112

12-
class PDCD_WS:
13+
class PDCD_WS(BaseSolver):
1314
r"""Primal-Dual Coordinate Descent solver with working sets.
1415
1516
It solves
@@ -78,6 +79,9 @@ class PDCD_WS:
7879
https://arxiv.org/abs/2204.07826
7980
"""
8081

82+
_datafit_required_attr = ('prox_conjugate',)
83+
_penalty_required_attr = ("prox_1d",)
84+
8185
def __init__(self, max_iter=1000, max_epochs=1000, dual_init=None,
8286
p0=100, tol=1e-6, verbose=False):
8387
self.max_iter = max_iter
@@ -87,11 +91,7 @@ def __init__(self, max_iter=1000, max_epochs=1000, dual_init=None,
8791
self.tol = tol
8892
self.verbose = verbose
8993

90-
def solve(self, X, y, datafit_, penalty_, w_init=None, Xw_init=None):
91-
if issparse(X):
92-
raise ValueError("Sparse matrices are not yet support in PDCD_WS solver.")
93-
94-
datafit, penalty = PDCD_WS._validate_init(datafit_, penalty_)
94+
def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
9595
n_samples, n_features = X.shape
9696

9797
# init steps
@@ -196,27 +196,12 @@ def _solve_subproblem(y, X, w, Xw, z, z_bar, datafit, penalty,
196196
if stop_crit_in <= tol_in:
197197
break
198198

199-
@staticmethod
200-
def _validate_init(datafit_, penalty_):
201-
# validate datafit
202-
missing_attrs = []
203-
for attr in ('prox_conjugate', 'subdiff_distance'):
204-
if not hasattr(datafit_, attr):
205-
missing_attrs.append(f"`{attr}`")
206-
207-
if len(missing_attrs):
208-
raise AttributeError(
209-
"Datafit is not compatible with PDCD_WS solver.\n"
210-
"Datafit must implement `prox_conjugate` and `subdiff_distance`.\n"
211-
f"Missing {' and '.join(missing_attrs)}."
199+
def custom_checks(self, X, y, datafit, penalty):
200+
if issparse(X):
201+
raise ValueError(
202+
"Sparse matrices are not yet supported in `PDCD_WS` solver."
212203
)
213204

214-
# jit compile classes
215-
compiled_datafit = compiled_clone(datafit_)
216-
compiled_penalty = compiled_clone(penalty_)
217-
218-
return compiled_datafit, compiled_penalty
219-
220205

221206
@njit
222207
def _scores_primal(X, w, z, penalty, primal_steps, ws):

skglm/experimental/tests/test_quantile_regression.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from skglm.penalties import L1
66
from skglm.experimental.pdcd_ws import PDCD_WS
77
from skglm.experimental.quantile_regression import Pinball
8+
from skglm.utils.jit_compilation import compiled_clone
89

910
from skglm.utils.data import make_correlated_data
1011
from sklearn.linear_model import QuantileRegressor
@@ -21,9 +22,12 @@ def test_PDCD_WS(quantile_level):
2122
alpha_max = norm(X.T @ (np.sign(y)/2 + (quantile_level - 0.5)), ord=np.inf)
2223
alpha = alpha_max / 5
2324

25+
datafit = compiled_clone(Pinball(quantile_level))
26+
penalty = compiled_clone(L1(alpha))
27+
2428
w = PDCD_WS(
2529
dual_init=np.sign(y)/2 + (quantile_level - 0.5)
26-
).solve(X, y, Pinball(quantile_level), L1(alpha))[0]
30+
).solve(X, y, datafit, penalty)[0]
2731

2832
clf = QuantileRegressor(
2933
quantile=quantile_level,

skglm/experimental/tests/test_sqrt_lasso.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from skglm.experimental.sqrt_lasso import (SqrtLasso, SqrtQuadratic,
88
_chambolle_pock_sqrt)
99
from skglm.experimental.pdcd_ws import PDCD_WS
10+
from skglm.utils.jit_compilation import compiled_clone
1011

1112

1213
def test_alpha_max():
@@ -72,7 +73,10 @@ def test_PDCD_WS(with_dual_init):
7273

7374
dual_init = y / norm(y) if with_dual_init else None
7475

75-
w = PDCD_WS(dual_init=dual_init).solve(X, y, SqrtQuadratic(), L1(alpha))[0]
76+
datafit = compiled_clone(SqrtQuadratic())
77+
penalty = compiled_clone(L1(alpha))
78+
79+
w = PDCD_WS(dual_init=dual_init).solve(X, y, datafit, penalty)[0]
7680
clf = SqrtLasso(alpha=alpha, fit_intercept=False, tol=1e-12).fit(X, y)
7781
np.testing.assert_allclose(clf.coef_, w, atol=1e-6)
7882

skglm/penalties/block_separable.py

-11
Original file line numberDiff line numberDiff line change
@@ -458,17 +458,6 @@ def prox_1group(self, value, stepsize, g):
458458
res = ST_vec(value, self.alpha * stepsize * self.weights_features[g])
459459
return BST(res, self.alpha * stepsize * self.weights_groups[g])
460460

461-
def subdiff_distance(self, w, grad_ws, ws):
462-
"""Compute distance to the subdifferential at ``w`` of negative gradient.
463-
464-
Refer to :ref:`subdiff_positive_group_lasso` for details of the derivation.
465-
466-
Note:
467-
----
468-
``grad_ws`` is a stacked array of gradients ``[grad_ws_1, grad_ws_2, ...]``.
469-
"""
470-
raise NotImplementedError("Too hard for now")
471-
472461
def is_penalized(self, n_groups):
473462
return np.ones(n_groups, dtype=np.bool_)
474463

skglm/penalties/non_separable.py

-10
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,3 @@ def prox_vec(self, x, stepsize):
4848
prox[sorted_indices] = prox_SLOPE(abs_x[sorted_indices], alphas * stepsize)
4949

5050
return np.sign(x) * prox
51-
52-
def prox_1d(self, value, stepsize, j):
53-
raise ValueError(
54-
"No coordinate-wise proximal operator for SLOPE. Use `prox_vec` instead."
55-
)
56-
57-
def subdiff_distance(self, w, grad, ws):
58-
return ValueError(
59-
"No subdifferential distance for SLOPE. Use `opt_strategy='fixpoint'`"
60-
)

skglm/solvers/anderson_cd.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
)
88
from skglm.solvers.base import BaseSolver
99
from skglm.utils.anderson import AndersonAcceleration
10+
from skglm.utils.validation import check_attrs
1011

1112

1213
class AndersonCD(BaseSolver):
@@ -47,6 +48,9 @@ class AndersonCD(BaseSolver):
4748
code: https://github.com/mathurinm/andersoncd
4849
"""
4950

51+
_datafit_required_attr = ("get_lipschitz", "gradient_scalar")
52+
_penalty_required_attr = ("prox_1d",)
53+
5054
def __init__(self, max_iter=50, max_epochs=50_000, p0=10,
5155
tol=1e-4, ws_strategy="subdiff", fit_intercept=True,
5256
warm_start=False, verbose=0):
@@ -59,7 +63,7 @@ def __init__(self, max_iter=50, max_epochs=50_000, p0=10,
5963
self.warm_start = warm_start
6064
self.verbose = verbose
6165

62-
def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
66+
def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
6367
if self.ws_strategy not in ("subdiff", "fixpoint"):
6468
raise ValueError(
6569
'Unsupported value for self.ws_strategy:', self.ws_strategy)
@@ -269,6 +273,21 @@ def path(self, X, y, datafit, penalty, alphas=None, w_init=None,
269273
results += (n_iters,)
270274
return results
271275

276+
def custom_checks(self, X, y, datafit, penalty):
277+
# check datafit support sparse data
278+
check_attrs(
279+
datafit, solver=self,
280+
required_attr=self._datafit_required_attr,
281+
support_sparse=sparse.issparse(X)
282+
)
283+
284+
# ws strategy
285+
if self.ws_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"):
286+
raise AttributeError(
287+
"Penalty must implement `subdiff_distance` "
288+
"to use ws_strategy='subdiff' in solver AndersonCD."
289+
)
290+
272291

273292
@njit
274293
def _cd_epoch(X, y, w, Xw, lc, datafit, penalty, ws):

skglm/solvers/base.py

+79-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,38 @@
1-
from abc import abstractmethod
1+
from abc import abstractmethod, ABC
2+
from skglm.utils.validation import check_attrs
23

34

4-
class BaseSolver():
5-
"""Base class for solvers."""
5+
class BaseSolver(ABC):
6+
"""Base class for solvers.
7+
8+
Attributes
9+
----------
10+
_datafit_required_attr : list
11+
List of attributes that must be implemented in Datafit.
12+
13+
_penalty_required_attr : list
14+
List of attributes that must be implemented in Penalty.
15+
16+
Notes
17+
-----
18+
For required attributes, if an attribute is given as a list of attributes
19+
it means at least one of them should be implemented.
20+
For instance, if
21+
22+
_datafit_required_attr = (
23+
"get_global_lipschitz",
24+
("gradient", "gradient_scalar")
25+
)
26+
27+
it mean datafit must implement the methods ``get_global_lipschitz``
28+
and (``gradient`` or ``gradient_scaler``).
29+
"""
30+
31+
_datafit_required_attr: list
32+
_penalty_required_attr: list
633

734
@abstractmethod
8-
def solve(self, X, y, datafit, penalty, w_init, Xw_init):
35+
def _solve(self, X, y, datafit, penalty, w_init, Xw_init):
936
"""Solve an optimization problem.
1037
1138
Parameters
@@ -39,3 +66,51 @@ def solve(self, X, y, datafit, penalty, w_init, Xw_init):
3966
stop_crit : float
4067
Value of stopping criterion at convergence.
4168
"""
69+
70+
def custom_checks(self, X, y, datafit, penalty):
71+
"""Ensure the solver is suited for the `datafit` + `penalty` problem.
72+
73+
This method includes extra checks to perform
74+
aside from checking attributes compatibility.
75+
76+
Parameters
77+
----------
78+
X : array, shape (n_samples, n_features)
79+
Training data.
80+
81+
y : array, shape (n_samples,)
82+
Target values.
83+
84+
datafit : instance of BaseDatafit
85+
Datafit.
86+
87+
penalty : instance of BasePenalty
88+
Penalty.
89+
"""
90+
pass
91+
92+
def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None,
93+
*, run_checks=True):
94+
"""Solve the optimization problem after validating its compatibility.
95+
96+
A proxy of ``_solve`` method that implicitly ensures the compatibility
97+
of ``datafit`` and ``penalty`` with the solver.
98+
99+
Examples
100+
--------
101+
>>> ...
102+
>>> coefs, obj_out, stop_crit = solver.solve(X, y, datafit, penalty)
103+
"""
104+
if run_checks:
105+
self._validate(X, y, datafit, penalty)
106+
107+
return self._solve(X, y, datafit, penalty, w_init, Xw_init)
108+
109+
def _validate(self, X, y, datafit, penalty):
110+
# execute: `custom_checks` then check attributes
111+
self.custom_checks(X, y, datafit, penalty)
112+
113+
# do not check for sparse support here, make the check at the solver level
114+
# some solvers like ProxNewton don't require methods for sparse support
115+
check_attrs(datafit, self, self._datafit_required_attr)
116+
check_attrs(penalty, self, self._penalty_required_attr)

skglm/solvers/fista.py

+26-14
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from skglm.solvers.base import BaseSolver
44
from skglm.solvers.common import construct_grad, construct_grad_sparse
55
from skglm.utils.prox_funcs import _prox_vec
6+
from skglm.utils.validation import check_attrs
67

78

89
class FISTA(BaseSolver):
@@ -27,6 +28,9 @@ class FISTA(BaseSolver):
2728
https://epubs.siam.org/doi/10.1137/080716542
2829
"""
2930

31+
_datafit_required_attr = ("get_global_lipschitz", ("gradient", "gradient_scalar"))
32+
_penalty_required_attr = (("prox_1d", "prox_vec"),)
33+
3034
def __init__(self, max_iter=100, tol=1e-4, opt_strategy="subdiff", verbose=0):
3135
self.max_iter = max_iter
3236
self.tol = tol
@@ -35,7 +39,7 @@ def __init__(self, max_iter=100, tol=1e-4, opt_strategy="subdiff", verbose=0):
3539
self.fit_intercept = False # needed to be passed to GeneralizedLinearEstimator
3640
self.warm_start = False
3741

38-
def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
42+
def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
3943
p_objs_out = []
4044
n_samples, n_features = X.shape
4145
all_features = np.arange(n_features)
@@ -46,19 +50,12 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
4650
z = w_init.copy() if w_init is not None else np.zeros(n_features)
4751
Xw = Xw_init.copy() if Xw_init is not None else np.zeros(n_samples)
4852

49-
try:
50-
if X_is_sparse:
51-
lipschitz = datafit.get_global_lipschitz_sparse(
52-
X.data, X.indptr, X.indices, y
53-
)
54-
else:
55-
lipschitz = datafit.get_global_lipschitz(X, y)
56-
except AttributeError as e:
57-
sparse_suffix = '_sparse' if X_is_sparse else ''
58-
59-
raise Exception(
60-
"Datafit is not compatible with FISTA solver.\n Datafit must "
61-
f"implement `get_global_lipschitz{sparse_suffix}` method") from e
53+
if X_is_sparse:
54+
lipschitz = datafit.get_global_lipschitz_sparse(
55+
X.data, X.indptr, X.indices, y
56+
)
57+
else:
58+
lipschitz = datafit.get_global_lipschitz(X, y)
6259

6360
for n_iter in range(self.max_iter):
6461
t_old = t_new
@@ -111,3 +108,18 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
111108
print(f"Stopping criterion max violation: {stop_crit:.2e}")
112109
break
113110
return w, np.array(p_objs_out), stop_crit
111+
112+
def custom_checks(self, X, y, datafit, penalty):
113+
# check datafit support sparse data
114+
check_attrs(
115+
datafit, solver=self,
116+
required_attr=self._datafit_required_attr,
117+
support_sparse=issparse(X)
118+
)
119+
120+
# optimality check
121+
if self.opt_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"):
122+
raise AttributeError(
123+
"Penalty must implement `subdiff_distance` "
124+
"to use `opt_strategy='subdiff'` in Fista solver."
125+
)

0 commit comments

Comments
 (0)