|
22 | 22 | from sklearn.ensemble import RandomForestClassifier
|
23 | 23 | from sklearn.exceptions import FitFailedWarning, NotFittedError
|
24 | 24 | from sklearn.feature_selection import SelectKBest
|
25 |
| -from sklearn.linear_model import LogisticRegression |
| 25 | +from sklearn.linear_model import LogisticRegression, LinearRegression |
26 | 26 | from sklearn.model_selection import (
|
27 | 27 | GridSearchCV,
|
28 | 28 | GroupKFold,
|
@@ -445,7 +445,12 @@ def test_pipeline_sub_estimators():
|
445 | 445 | },
|
446 | 446 | ]
|
447 | 447 |
|
448 |
| - gs = GridSearchCV(pipe, param_grid=param_grid, return_train_score=True, cv=3,) |
| 448 | + gs = GridSearchCV( |
| 449 | + pipe, |
| 450 | + param_grid=param_grid, |
| 451 | + return_train_score=True, |
| 452 | + cv=3, |
| 453 | + ) |
449 | 454 | gs.fit(X, y)
|
450 | 455 | dgs = dcv.GridSearchCV(
|
451 | 456 | pipe, param_grid=param_grid, scheduler="sync", return_train_score=True, cv=3
|
@@ -687,6 +692,29 @@ def test_estimator_predict_failure(in_pipeline):
|
687 | 692 | gs.fit(X, y)
|
688 | 693 |
|
689 | 694 |
|
| 695 | +def test_estimator_score_failure(): |
| 696 | + X = np.array([[1, 2], [2, 1], [0, 0]]) |
| 697 | + |
| 698 | + y = 3 * X[:, 0] + 4 * X[:, 1] |
| 699 | + cv = LeaveOneOut() |
| 700 | + |
| 701 | + ols = LinearRegression(fit_intercept=False) |
| 702 | + |
| 703 | + # mean poisson deviance is undefined when y_hat is 0, so this can be used to test |
| 704 | + # when estimator fit succeeds but score fails |
| 705 | + regr = dcv.GridSearchCV( |
| 706 | + ols, |
| 707 | + {"normalize": [False, True]}, |
| 708 | + scoring=["neg_mean_squared_error", "neg_mean_poisson_deviance"], |
| 709 | + refit=False, |
| 710 | + cv=cv, |
| 711 | + error_score=-1, |
| 712 | + n_jobs=1, |
| 713 | + ) |
| 714 | + regr.fit(X, y) |
| 715 | + assert (regr.cv_results_["split2_test_neg_mean_poisson_deviance"] == [-1, -1]).all() |
| 716 | + |
| 717 | + |
690 | 718 | def test_pipeline_raises():
|
691 | 719 | X, y = make_classification(n_samples=100, n_features=10, random_state=0)
|
692 | 720 |
|
@@ -946,7 +974,11 @@ def test_gridsearch_with_arraylike_fit_param(cache_cv):
|
946 | 974 | param_grid = {"foo_param": [0.0001, 0.1]}
|
947 | 975 |
|
948 | 976 | a = dcv.GridSearchCV(
|
949 |
| - MockClassifierWithFitParam(), param_grid, cv=3, refit=False, cache_cv=cache_cv, |
| 977 | + MockClassifierWithFitParam(), |
| 978 | + param_grid, |
| 979 | + cv=3, |
| 980 | + refit=False, |
| 981 | + cache_cv=cache_cv, |
950 | 982 | )
|
951 | 983 | b = GridSearchCV(MockClassifierWithFitParam(), param_grid, cv=3, refit=False)
|
952 | 984 |
|
|
0 commit comments