Skip to content

Commit 6da73e4

Browse files
committed
Fix one print issue, enable multiclass predicting for classifiers
1 parent b913703 commit 6da73e4

File tree

3 files changed

+22
-13
lines changed

3 files changed

+22
-13
lines changed

eis_toolkit/cli.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2359,8 +2359,10 @@ def classifier_test_cli(
23592359
typer.echo(f"Results: {json_str}")
23602360

23612361
typer.echo(
2362-
f"Testing classifier model completed, writing rasters to \
2363-
{output_raster_probability} and {output_raster_classified}."
2362+
(
2363+
"Testing classifier model completed, writing rasters to "
2364+
f"{output_raster_probability} and {output_raster_classified}."
2365+
)
23642366
)
23652367

23662368

eis_toolkit/prediction/machine_learning_predict.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,12 @@ def predict_classifier(
1818
"""
1919
Predict with a trained classifier model.
2020
21-
Only works for binary classification currently.
22-
2321
Args:
2422
data: Data used to make predictions.
2523
model: Trained classifier or regressor. Can be any machine learning model trained with
2624
EIS Toolkit (Sklearn and Keras models).
27-
classification_threshold: Threshold for classifying based on probabilities. Defaults to 0.5.
25+
classification_threshold: Threshold for classifying based on probabilities. Only used for
26+
binary classification. Defaults to 0.5.
2827
include_probabilities: If the probability array should be returned too. Defaults to True.
2928
3029
Returns:
@@ -34,19 +33,27 @@ def predict_classifier(
3433
InvalidModelTypeException: Input model is not a classifier model.
3534
"""
3635
if isinstance(model, keras.Model):
37-
probabilities = model.predict(data).squeeze()
38-
labels = probabilities >= classification_threshold
36+
probabilities = model.predict(data).astype(np.float32)
37+
if probabilities.shape[1] == 1: # Binary classification
38+
probabilities = probabilities.squeeze()
39+
labels = (probabilities >= classification_threshold).astype(np.float32)
40+
else: # Multiclass classification
41+
labels = probabilities.argmax(axis=-1).astype(np.float32)
3942
if include_probabilities:
40-
return labels, probabilities.astype(np.float32)
43+
return labels, probabilities
4144
else:
4245
return labels
4346
elif isinstance(model, BaseEstimator):
4447
if not is_classifier(model):
4548
raise InvalidModelTypeException(f"Expected a classifier model: {type(model)}.")
46-
probabilities = model.predict_proba(data)[:, 1]
47-
labels = (probabilities >= classification_threshold).astype(np.float32)
49+
probabilities = model.predict_proba(data).astype(np.float32)
50+
if probabilities.shape[1] == 2: # Binary classification
51+
probabilities = probabilities[:, 1]
52+
labels = (probabilities >= classification_threshold).astype(np.float32)
53+
else: # Multiclass classification
54+
labels = probabilities.argmax(axis=-1).astype(np.float32)
4855
if include_probabilities:
49-
return labels, probabilities.astype(np.float32)
56+
return labels, probabilities
5057
else:
5158
return labels
5259
else:

tests/prediction/machine_learning_general_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def test_evaluate_model_sklearn():
118118
X_train, y_train, model=RF_MODEL, validation_method="none", metrics=CLF_METRICS, random_state=42
119119
)
120120

121-
predictions = predict_classifier(X_test, model, include_probabilities=False)
121+
predictions = predict_classifier(X_test, model, classification_threshold=0.5, include_probabilities=False)
122122
accuracy = score_predictions(y_test, predictions, "accuracy")
123123
np.testing.assert_equal(accuracy, 1.0)
124124

@@ -131,7 +131,7 @@ def test_predict_classifier_sklearn():
131131
X_train, y_train, model=RF_MODEL, validation_method="none", metrics=CLF_METRICS, random_state=42
132132
)
133133

134-
predicted_labels, predicted_probabilities = predict_classifier(X_test, model, True)
134+
predicted_labels, predicted_probabilities = predict_classifier(X_test, model, include_probabilities=True)
135135
np.testing.assert_equal(len(predicted_labels), len(y_test))
136136
np.testing.assert_equal(len(predicted_probabilities), len(y_test))
137137

0 commit comments

Comments
 (0)