@@ -18,13 +18,12 @@ def predict_classifier(
18
18
"""
19
19
Predict with a trained classifier model.
20
20
21
- Only works for binary classification currently.
22
-
23
21
Args:
24
22
data: Data used to make predictions.
25
23
model: Trained classifier or regressor. Can be any machine learning model trained with
26
24
EIS Toolkit (Sklearn and Keras models).
27
- classification_threshold: Threshold for classifying based on probabilities. Defaults to 0.5.
25
+ classification_threshold: Threshold for classifying based on probabilities. Only used for
26
+ binary classification. Defaults to 0.5.
28
27
include_probabilities: If the probability array should be returned too. Defaults to True.
29
28
30
29
Returns:
@@ -34,19 +33,27 @@ def predict_classifier(
34
33
InvalidModelTypeException: Input model is not a classifier model.
35
34
"""
36
35
if isinstance (model , keras .Model ):
37
- probabilities = model .predict (data ).squeeze ()
38
- labels = probabilities >= classification_threshold
36
+ probabilities = model .predict (data ).astype (np .float32 )
37
+ if probabilities .shape [1 ] == 1 : # Binary classification
38
+ probabilities = probabilities .squeeze ()
39
+ labels = (probabilities >= classification_threshold ).astype (np .float32 )
40
+ else : # Multiclass classification
41
+ labels = probabilities .argmax (axis = - 1 ).astype (np .float32 )
39
42
if include_probabilities :
40
- return labels , probabilities . astype ( np . float32 )
43
+ return labels , probabilities
41
44
else :
42
45
return labels
43
46
elif isinstance (model , BaseEstimator ):
44
47
if not is_classifier (model ):
45
48
raise InvalidModelTypeException (f"Expected a classifier model: { type (model )} ." )
46
- probabilities = model .predict_proba (data )[:, 1 ]
47
- labels = (probabilities >= classification_threshold ).astype (np .float32 )
49
+ probabilities = model .predict_proba (data ).astype (np .float32 )
50
+ if probabilities .shape [1 ] == 2 : # Binary classification
51
+ probabilities = probabilities [:, 1 ]
52
+ labels = (probabilities >= classification_threshold ).astype (np .float32 )
53
+ else : # Multiclass classification
54
+ labels = probabilities .argmax (axis = - 1 ).astype (np .float32 )
48
55
if include_probabilities :
49
- return labels , probabilities . astype ( np . float32 )
56
+ return labels , probabilities
50
57
else :
51
58
return labels
52
59
else :
0 commit comments