Skip to content

Commit 5efeb8f

Browse files
ENH - jit-compile datafits and penalties inside solver (#270)
Co-authored-by: mathurinm <[email protected]>
1 parent 9c37cd7 commit 5efeb8f

22 files changed

+150
-138
lines changed

examples/plot_sparse_recovery.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from skglm.utils.data import make_correlated_data
1919
from skglm.solvers import AndersonCD
2020
from skglm.datafits import Quadratic
21-
from skglm.utils.jit_compilation import compiled_clone
2221
from skglm.penalties import L1, MCPenalty, L0_5, L2_3, SCAD
2322

2423
cmap = plt.get_cmap('tab10')
@@ -74,7 +73,7 @@
7473
for idx, estimator in enumerate(penalties.keys()):
7574
print(f'Running {estimator}...')
7675
estimator_path = solver.path(
77-
X, y, compiled_clone(datafit), compiled_clone(penalties[estimator]),
76+
X, y, datafit, penalties[estimator],
7877
alphas=alphas)
7978

8079
f1_temp = np.zeros(n_alphas)

examples/plot_survival_analysis.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# Let's first generate synthetic data on which to run the Cox estimator,
1616
# using ``skglm`` data utils.
1717
#
18+
1819
from skglm.utils.data import make_dummy_survival_data
1920

2021
n_samples, n_features = 500, 100
@@ -59,18 +60,16 @@
5960
# Todo so, we need to combine a Cox datafit and a :math:`\ell_1` penalty
6061
# and solve the resulting problem using skglm Proximal Newton solver ``ProxNewton``.
6162
# We set the intensity of the :math:`\ell_1` regularization to ``alpha=1e-2``.
62-
from skglm.datafits import Cox
6363
from skglm.penalties import L1
64+
from skglm.datafits import Cox
6465
from skglm.solvers import ProxNewton
6566

66-
from skglm.utils.jit_compilation import compiled_clone
67-
6867
# regularization intensity
6968
alpha = 1e-2
7069

7170
# skglm internals: init datafit and penalty
72-
datafit = compiled_clone(Cox())
73-
penalty = compiled_clone(L1(alpha))
71+
datafit = Cox()
72+
penalty = L1(alpha)
7473

7574
datafit.initialize(X, y)
7675

@@ -230,7 +229,7 @@
230229
# We only need to pass in ``use_efron=True`` to the ``Cox`` datafit.
231230

232231
# ensure using Efron estimate
233-
datafit = compiled_clone(Cox(use_efron=True))
232+
datafit = Cox(use_efron=True)
234233
datafit.initialize(X, y)
235234

236235
# solve the problem

skglm/estimators.py

+19-28
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from sklearn.utils._param_validation import Interval, StrOptions
1919
from sklearn.multiclass import OneVsRestClassifier, check_classification_targets
2020

21-
from skglm.utils.jit_compilation import compiled_clone
2221
from skglm.solvers import AndersonCD, MultiTaskBCD, GroupBCD
2322
from skglm.datafits import (Cox, Quadratic, Logistic, QuadraticSVC,
2423
QuadraticMultiTask, QuadraticGroup,)
@@ -102,12 +101,10 @@ def _glm_fit(X, y, model, datafit, penalty, solver):
102101

103102
n_samples, n_features = X_.shape
104103

105-
penalty_jit = compiled_clone(penalty)
106-
datafit_jit = compiled_clone(datafit, to_float32=X.dtype == np.float32)
107104
if issparse(X):
108-
datafit_jit.initialize_sparse(X_.data, X_.indptr, X_.indices, y)
105+
datafit.initialize_sparse(X_.data, X_.indptr, X_.indices, y)
109106
else:
110-
datafit_jit.initialize(X_, y)
107+
datafit.initialize(X_, y)
111108

112109
# if model.warm_start and hasattr(model, 'coef_') and model.coef_ is not None:
113110
if solver.warm_start and hasattr(model, 'coef_') and model.coef_ is not None:
@@ -136,7 +133,7 @@ def _glm_fit(X, y, model, datafit, penalty, solver):
136133
"The size of the WeightedL1 penalty weights should be n_features, "
137134
"expected %i, got %i." % (X_.shape[1], len(penalty.weights)))
138135

139-
coefs, p_obj, kkt = solver.solve(X_, y, datafit_jit, penalty_jit, w, Xw)
136+
coefs, p_obj, kkt = solver.solve(X_, y, datafit, penalty, w, Xw)
140137
model.coef_, model.stop_crit_ = coefs[:n_features], kkt
141138
if y.ndim == 1:
142139
model.intercept_ = coefs[-1] if fit_intercept else 0.
@@ -440,8 +437,8 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
440437
The number of iterations along the path. If return_n_iter is set to
441438
``True``.
442439
"""
443-
penalty = compiled_clone(L1(self.alpha, self.positive))
444-
datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32)
440+
penalty = L1(self.alpha, self.positive)
441+
datafit = Quadratic()
445442
solver = AndersonCD(
446443
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
447444
ws_strategy=self.ws_strategy, fit_intercept=self.fit_intercept,
@@ -581,8 +578,8 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
581578
raise ValueError("The number of weights must match the number of \
582579
features. Got %s, expected %s." % (
583580
len(weights), X.shape[1]))
584-
penalty = compiled_clone(WeightedL1(self.alpha, weights, self.positive))
585-
datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32)
581+
penalty = WeightedL1(self.alpha, weights, self.positive)
582+
datafit = Quadratic()
586583
solver = AndersonCD(
587584
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
588585
ws_strategy=self.ws_strategy, fit_intercept=self.fit_intercept,
@@ -744,8 +741,8 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
744741
The number of iterations along the path. If return_n_iter is set to
745742
``True``.
746743
"""
747-
penalty = compiled_clone(L1_plus_L2(self.alpha, self.l1_ratio, self.positive))
748-
datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32)
744+
penalty = L1_plus_L2(self.alpha, self.l1_ratio, self.positive)
745+
datafit = Quadratic()
749746
solver = AndersonCD(
750747
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
751748
ws_strategy=self.ws_strategy, fit_intercept=self.fit_intercept,
@@ -917,19 +914,17 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
917914
``True``.
918915
"""
919916
if self.weights is None:
920-
penalty = compiled_clone(
921-
MCPenalty(self.alpha, self.gamma, self.positive)
922-
)
917+
penalty = MCPenalty(self.alpha, self.gamma, self.positive)
923918
else:
924919
if X.shape[1] != len(self.weights):
925920
raise ValueError(
926921
"The number of weights must match the number of features. "
927922
f"Got {len(self.weights)}, expected {X.shape[1]}."
928923
)
929-
penalty = compiled_clone(
930-
WeightedMCPenalty(self.alpha, self.gamma, self.weights, self.positive)
931-
)
932-
datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32)
924+
penalty = WeightedMCPenalty(
925+
self.alpha, self.gamma, self.weights, self.positive)
926+
927+
datafit = Quadratic()
933928
solver = AndersonCD(
934929
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
935930
ws_strategy=self.ws_strategy, fit_intercept=self.fit_intercept,
@@ -1369,10 +1364,6 @@ def fit(self, X, y):
13691364
else:
13701365
penalty = L2(self.alpha)
13711366

1372-
# skglm internal: JIT compile classes
1373-
datafit = compiled_clone(datafit)
1374-
penalty = compiled_clone(penalty)
1375-
13761367
# init solver
13771368
if self.l1_ratio == 0.:
13781369
solver = LBFGS(max_iter=self.max_iter, tol=self.tol, verbose=self.verbose)
@@ -1518,14 +1509,14 @@ def fit(self, X, Y):
15181509
if not self.warm_start or not hasattr(self, "coef_"):
15191510
self.coef_ = None
15201511

1521-
datafit_jit = compiled_clone(QuadraticMultiTask(), X.dtype == np.float32)
1522-
penalty_jit = compiled_clone(L2_1(self.alpha), X.dtype == np.float32)
1512+
datafit = QuadraticMultiTask()
1513+
penalty = L2_1(self.alpha)
15231514

15241515
solver = MultiTaskBCD(
15251516
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
15261517
ws_strategy=self.ws_strategy, fit_intercept=self.fit_intercept,
15271518
warm_start=self.warm_start, verbose=self.verbose)
1528-
W, obj_out, kkt = solver.solve(X, Y, datafit_jit, penalty_jit)
1519+
W, obj_out, kkt = solver.solve(X, Y, datafit, penalty)
15291520

15301521
self.coef_ = W[:X.shape[1], :].T
15311522
self.intercept_ = self.fit_intercept * W[-1, :]
@@ -1573,8 +1564,8 @@ def path(self, X, Y, alphas, coef_init=None, return_n_iter=False, **params):
15731564
The number of iterations along the path. If return_n_iter is set to
15741565
``True``.
15751566
"""
1576-
datafit = compiled_clone(QuadraticMultiTask(), to_float32=X.dtype == np.float32)
1577-
penalty = compiled_clone(L2_1(self.alpha))
1567+
datafit = QuadraticMultiTask()
1568+
penalty = L2_1(self.alpha)
15781569
solver = MultiTaskBCD(
15791570
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
15801571
ws_strategy=self.ws_strategy, fit_intercept=self.fit_intercept,

skglm/experimental/reweighted.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ def fit(self, X, y):
6969
f"penalty {self.penalty.__class__.__name__}")
7070

7171
n_features = X.shape[1]
72-
_penalty = compiled_clone(WeightedL1(self.penalty.alpha, np.ones(n_features)))
73-
self.datafit = compiled_clone(self.datafit)
72+
# we need to compile this as it is not passed to solver.solve:
7473
self.penalty = compiled_clone(self.penalty)
74+
_penalty = WeightedL1(self.penalty.alpha, np.ones(n_features))
7575

7676
self.loss_history_ = []
7777

skglm/experimental/sqrt_lasso.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from skglm.penalties import L1
88
from skglm.utils.prox_funcs import ST_vec, proj_L2ball, BST
9-
from skglm.utils.jit_compilation import compiled_clone
109
from skglm.datafits.base import BaseDatafit
1110
from skglm.solvers.prox_newton import ProxNewton
1211

@@ -179,8 +178,8 @@ def path(self, X, y, alphas=None, eps=1e-3, n_alphas=10):
179178
alphas = np.sort(alphas)[::-1]
180179

181180
n_features = X.shape[1]
182-
sqrt_quadratic = compiled_clone(SqrtQuadratic())
183-
l1_penalty = compiled_clone(L1(1.)) # alpha is set along the path
181+
sqrt_quadratic = SqrtQuadratic()
182+
l1_penalty = L1(1.) # alpha is set along the path
184183

185184
coefs = np.zeros((n_alphas, n_features))
186185

skglm/experimental/tests/test_quantile_regression.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from skglm import GeneralizedLinearEstimator
77
from skglm.experimental.pdcd_ws import PDCD_WS
88
from skglm.experimental.quantile_regression import Pinball
9-
from skglm.utils.jit_compilation import compiled_clone
109

1110
from skglm.utils.data import make_correlated_data
1211
from sklearn.linear_model import QuantileRegressor
@@ -23,8 +22,8 @@ def test_PDCD_WS(quantile_level):
2322
alpha_max = norm(X.T @ (np.sign(y)/2 + (quantile_level - 0.5)), ord=np.inf)
2423
alpha = alpha_max / 5
2524

26-
datafit = compiled_clone(Pinball(quantile_level))
27-
penalty = compiled_clone(L1(alpha))
25+
datafit = Pinball(quantile_level)
26+
penalty = L1(alpha)
2827

2928
w = PDCD_WS(
3029
dual_init=np.sign(y)/2 + (quantile_level - 0.5)

skglm/experimental/tests/test_sqrt_lasso.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from skglm.experimental.sqrt_lasso import (SqrtLasso, SqrtQuadratic,
88
_chambolle_pock_sqrt)
99
from skglm.experimental.pdcd_ws import PDCD_WS
10-
from skglm.utils.jit_compilation import compiled_clone
1110

1211

1312
def test_alpha_max():
@@ -70,8 +69,8 @@ def test_PDCD_WS(with_dual_init):
7069

7170
dual_init = y / norm(y) if with_dual_init else None
7271

73-
datafit = compiled_clone(SqrtQuadratic())
74-
penalty = compiled_clone(L1(alpha))
72+
datafit = SqrtQuadratic()
73+
penalty = L1(alpha)
7574

7675
w = PDCD_WS(dual_init=dual_init).solve(X, y, datafit, penalty)[0]
7776
clf = SqrtLasso(alpha=alpha, tol=1e-12).fit(X, y)

skglm/solvers/base.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
import warnings
12
from abc import abstractmethod, ABC
3+
4+
import numpy as np
5+
26
from skglm.utils.validation import check_attrs
7+
from skglm.utils.jit_compilation import compiled_clone
38

49

510
class BaseSolver(ABC):
@@ -89,8 +94,9 @@ def custom_checks(self, X, y, datafit, penalty):
8994
"""
9095
pass
9196

92-
def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None,
93-
*, run_checks=True):
97+
def solve(
98+
self, X, y, datafit, penalty, w_init=None, Xw_init=None, *, run_checks=True
99+
):
94100
"""Solve the optimization problem after validating its compatibility.
95101
96102
A proxy of ``_solve`` method that implicitly ensures the compatibility
@@ -101,6 +107,29 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None,
101107
>>> ...
102108
>>> coefs, obj_out, stop_crit = solver.solve(X, y, datafit, penalty)
103109
"""
110+
# TODO check for datafit/penalty being jit-compiled properly
111+
# instead of searching for a string
112+
if "jitclass" in str(type(datafit)):
113+
warnings.warn(
114+
"Passing in a compiled datafit is deprecated since skglm v0.5 "
115+
"Compilation is now done inside solver."
116+
"This will raise an error starting skglm v0.6 onwards."
117+
)
118+
elif datafit is not None:
119+
datafit = compiled_clone(datafit, to_float32=X.dtype == np.float32)
120+
121+
if "jitclass" in str(type(penalty)):
122+
warnings.warn(
123+
"Passing in a compiled penalty is deprecated since skglm v0.5 "
124+
"Compilation is now done inside solver. "
125+
"This will raise an error starting skglm v0.6 onwards."
126+
)
127+
elif penalty is not None:
128+
penalty = compiled_clone(penalty)
129+
# TODO add support for bool spec in compiled_clone
130+
# currently, doing so break the code
131+
# penalty = compiled_clone(penalty, to_float32=X.dtype == np.float32)
132+
104133
if run_checks:
105134
self._validate(X, y, datafit, penalty)
106135

skglm/solvers/common.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ def dist_fix_point_cd(w, grad_ws, lipschitz_ws, datafit, penalty, ws):
4646

4747

4848
@njit
49-
def dist_fix_point_bcd(
50-
w, grad_ws, lipschitz_ws, datafit, penalty, ws):
49+
def dist_fix_point_bcd(w, grad_ws, lipschitz_ws, datafit, penalty, ws):
5150
"""Compute the violation of the fixed point iterate scheme for BCD.
5251
5352
Parameters

skglm/solvers/fista.py

+2
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,12 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
5151
Xw = Xw_init.copy() if Xw_init is not None else np.zeros(n_samples)
5252

5353
if X_is_sparse:
54+
datafit.initialize_sparse(X.data, X.indptr, X.indices, y)
5455
lipschitz = datafit.get_global_lipschitz_sparse(
5556
X.data, X.indptr, X.indices, y
5657
)
5758
else:
59+
datafit.initialize(X, y)
5860
lipschitz = datafit.get_global_lipschitz(X, y)
5961

6062
for n_iter in range(self.max_iter):

skglm/solvers/group_prox_newton.py

+7
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
6969
stop_crit = 0.
7070
p_objs_out = []
7171

72+
# TODO: to be isolated in a seperated method
73+
is_sparse = issparse(X)
74+
if is_sparse:
75+
datafit.initialize_sparse(X.data, X.indptr, X.indices, y)
76+
else:
77+
datafit.initialize(X, y)
78+
7279
for iter in range(self.max_iter):
7380
grad = _construct_grad(X, y, w, Xw, datafit, all_groups)
7481

skglm/solvers/lbfgs.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ def __init__(self, max_iter=50, tol=1e-4, verbose=False):
3838

3939
def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
4040

41+
# TODO: to be isolated in a seperated method
42+
is_sparse = issparse(X)
43+
if is_sparse:
44+
datafit.initialize_sparse(X.data, X.indptr, X.indices, y)
45+
else:
46+
datafit.initialize(X, y)
47+
4148
def objective(w):
4249
Xw = X @ w
4350
datafit_value = datafit.value(y, w, Xw)
@@ -70,8 +77,7 @@ def callback_post_iter(w_k):
7077

7178
it = len(p_objs_out)
7279
print(
73-
f"Iteration {it}: {p_obj:.10f}, "
74-
f"stopping crit: {stop_crit:.2e}"
80+
f"Iteration {it}: {p_obj:.10f}, " f"stopping crit: {stop_crit:.2e}"
7581
)
7682

7783
n_features = X.shape[1]
@@ -87,7 +93,7 @@ def callback_post_iter(w_k):
8793
options=dict(
8894
maxiter=self.max_iter,
8995
gtol=self.tol,
90-
ftol=0. # set ftol=0. to control convergence using only gtol
96+
ftol=0.0, # set ftol=0. to control convergence using only gtol
9197
),
9298
callback=callback_post_iter,
9399
)
@@ -97,7 +103,7 @@ def callback_post_iter(w_k):
97103
f"`LBFGS` did not converge for tol={self.tol:.3e} "
98104
f"and max_iter={self.max_iter}.\n"
99105
"Consider increasing `max_iter` and/or `tol`.",
100-
category=ConvergenceWarning
106+
category=ConvergenceWarning,
101107
)
102108

103109
w = result.x
@@ -110,7 +116,8 @@ def callback_post_iter(w_k):
110116
def custom_checks(self, X, y, datafit, penalty):
111117
# check datafit support sparse data
112118
check_attrs(
113-
datafit, solver=self,
119+
datafit,
120+
solver=self,
114121
required_attr=self._datafit_required_attr,
115-
support_sparse=issparse(X)
122+
support_sparse=issparse(X),
116123
)

0 commit comments

Comments
 (0)