Skip to content

Commit 4ba04aa

Browse files
committed
Add model type exception, add classification threshold, return one-dimensional arrays from prediction funcs
1 parent cca25dd commit 4ba04aa

File tree

2 files changed

+35
-11
lines changed

2 files changed

+35
-11
lines changed

eis_toolkit/exceptions.py

+4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ class InvalidDataShapeException(Exception):
3434
"""Exception error for datasets with invalid shapes."""
3535

3636

37+
class InvalidModelTypeException(Exception):
38+
"""Exception error for invalid model type."""
39+
40+
3741
class InvalidParameterValueException(Exception):
3842
"""Exception error class for invalid parameter values."""
3943

eis_toolkit/prediction/machine_learning_predict.py

+31-11
Original file line numberDiff line numberDiff line change
@@ -2,40 +2,55 @@
22
import pandas as pd
33
from beartype import beartype
44
from beartype.typing import Tuple, Union
5-
from sklearn.base import BaseEstimator
5+
from sklearn.base import BaseEstimator, is_classifier
66
from tensorflow import keras
77

8+
from eis_toolkit.exceptions import InvalidModelTypeException
9+
810

911
@beartype
1012
def predict_classifier(
11-
data: Union[np.ndarray, pd.DataFrame], model: Union[BaseEstimator, keras.Model], include_probabilities: bool = True
13+
data: Union[np.ndarray, pd.DataFrame],
14+
model: Union[BaseEstimator, keras.Model],
15+
classification_threshold: float = 0.5,
16+
include_probabilities: bool = True,
1217
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
1318
"""
14-
Predict with a trained model.
19+
Predict with a trained classifier model.
20+
21+
Only works for binary classification currently.
1522
1623
Args:
1724
data: Data used to make predictions.
1825
model: Trained classifier or regressor. Can be any machine learning model trained with
1926
EIS Toolkit (Sklearn and Keras models).
27+
classification_threshold: Threshold for classifying based on probabilities. Defaults to 0.5.
2028
include_probabilities: If the probability array should be returned too. Defaults to True.
2129
2230
Returns:
23-
Predicted labels and optionally predicted probabilities by a classifier model.
31+
Predicted labels and optionally predicted probabilities as one-dimensional arrays by a classifier model.
32+
33+
Raises:
34+
InvalidModelTypeException: Input model is not a classifier model.
2435
"""
2536
if isinstance(model, keras.Model):
26-
probabilities = model.predict(data)
27-
labels = probabilities.argmax(axis=-1)
37+
probabilities = model.predict(data).squeeze()
38+
labels = probabilities >= classification_threshold
2839
if include_probabilities:
29-
return labels, probabilities
40+
return labels, probabilities.astype(np.float32)
3041
else:
3142
return labels
3243
elif isinstance(model, BaseEstimator):
33-
labels = model.predict(data)
44+
if not is_classifier(model):
45+
raise InvalidModelTypeException(f"Expected a classifier model: {type(model)}.")
46+
probabilities = model.predict_proba(data)[:, 1]
47+
labels = (probabilities >= classification_threshold).astype(np.float32)
3448
if include_probabilities:
35-
probabilities = model.predict_proba(data)
36-
return labels, probabilities
49+
return labels, probabilities.astype(np.float32)
3750
else:
3851
return labels
52+
else:
53+
raise InvalidModelTypeException(f"Model type not recognized: {type(model)}.")
3954

4055

4156
@beartype
@@ -44,7 +59,7 @@ def predict_regressor(
4459
model: Union[BaseEstimator, keras.Model],
4560
) -> np.ndarray:
4661
"""
47-
Predict with a trained model.
62+
Predict with a trained regressor model.
4863
4964
Args:
5065
data: Data used to make predictions.
@@ -53,6 +68,11 @@ def predict_regressor(
5368
5469
Returns:
5570
Regression model prediction array.
71+
72+
Raises:
73+
InvalidModelTypeException: Input model is not a regressor model.
5674
"""
75+
if is_classifier(model):
76+
raise InvalidModelTypeException(f"Expected a regressor model: {type(model)}.")
5777
result = model.predict(data)
5878
return result

0 commit comments

Comments
 (0)