Skip to content

Commit ef247d3

Browse files
committed
Added wrappers for splitting data and predicting (tests included)
1 parent 9becd2b commit ef247d3

File tree

2 files changed

+86
-6
lines changed

2 files changed

+86
-6
lines changed

eis_toolkit/prediction/model_utils.py

+58-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from numbers import Number
12
from pathlib import Path
23

34
import joblib
45
import numpy as np
56
import pandas as pd
67
from beartype import beartype
7-
from beartype.typing import Literal, Optional, Sequence, Tuple, Union
8+
from beartype.typing import List, Literal, Optional, Sequence, Tuple, Union
9+
from scipy import sparse
810
from sklearn.base import BaseEstimator, is_classifier, is_regressor
911
from sklearn.metrics import (
1012
accuracy_score,
@@ -16,6 +18,7 @@
1618
recall_score,
1719
)
1820
from sklearn.model_selection import KFold, LeaveOneOut, StratifiedKFold, train_test_split
21+
from tensorflow import keras
1922

2023
from eis_toolkit import exceptions
2124

@@ -52,6 +55,58 @@ def load_model(path: Path) -> BaseEstimator:
5255
return joblib.load(path)
5356

5457

58+
@beartype
59+
def split_data(
60+
*data: Union[np.ndarray, pd.DataFrame, sparse._csr.csr_matrix, List[Number]],
61+
split_size: float = 0.2,
62+
random_state: Optional[int] = 42,
63+
shuffle: bool = True,
64+
) -> List[Union[np.ndarray, pd.DataFrame, sparse._csr.csr_matrix, List[Number]]]:
65+
"""
66+
Split data into two parts.
67+
68+
For more guidance, read documentation of sklearn.model_selection.train_test_split:
69+
(https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html).
70+
71+
Args:
72+
*data: Data to be split. Multiple datasets can be given as input (for example X and y),
73+
but they need to have the same length. All datasets are split into two and the parts returned
74+
(for example X_train, X_test, y_train, y_test).
75+
split_size: The proportion of the second part of the split. Typically this is the size of test/validation
76+
part. The first part will be complemental proportion. For example, if split_size = 0.2, the first part
77+
will have 80% of the data and the second part 20% of the data. Defaults to 0.2.
78+
random_state: Seed for random number generation. Defaults to 42.
79+
shuffle: If data is shuffled before splitting. Defaults to True.
80+
81+
Returns:
82+
List containing splits of inputs (two outputs per input).
83+
"""
84+
85+
if not (0 < split_size < 1):
86+
raise exceptions.InvalidParameterValueException("Split size must be more than 0 and less than 1.")
87+
88+
split_data = train_test_split(*data, test_size=split_size, random_state=random_state, shuffle=shuffle)
89+
90+
return split_data
91+
92+
93+
@beartype
94+
def predict(model: Union[BaseEstimator, keras.Model], data: np.ndarray) -> np.ndarray:
95+
"""
96+
Predict with a trained model.
97+
98+
Args:
99+
model: Trained classifier or regressor. Can be any machine learning model trained with
100+
EIS Toolkit (Sklearn and Keras models).
101+
data: Data used to make predictions.
102+
103+
Returns:
104+
Predictions.
105+
"""
106+
result = model.predict(data)
107+
return result
108+
109+
55110
@beartype
56111
def _train_and_validate_sklearn_model(
57112
X: Union[np.ndarray, pd.DataFrame],
@@ -80,8 +135,6 @@ def _train_and_validate_sklearn_model(
80135
)
81136
if cv_folds < 2:
82137
raise exceptions.InvalidParameterValueException("Number of cross-validation folds must be at least 2.")
83-
if not (0 < split_size < 1):
84-
raise exceptions.InvalidParameterValueException("Split size must be more than 0 and less than 1.")
85138

86139
# Approach 1: No validation
87140
if validation_method == NO_VALIDATION:
@@ -92,8 +145,8 @@ def _train_and_validate_sklearn_model(
92145

93146
# Approach 2: Validation with splitting data once
94147
elif validation_method == SPLIT:
95-
X_train, X_valid, y_train, y_valid = train_test_split(
96-
X, y, test_size=split_size, random_state=random_state, shuffle=True
148+
X_train, X_valid, y_train, y_valid = split_data(
149+
X, y, split_size=split_size, random_state=random_state, shuffle=True
97150
)
98151
model.fit(X_train, y_train)
99152
y_pred = model.predict(X_valid)

tests/prediction/model_utils_test.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
from sklearn.ensemble import RandomForestClassifier
88

99
from eis_toolkit import exceptions
10-
from eis_toolkit.prediction.model_utils import _train_and_validate_sklearn_model, load_model, save_model
10+
from eis_toolkit.prediction.model_utils import (
11+
_train_and_validate_sklearn_model,
12+
load_model,
13+
predict,
14+
save_model,
15+
split_data,
16+
)
1117

1218
TEST_DIR = Path(__file__).parent.parent
1319

@@ -94,6 +100,27 @@ def test_binary_classification():
94100
assert len(out_metrics) == 4
95101

96102

103+
def test_splitting():
104+
"""Test that split data works as expected."""
105+
X_train, X_test, y_train, y_test = split_data(X_IRIS, Y_IRIS, split_size=0.2)
106+
np.testing.assert_equal(len(X_train), len(X_IRIS) * 0.8)
107+
np.testing.assert_equal(len(y_train), len(Y_IRIS) * 0.8)
108+
np.testing.assert_equal(len(X_test), len(X_IRIS) * 0.2)
109+
np.testing.assert_equal(len(y_test), len(Y_IRIS) * 0.2)
110+
111+
112+
def test_predict_sklearn():
113+
"""Test that predict works as expected with a Sklearn model."""
114+
X_train, X_test, y_train, y_test = split_data(X_IRIS, Y_IRIS, split_size=0.2)
115+
116+
model, _ = _train_and_validate_sklearn_model(
117+
X_train, y_train, model=RF_MODEL, validation_method="none", metrics=CLF_METRICS, random_state=42
118+
)
119+
120+
predicted_labels = predict(model, X_test)
121+
assert len(predicted_labels) == len(y_test)
122+
123+
97124
def test_save_and_load_model():
98125
"""Test that saving and loading a model works as expected."""
99126
model_save_path = TEST_DIR.joinpath("data/local/results/saved_rf_model.joblib")

0 commit comments

Comments
 (0)