From c509bc2cee530d45852764fdd22ab6b3f0c245b6 Mon Sep 17 00:00:00 2001 From: Ian Bolliger Date: Sun, 2 May 2021 19:22:48 -0700 Subject: [PATCH] fix failed score grid search --- dask_ml/model_selection/methods.py | 30 +++++++++------ .../dask_searchcv/test_model_selection.py | 38 +++++++++++++++++-- 2 files changed, 53 insertions(+), 15 deletions(-) diff --git a/dask_ml/model_selection/methods.py b/dask_ml/model_selection/methods.py index 16dd6c501..06f9f5c51 100644 --- a/dask_ml/model_selection/methods.py +++ b/dask_ml/model_selection/methods.py @@ -271,29 +271,35 @@ def fit_transform( return (est, fit_time), Xt -def _score(est, X, y, scorer): +def _score_with_error(est, X, y, scorer, error_score): + try: + out = scorer(est, X) if y is None else scorer(est, X, y) + except Exception: + if error_score == "raise": + raise + else: + out = error_score + return out + + +def _score(est, X, y, scorer, error_score): if est is FIT_FAILURE: return FIT_FAILURE if isinstance(scorer, Mapping): - return {k: v(est, X) if y is None else v(est, X, y) for k, v in scorer.items()} - return scorer(est, X) if y is None else scorer(est, X, y) + return { + k: _score_with_error(est, X, y, v, error_score) for k, v in scorer.items() + } + return _score_with_error(est, X, y, scorer, error_score) def score(est_and_time, X_test, y_test, X_train, y_train, scorer, error_score): est, fit_time = est_and_time start_time = default_timer() - try: - test_score = _score(est, X_test, y_test, scorer) - except Exception: - if error_score == "raise": - raise - else: - score_time = default_timer() - start_time - return fit_time, error_score, score_time, error_score + test_score = _score(est, X_test, y_test, scorer, error_score) score_time = default_timer() - start_time if X_train is None: return fit_time, test_score, score_time - train_score = _score(est, X_train, y_train, scorer) + train_score = _score(est, X_train, y_train, scorer, error_score) return fit_time, test_score, score_time, train_score diff --git a/tests/model_selection/dask_searchcv/test_model_selection.py b/tests/model_selection/dask_searchcv/test_model_selection.py index e0a203932..ef5ab04ba 100644 --- a/tests/model_selection/dask_searchcv/test_model_selection.py +++ b/tests/model_selection/dask_searchcv/test_model_selection.py @@ -22,7 +22,7 @@ from sklearn.ensemble import RandomForestClassifier from sklearn.exceptions import FitFailedWarning, NotFittedError from sklearn.feature_selection import SelectKBest -from sklearn.linear_model import LogisticRegression +from sklearn.linear_model import LogisticRegression, LinearRegression from sklearn.model_selection import ( GridSearchCV, GroupKFold, @@ -445,7 +445,12 @@ def test_pipeline_sub_estimators(): }, ] - gs = GridSearchCV(pipe, param_grid=param_grid, return_train_score=True, cv=3,) + gs = GridSearchCV( + pipe, + param_grid=param_grid, + return_train_score=True, + cv=3, + ) gs.fit(X, y) dgs = dcv.GridSearchCV( pipe, param_grid=param_grid, scheduler="sync", return_train_score=True, cv=3 @@ -687,6 +692,29 @@ def test_estimator_predict_failure(in_pipeline): gs.fit(X, y) +def test_estimator_score_failure(): + X = np.array([[1, 2], [2, 1], [0, 0]]) + + y = 3 * X[:, 0] + 4 * X[:, 1] + cv = LeaveOneOut() + + ols = LinearRegression(fit_intercept=False) + + # mean poisson deviance is undefined when y_hat is 0, so this can be used to test + # when estimator fit succeeds but score fails + regr = dcv.GridSearchCV( + ols, + {"normalize": [False, True]}, + scoring=["neg_mean_squared_error", "neg_mean_poisson_deviance"], + refit=False, + cv=cv, + error_score=-1, + n_jobs=1, + ) + regr.fit(X, y) + assert (regr.cv_results_["split2_test_neg_mean_poisson_deviance"] == [-1, -1]).all() + + def test_pipeline_raises(): X, y = make_classification(n_samples=100, n_features=10, random_state=0) @@ -946,7 +974,11 @@ def test_gridsearch_with_arraylike_fit_param(cache_cv): param_grid = {"foo_param": [0.0001, 0.1]} a = dcv.GridSearchCV( - MockClassifierWithFitParam(), param_grid, cv=3, refit=False, cache_cv=cache_cv, + MockClassifierWithFitParam(), + param_grid, + cv=3, + refit=False, + cache_cv=cache_cv, ) b = GridSearchCV(MockClassifierWithFitParam(), param_grid, cv=3, refit=False)