Skip to content

Commit 8c75aa3

Browse files
committed
Added func to test/score Sklearn model
1 parent bd2b37e commit 8c75aa3

File tree

2 files changed

+63
-9
lines changed

2 files changed

+63
-9
lines changed

eis_toolkit/prediction/model_utils.py

+44-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
import pandas as pd
77
from beartype import beartype
8-
from beartype.typing import List, Literal, Optional, Sequence, Tuple, Union
8+
from beartype.typing import Dict, List, Literal, Optional, Sequence, Tuple, Union
99
from scipy import sparse
1010
from sklearn.base import BaseEstimator, is_classifier, is_regressor
1111
from sklearn.metrics import (
@@ -91,14 +91,55 @@ def split_data(
9191

9292

9393
@beartype
94-
def predict(model: Union[BaseEstimator, keras.Model], data: np.ndarray) -> np.ndarray:
94+
def test_model(
95+
X_test: Union[np.ndarray, pd.DataFrame],
96+
y_test: Union[np.ndarray, pd.Series],
97+
model: Union[BaseEstimator, keras.Model],
98+
metrics: Optional[Sequence[Literal["mse", "rmse", "mae", "r2", "accuracy", "precision", "recall", "f1"]]] = None,
99+
) -> Dict[str, Number]:
100+
"""
101+
Test and score a trained model.
102+
103+
TODO: Implement for Keras models.
104+
105+
Args:
106+
X_test: Test data.
107+
y_test: Target labels for test data.
108+
model: Trained Sklearn classifier or regressor.
109+
metrics: Metrics to use for scoring the model. Defaults to "accuracy" for a classifier
110+
and to "mse" for a regressor.
111+
112+
Returns:
113+
Test metric scores as a dictionary.
114+
"""
115+
x_size = X_test.index.size if isinstance(X_test, pd.DataFrame) else X_test.shape[0]
116+
if x_size != y_test.size:
117+
raise exceptions.NonMatchingParameterLengthsException(
118+
f"X and y must have the length {x_size} != {y_test.size}."
119+
)
120+
121+
if metrics is None:
122+
metrics = ["accuracy"] if is_classifier(model) else ["mse"]
123+
124+
y_pred = model.predict(X_test)
125+
126+
out_metrics = {}
127+
for metric in metrics:
128+
score = _score_model(model, y_test, y_pred, metric)
129+
out_metrics[metric] = score
130+
131+
return out_metrics
132+
133+
134+
@beartype
135+
def predict(data: Union[np.ndarray, pd.DataFrame], model: Union[BaseEstimator, keras.Model]) -> np.ndarray:
95136
"""
96137
Predict with a trained model.
97138
98139
Args:
140+
data: Data used to make predictions.
99141
model: Trained classifier or regressor. Can be any machine learning model trained with
100142
EIS Toolkit (Sklearn and Keras models).
101-
data: Data used to make predictions.
102143
103144
Returns:
104145
Predictions.

tests/prediction/model_utils_test.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
save_model,
1515
split_data,
1616
)
17+
from eis_toolkit.prediction.model_utils import test_model as model_test
1718

1819
TEST_DIR = Path(__file__).parent.parent
1920

@@ -50,7 +51,7 @@ def test_train_and_evaluate_with_split():
5051
)
5152

5253
assert isinstance(model, RandomForestClassifier)
53-
assert len(out_metrics) == 4
54+
np.testing.assert_equal(len(out_metrics), 4)
5455

5556

5657
def test_train_and_evaluate_with_kfold_cv():
@@ -60,7 +61,7 @@ def test_train_and_evaluate_with_kfold_cv():
6061
)
6162

6263
assert isinstance(model, RandomForestClassifier)
63-
assert len(out_metrics) == 4
64+
np.testing.assert_equal(len(out_metrics), 4)
6465

6566

6667
def test_train_and_evaluate_with_skfold_cv():
@@ -70,7 +71,7 @@ def test_train_and_evaluate_with_skfold_cv():
7071
)
7172

7273
assert isinstance(model, RandomForestClassifier)
73-
assert len(out_metrics) == 4
74+
np.testing.assert_equal(len(out_metrics), 4)
7475

7576

7677
def test_binary_classification():
@@ -97,7 +98,7 @@ def test_binary_classification():
9798
)
9899

99100
assert isinstance(model, RandomForestClassifier)
100-
assert len(out_metrics) == 4
101+
np.testing.assert_equal(len(out_metrics), 4)
101102

102103

103104
def test_splitting():
@@ -109,6 +110,18 @@ def test_splitting():
109110
np.testing.assert_equal(len(y_test), len(Y_IRIS) * 0.2)
110111

111112

113+
def test_test_model_sklearn():
114+
"""Test that test model works as expected with a Sklearn model."""
115+
X_train, X_test, y_train, y_test = split_data(X_IRIS, Y_IRIS, split_size=0.2)
116+
117+
model, _ = _train_and_validate_sklearn_model(
118+
X_train, y_train, model=RF_MODEL, validation_method="none", metrics=CLF_METRICS, random_state=42
119+
)
120+
121+
out_metrics = model_test(X_test, y_test, model)
122+
np.testing.assert_equal(out_metrics["accuracy"], 1.0)
123+
124+
112125
def test_predict_sklearn():
113126
"""Test that predict works as expected with a Sklearn model."""
114127
X_train, X_test, y_train, y_test = split_data(X_IRIS, Y_IRIS, split_size=0.2)
@@ -117,8 +130,8 @@ def test_predict_sklearn():
117130
X_train, y_train, model=RF_MODEL, validation_method="none", metrics=CLF_METRICS, random_state=42
118131
)
119132

120-
predicted_labels = predict(model, X_test)
121-
assert len(predicted_labels) == len(y_test)
133+
predicted_labels = predict(X_test, model)
134+
np.testing.assert_equal(len(predicted_labels), len(y_test))
122135

123136

124137
def test_save_and_load_model():

0 commit comments

Comments
 (0)