Skip to content

Commit 98b56ff

Browse files
committed
make changes according to review
1 parent 774b078 commit 98b56ff

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

eis_toolkit/prediction/machine_learning_predict.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ def predict_classifier(
2929
return labels, probabilities
3030
else:
3131
return labels
32-
elif include_probabilities:
33-
probabilities = model.predict_proba(data)
32+
elif isinstance(model, BaseEstimator):
3433
labels = model.predict(data)
35-
return labels, probabilities
36-
else:
37-
labels = model.predict(data)
38-
return labels
34+
if include_probabilities:
35+
probabilities = model.predict_proba(data)
36+
return labels, probabilities
37+
else:
38+
return labels
3939

4040

4141
@beartype

eis_toolkit/validation/classification_probability_evaluation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def plot_det_curve(
9898
"""
9999
display = DetCurveDisplay.from_predictions(y_true, y_prob, ax=ax, **kwargs)
100100
out_ax = display.ax_
101-
out_ax.set(xlabel="False positive rate", ylabel="True positive rate", title=plot_title)
101+
out_ax.set(xlabel="False positive rate", ylabel="False negative rate", title=plot_title)
102102
return out_ax
103103

104104

@@ -129,7 +129,7 @@ def plot_precision_recall_curve(
129129
"""
130130
display = PrecisionRecallDisplay.from_predictions(y_true, y_prob, plot_chance_level=True, ax=ax, **kwargs)
131131
out_ax = display.ax_
132-
out_ax.set(xlabel="False positive rate", ylabel="True positive rate", title=plot_title)
132+
out_ax.set(xlabel="Recall", ylabel="Precision", title=plot_title)
133133
return out_ax
134134

135135

eis_toolkit/validation/plot_prediction_area_curves.py

+2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def plot_prediction_area_curves(
6565
Plots prediction area plot that can be used to evaluate mineral prospectivity maps and evidential layers. See e.g.,
6666
Yousefi and Carranza (2015).
6767
68+
The inputs needed for this tool can be obtained with calculate_base_metrics() tool.
69+
6870
Args:
6971
true_positive_rate_values: True positive rate values.
7072
proportion_of_area_values: Proportion of area values.

0 commit comments

Comments
 (0)