Skip to content

Commit d91f0d2

Browse files
FIX: Don't prevent oneDAL usage with non-PSD elasticnet/lasso (uxlfoundation#2713)
* remove unneeded filter for d4p support * remove unnecessary copying * more clear test * mention that multi-target is supported, add tests for it * missing word
1 parent ac91982 commit d91f0d2

File tree

3 files changed

+162
-25
lines changed

3 files changed

+162
-25
lines changed

daal4py/sklearn/linear_model/_coordinate_descent.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,7 @@ def _daal4py_fit_enet(self, X, y_, check_input):
177177
inputArgument = np.zeros((n_rows, n_cols), dtype=_fptype)
178178
for i in range(n_rows):
179179
inputArgument[i][0] = self.intercept_ if (n_rows == 1) else self.intercept_[i]
180-
inputArgument[i][1:] = (
181-
self.coef_[:].copy(order="C")
182-
if (n_rows == 1)
183-
else self.coef_[i, :].copy(order="C")
184-
)
180+
inputArgument[i][1:] = self.coef_[:] if (n_rows == 1) else self.coef_[i, :]
185181
cd_solver.setup(inputArgument)
186182
doUse_condition = self.copy_X is False or (
187183
self.fit_intercept and _normalize and self.copy_X
@@ -695,9 +691,6 @@ def predict(self, X):
695691
_X = check_array(
696692
X, accept_sparse=["csr", "csc", "coo"], dtype=[np.float64, np.float32]
697693
)
698-
good_shape_for_daal = (
699-
True if _X.ndim <= 1 else True if _X.shape[0] >= _X.shape[1] else False
700-
)
701694

702695
_patching_status = PatchingConditionsChain(
703696
"sklearn.linear_model.ElasticNet.predict"
@@ -706,11 +699,6 @@ def predict(self, X):
706699
[
707700
(hasattr(self, "daal_model_"), "oneDAL model was not trained."),
708701
(not sp.issparse(_X), "X is sparse. Sparse input is not supported."),
709-
(
710-
good_shape_for_daal,
711-
"The shape of X does not satisfy oneDAL requirements: "
712-
"number of features > number of samples.",
713-
),
714702
]
715703
)
716704
_patching_status.write_log()
@@ -808,20 +796,12 @@ def predict(self, X):
808796
_X = check_array(
809797
X, accept_sparse=["csr", "csc", "coo"], dtype=[np.float64, np.float32]
810798
)
811-
good_shape_for_daal = (
812-
True if _X.ndim <= 1 else True if _X.shape[0] >= _X.shape[1] else False
813-
)
814799

815800
_patching_status = PatchingConditionsChain("sklearn.linear_model.Lasso.predict")
816801
_dal_ready = _patching_status.and_conditions(
817802
[
818803
(hasattr(self, "daal_model_"), "oneDAL model was not trained."),
819804
(not sp.issparse(_X), "X is sparse. Sparse input is not supported."),
820-
(
821-
good_shape_for_daal,
822-
"The shape of X does not satisfy oneDAL requirements: "
823-
"number of features > number of samples.",
824-
),
825805
]
826806
)
827807
_patching_status.write_log()
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# ==============================================================================
2+
# Copyright contributors to the oneDAL project
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
import warnings
17+
18+
import numpy as np
19+
import pytest
20+
from sklearn.datasets import make_regression
21+
from sklearn.exceptions import ConvergenceWarning
22+
from sklearn.linear_model import ElasticNet as _sklElasticnet
23+
from sklearn.linear_model import Lasso as _sklLasso
24+
25+
from daal4py.sklearn.linear_model import ElasticNet, Lasso
26+
27+
28+
def fn_lasso(model, X, y, lambda_):
29+
resid = y - model.predict(X)
30+
fn_ssq = resid.reshape(-1) @ resid.reshape(-1)
31+
fn_l1 = np.abs(model.coef_).sum()
32+
return fn_ssq + lambda_ * fn_l1
33+
34+
35+
@pytest.mark.parametrize("nrows", [10, 20])
36+
@pytest.mark.parametrize("ncols", [10, 20])
37+
@pytest.mark.parametrize("n_targets", [1, 2])
38+
@pytest.mark.parametrize("fit_intercept", [False, True])
39+
@pytest.mark.parametrize("positive", [False, True])
40+
@pytest.mark.parametrize("l1_ratio", [0.0, 1.0, 0.5])
41+
def test_enet_is_correct(nrows, ncols, n_targets, fit_intercept, positive, l1_ratio):
42+
X, y = make_regression(
43+
n_samples=nrows, n_features=ncols, n_targets=n_targets, random_state=123
44+
)
45+
with warnings.catch_warnings():
46+
warnings.simplefilter("ignore", ConvergenceWarning)
47+
model_d4p = ElasticNet(
48+
fit_intercept=fit_intercept,
49+
positive=positive,
50+
l1_ratio=l1_ratio,
51+
tol=1e-7,
52+
max_iter=int(1e4),
53+
).fit(X, y)
54+
model_skl = _sklElasticnet(
55+
fit_intercept=fit_intercept,
56+
positive=positive,
57+
l1_ratio=l1_ratio,
58+
tol=1e-7,
59+
max_iter=int(1e4),
60+
).fit(X, y)
61+
62+
# Note: lasso is not guaranteed to have a unique global optimum.
63+
# If the coefficients do not match, this makes another check on
64+
# the optimality of the function values instead. It checks that
65+
# the result from daal4py is no worse than 2% off scikit-learn's.
66+
67+
tol = 1e-6 if n_targets == 1 else 1e-5
68+
try:
69+
np.testing.assert_allclose(model_d4p.coef_, model_skl.coef_, atol=tol, rtol=tol)
70+
except AssertionError as e:
71+
if l1_ratio != 1:
72+
raise e
73+
fn_d4p = fn_lasso(model_d4p, X, y, model_d4p.alpha)
74+
fn_skl = fn_lasso(model_skl, X, y, model_skl.alpha)
75+
assert fn_d4p <= fn_skl * 1.02
76+
77+
if fit_intercept:
78+
np.testing.assert_allclose(
79+
model_d4p.intercept_, model_skl.intercept_, atol=tol, rtol=tol
80+
)
81+
82+
if positive:
83+
assert np.all(model_d4p.coef_ >= 0)
84+
85+
86+
@pytest.mark.parametrize("nrows", [10, 20])
87+
@pytest.mark.parametrize("ncols", [10, 20])
88+
@pytest.mark.parametrize("n_targets", [1, 2])
89+
@pytest.mark.parametrize("fit_intercept", [False, True])
90+
@pytest.mark.parametrize("positive", [False, True])
91+
@pytest.mark.parametrize("alpha", [1e-2, 1e2])
92+
def test_lasso_is_correct(nrows, ncols, n_targets, fit_intercept, positive, alpha):
93+
X, y = make_regression(
94+
n_samples=nrows, n_features=ncols, n_targets=n_targets, random_state=123
95+
)
96+
with warnings.catch_warnings():
97+
warnings.simplefilter("ignore", ConvergenceWarning)
98+
model_d4p = Lasso(
99+
fit_intercept=fit_intercept,
100+
positive=positive,
101+
alpha=alpha,
102+
tol=1e-7,
103+
max_iter=int(1e4),
104+
).fit(X, y)
105+
model_skl = _sklLasso(
106+
fit_intercept=fit_intercept,
107+
positive=positive,
108+
alpha=alpha,
109+
tol=1e-7,
110+
max_iter=int(1e4),
111+
).fit(X, y)
112+
113+
tol = 1e-4 if alpha < 1 else (1e-6 if n_targets == 1 else 1e-5)
114+
try:
115+
np.testing.assert_allclose(model_d4p.coef_, model_skl.coef_, atol=tol, rtol=tol)
116+
if fit_intercept:
117+
np.testing.assert_allclose(
118+
model_d4p.intercept_, model_skl.intercept_, atol=tol, rtol=tol
119+
)
120+
except AssertionError as e:
121+
fn_d4p = fn_lasso(model_d4p, X, y, model_d4p.alpha)
122+
fn_skl = fn_lasso(model_skl, X, y, model_skl.alpha)
123+
assert fn_d4p <= fn_skl * 1.02
124+
125+
if positive:
126+
assert np.all(model_d4p.coef_ >= 0)
127+
128+
129+
@pytest.mark.parametrize("n_targets", [1, 2])
130+
def test_warm_start(n_targets):
131+
X, y = make_regression(
132+
n_samples=20, n_features=10, n_targets=n_targets, random_state=123
133+
)
134+
X1 = X[:10]
135+
y1 = y[:10]
136+
X2 = X[10:]
137+
y2 = y[10:]
138+
139+
with warnings.catch_warnings():
140+
warnings.simplefilter("ignore", ConvergenceWarning)
141+
model_d4p = ElasticNet(
142+
warm_start=True,
143+
tol=1e-7,
144+
max_iter=int(1e4),
145+
).fit(X1, y1)
146+
coefs_ref = model_d4p.coef_.copy()
147+
intercept_ref = model_d4p.intercept_.copy()
148+
149+
model_d4p.set_params(max_iter=1)
150+
model_d4p.fit(X2, y2)
151+
152+
model_from_scratch = ElasticNet(tol=1e-7, max_iter=int(1e4)).fit(X2, y2)
153+
154+
diff_ref = np.linalg.norm(model_d4p.coef_ - coefs_ref)
155+
diff_from_scratch = np.linalg.norm(model_d4p.coef_ - model_from_scratch.coef_)
156+
157+
assert diff_ref < diff_from_scratch

doc/sources/algorithms.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,26 +117,26 @@ Regression
117117
- All parameters are supported except:
118118

119119
- ``sample_weight`` != `None`
120-
- ``positive`` = `True`
120+
- ``positive`` = `True` (this is supported through the class :obj:`sklearn.linear_model.ElasticNet`)
121121
- Only dense data is supported.
122122
* - :obj:`sklearn.linear_model.Ridge`
123123
- All parameters are supported except:
124124

125125
- ``solver`` != `'auto'`
126126
- ``sample_weight`` != `None`
127-
- ``positive`` = `True`
127+
- ``positive`` = `True` (this is supported through the class :obj:`sklearn.linear_model.ElasticNet`)
128128
- ``alpha`` must be a scalar
129129
- Only dense data is supported.
130130
* - :obj:`sklearn.linear_model.ElasticNet`
131131
- All parameters are supported except:
132132

133133
- ``sample_weight`` != `None`
134-
- Multi-output and sparse data are not supported, `#observations` should be >= `#features`.
134+
- Sparse data is not supported.
135135
* - :obj:`sklearn.linear_model.Lasso`
136136
- All parameters are supported except:
137137

138138
- ``sample_weight`` != `None`
139-
- Multi-output and sparse data are not supported, `#observations` should be >= `#features`.
139+
- Sparse data is not supported.
140140

141141
Clustering
142142
**********

0 commit comments

Comments
 (0)