Skip to content

Commit c509bc2

Browse files
committed
fix failed score grid search
1 parent 0ea276d commit c509bc2

File tree

2 files changed

+53
-15
lines changed

2 files changed

+53
-15
lines changed

dask_ml/model_selection/methods.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -271,29 +271,35 @@ def fit_transform(
271271
return (est, fit_time), Xt
272272

273273

274-
def _score(est, X, y, scorer):
274+
def _score_with_error(est, X, y, scorer, error_score):
275+
try:
276+
out = scorer(est, X) if y is None else scorer(est, X, y)
277+
except Exception:
278+
if error_score == "raise":
279+
raise
280+
else:
281+
out = error_score
282+
return out
283+
284+
285+
def _score(est, X, y, scorer, error_score):
275286
if est is FIT_FAILURE:
276287
return FIT_FAILURE
277288
if isinstance(scorer, Mapping):
278-
return {k: v(est, X) if y is None else v(est, X, y) for k, v in scorer.items()}
279-
return scorer(est, X) if y is None else scorer(est, X, y)
289+
return {
290+
k: _score_with_error(est, X, y, v, error_score) for k, v in scorer.items()
291+
}
292+
return _score_with_error(est, X, y, scorer, error_score)
280293

281294

282295
def score(est_and_time, X_test, y_test, X_train, y_train, scorer, error_score):
283296
est, fit_time = est_and_time
284297
start_time = default_timer()
285-
try:
286-
test_score = _score(est, X_test, y_test, scorer)
287-
except Exception:
288-
if error_score == "raise":
289-
raise
290-
else:
291-
score_time = default_timer() - start_time
292-
return fit_time, error_score, score_time, error_score
298+
test_score = _score(est, X_test, y_test, scorer, error_score)
293299
score_time = default_timer() - start_time
294300
if X_train is None:
295301
return fit_time, test_score, score_time
296-
train_score = _score(est, X_train, y_train, scorer)
302+
train_score = _score(est, X_train, y_train, scorer, error_score)
297303
return fit_time, test_score, score_time, train_score
298304

299305

tests/model_selection/dask_searchcv/test_model_selection.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from sklearn.ensemble import RandomForestClassifier
2323
from sklearn.exceptions import FitFailedWarning, NotFittedError
2424
from sklearn.feature_selection import SelectKBest
25-
from sklearn.linear_model import LogisticRegression
25+
from sklearn.linear_model import LogisticRegression, LinearRegression
2626
from sklearn.model_selection import (
2727
GridSearchCV,
2828
GroupKFold,
@@ -445,7 +445,12 @@ def test_pipeline_sub_estimators():
445445
},
446446
]
447447

448-
gs = GridSearchCV(pipe, param_grid=param_grid, return_train_score=True, cv=3,)
448+
gs = GridSearchCV(
449+
pipe,
450+
param_grid=param_grid,
451+
return_train_score=True,
452+
cv=3,
453+
)
449454
gs.fit(X, y)
450455
dgs = dcv.GridSearchCV(
451456
pipe, param_grid=param_grid, scheduler="sync", return_train_score=True, cv=3
@@ -687,6 +692,29 @@ def test_estimator_predict_failure(in_pipeline):
687692
gs.fit(X, y)
688693

689694

695+
def test_estimator_score_failure():
696+
X = np.array([[1, 2], [2, 1], [0, 0]])
697+
698+
y = 3 * X[:, 0] + 4 * X[:, 1]
699+
cv = LeaveOneOut()
700+
701+
ols = LinearRegression(fit_intercept=False)
702+
703+
# mean poisson deviance is undefined when y_hat is 0, so this can be used to test
704+
# when estimator fit succeeds but score fails
705+
regr = dcv.GridSearchCV(
706+
ols,
707+
{"normalize": [False, True]},
708+
scoring=["neg_mean_squared_error", "neg_mean_poisson_deviance"],
709+
refit=False,
710+
cv=cv,
711+
error_score=-1,
712+
n_jobs=1,
713+
)
714+
regr.fit(X, y)
715+
assert (regr.cv_results_["split2_test_neg_mean_poisson_deviance"] == [-1, -1]).all()
716+
717+
690718
def test_pipeline_raises():
691719
X, y = make_classification(n_samples=100, n_features=10, random_state=0)
692720

@@ -946,7 +974,11 @@ def test_gridsearch_with_arraylike_fit_param(cache_cv):
946974
param_grid = {"foo_param": [0.0001, 0.1]}
947975

948976
a = dcv.GridSearchCV(
949-
MockClassifierWithFitParam(), param_grid, cv=3, refit=False, cache_cv=cache_cv,
977+
MockClassifierWithFitParam(),
978+
param_grid,
979+
cv=3,
980+
refit=False,
981+
cache_cv=cache_cv,
950982
)
951983
b = GridSearchCV(MockClassifierWithFitParam(), param_grid, cv=3, refit=False)
952984

0 commit comments

Comments
 (0)