Skip to content

Commit d3c6808

Browse files
authored
Merge pull request #368 from GispoCoding/347-refactor-validation-category
347 refactor validation category, now called evaluation
2 parents 88abd03 + c673683 commit d3c6808

40 files changed

+1082
-654
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Calculate base metrics
2+
3+
::: eis_toolkit.evaluation.calculate_base_metrics
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Classification label evaluation
2+
3+
::: eis_toolkit.evaluation.classification_label_evaluation
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Classification probability evaluation
2+
3+
::: eis_toolkit.evaluation.classification_probability_evaluation
+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Plot confusion matrix
2+
3+
::: eis_toolkit.evaluation.plot_confusion_matrix
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Plot neural network training performance (accuracy and loss)
22

3-
::: eis_toolkit.validation.plot_nn_model_performance
3+
::: eis_toolkit.evaluation.plot_nn_model_performance
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Plot prediction-area (P-A) curves
2+
3+
::: eis_toolkit.evaluation.plot_prediction_area_curves

docs/evaluation/plot_rate_curve.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Plot rate curve
2+
3+
::: eis_toolkit.evaluation.plot_rate_curve

docs/validation/calculate_auc.md

-3
This file was deleted.

docs/validation/calculate_base_metrics.md

-3
This file was deleted.

docs/validation/get_pa_intersection.md

-3
This file was deleted.

docs/validation/plot_confusion_matrix.md

-3
This file was deleted.

docs/validation/plot_prediction_area_curves.md

-3
This file was deleted.

docs/validation/plot_rate_curve.md

-3
This file was deleted.

eis_toolkit/cli.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2963,7 +2963,7 @@ def winsorize_transform_cli(
29632963
typer.echo(f"Winsorize transform completed, writing raster to {output_raster}.")
29642964

29652965

2966-
# ---VALIDATION ---
2966+
# ---EVALUATION ---
29672967
# TODO
29682968

29692969

File renamed without changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from numbers import Number
2+
from typing import Dict
3+
4+
import numpy as np
5+
from sklearn.metrics import accuracy_score, confusion_matrix, precision_recall_fscore_support
6+
7+
8+
def summarize_label_metrics_binary(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, Number]:
9+
"""
10+
Generate a comprehensive report of various evaluation metrics for binary classification results.
11+
12+
The output includes accuracy, precision, recall, F1 scores and confusion matrix elements
13+
(true negatives, false positives, false negatives, true positives).
14+
15+
Args:
16+
y_true: True labels.
17+
y_pred: Predicted labels. The array should come from a binary classifier.
18+
19+
Returns:
20+
A dictionary containing the evaluated metrics.
21+
"""
22+
metrics = {}
23+
24+
metrics["Accuracy"] = accuracy_score(y_true, y_pred)
25+
26+
precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary")
27+
metrics["Precision"] = precision
28+
metrics["Recall"] = recall
29+
metrics["F1_score"] = f1
30+
31+
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
32+
metrics["True_negatives"] = tn
33+
metrics["False_positives"] = fp
34+
metrics["False_negatives"] = fn
35+
metrics["True_positives"] = tp
36+
37+
return metrics
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
from typing import Dict
2+
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
import seaborn as sns
6+
from beartype.typing import Optional
7+
from sklearn.calibration import CalibrationDisplay
8+
from sklearn.metrics import (
9+
DetCurveDisplay,
10+
PrecisionRecallDisplay,
11+
RocCurveDisplay,
12+
average_precision_score,
13+
brier_score_loss,
14+
log_loss,
15+
roc_auc_score,
16+
)
17+
18+
19+
def summarize_probability_metrics(y_true: np.ndarray, y_prob: np.ndarray) -> Dict[str, float]:
20+
"""
21+
Generate a comprehensive report of various evaluation metrics for classification probabilities.
22+
23+
The output includes ROC AUC, log loss, average precision and Brier score loss.
24+
25+
Args:
26+
y_true: True labels.
27+
y_prob: Predicted probabilities for the positive class. The array should come from
28+
a binary classifier.
29+
30+
Returns:
31+
A dictionary containing the evaluated metrics.
32+
"""
33+
metrics = {}
34+
35+
metrics["roc_auc"] = roc_auc_score(y_true, y_prob)
36+
metrics["log_loss"] = log_loss(y_true, y_prob)
37+
metrics["average_precision"] = average_precision_score(y_true, y_prob)
38+
metrics["brier_score_loss"] = brier_score_loss(y_true, y_prob)
39+
40+
return metrics
41+
42+
43+
def plot_roc_curve(
44+
y_true: np.ndarray,
45+
y_prob: np.ndarray,
46+
plot_title: Optional[str] = "ROC curve",
47+
ax: Optional[plt.Axes] = None,
48+
**kwargs
49+
) -> plt.Axes:
50+
"""
51+
Plot ROC (receiver operating characteristic) curve.
52+
53+
ROC curve is a binary classification multi-threshold metric. The ideal performance corner of the plot
54+
is top-left. AUC of the ROC curve summarizes model performance across different classification thresholds.
55+
56+
Args:
57+
y_true: True labels.
58+
y_prob: Predicted probabilities for the positive class. The array should come from
59+
a binary classifier.
60+
plot_title: Title for the plot. Defaults to "ROC curve".
61+
ax: An existing Axes in which to draw the plot. Defaults to None.
62+
**kwargs: Additional keyword arguments passed to matplotlib.pyplot.plot.
63+
64+
Returns:
65+
Matplotlib axes containing the plot.
66+
"""
67+
display = RocCurveDisplay.from_predictions(y_true, y_prob, plot_chance_level=True, ax=ax, **kwargs)
68+
out_ax = display.ax_
69+
out_ax.set(xlabel="False positive rate", ylabel="True positive rate", title=plot_title)
70+
return out_ax
71+
72+
73+
def plot_det_curve(
74+
y_true: np.ndarray,
75+
y_prob: np.ndarray,
76+
plot_title: Optional[str] = "DET curve",
77+
ax: Optional[plt.Axes] = None,
78+
**kwargs
79+
) -> plt.Axes:
80+
"""
81+
Plot DET (detection error tradeoff) curve.
82+
83+
DET curve is a binary classification multi-threshold metric. DET curves are a variation of ROC curves where
84+
False Negative Rate is plotted on the y-axis instead of True Positive Rate. The ideal performance corner of
85+
the plot is bottom-left. When comparing the performance of different models, DET curves can be
86+
slightly easier to assess visually than ROC curves.
87+
88+
Args:
89+
y_true: True labels.
90+
y_prob: Predicted probabilities for the positive class. The array should come from
91+
a binary classifier.
92+
plot_title: Title for the plot. Defaults to "DET curve".
93+
ax: An existing Axes in which to draw the plot. Defaults to None.
94+
**kwargs: Additional keyword arguments passed to matplotlib.pyplot.plot.
95+
96+
Returns:
97+
Matplotlib axes containing the plot.
98+
"""
99+
display = DetCurveDisplay.from_predictions(y_true, y_prob, ax=ax, **kwargs)
100+
out_ax = display.ax_
101+
out_ax.set(xlabel="False positive rate", ylabel="False negative rate", title=plot_title)
102+
return out_ax
103+
104+
105+
def plot_precision_recall_curve(
106+
y_true: np.ndarray,
107+
y_prob: np.ndarray,
108+
plot_title: Optional[str] = "Precision-Recall curve",
109+
ax: Optional[plt.Axes] = None,
110+
**kwargs
111+
) -> plt.Axes:
112+
"""
113+
Plot precision-recall curve.
114+
115+
Precision-recall curve is a binary classification multi-threshold metric. Precision-recall curve shows
116+
the tradeoff between precision and recall for different classification thresholds.
117+
It can be a useful measure of success when classes are imbalanced.
118+
119+
Args:
120+
y_true: True labels.
121+
y_prob: Predicted probabilities for the positive class. The array should come from
122+
a binary classifier.
123+
plot_title: Title for the plot. Defaults to "Precision-Recall curve".
124+
ax: An existing Axes in which to draw the plot. Defaults to None.
125+
**kwargs: Additional keyword arguments passed to matplotlib.pyplot.plot.
126+
127+
Returns:
128+
Matplotlib axes containing the plot.
129+
"""
130+
display = PrecisionRecallDisplay.from_predictions(y_true, y_prob, plot_chance_level=True, ax=ax, **kwargs)
131+
out_ax = display.ax_
132+
out_ax.set(xlabel="Recall", ylabel="Precision", title=plot_title)
133+
return out_ax
134+
135+
136+
def plot_calibration_curve(
137+
y_true: np.ndarray,
138+
y_prob: np.ndarray,
139+
n_bins: int = 5,
140+
plot_title: Optional[str] = "Calibration curve",
141+
ax: Optional[plt.Axes] = None,
142+
**kwargs
143+
) -> plt.Axes:
144+
"""
145+
Plot calibration curve (aka realibity diagram).
146+
147+
Calibration curve has the frequency of the positive labels on the y-axis and the predicted probability on
148+
the x-axis. Generally, the close the calibration curve is to line x=y, the better the model is calibrated.
149+
150+
Args:
151+
y_true: True labels.
152+
y_prob: Predicted probabilities for the positive class. The array should come from
153+
a binary classifier.
154+
plot_title: Title for the plot. Defaults to "Precision-Recall curve".
155+
ax: An existing Axes in which to draw the plot. Defaults to None.
156+
**kwargs: Additional keyword arguments passed to matplotlib.pyplot.plot.
157+
158+
Returns:
159+
Matplotlib axes containing the plot.
160+
"""
161+
display = CalibrationDisplay.from_predictions(y_true, y_prob, n_bins=n_bins, ax=ax, **kwargs)
162+
out_ax = display.ax_
163+
out_ax.set(xlabel="Mean predicted probability", ylabel="Fraction of positives", title=plot_title)
164+
return out_ax
165+
166+
167+
def plot_predicted_probability_distribution(
168+
y_prob: np.ndarray,
169+
n_bins: int = 5,
170+
plot_title: Optional[str] = "Distribution of predicted probabilities",
171+
ax: Optional[plt.Axes] = None,
172+
**kwargs
173+
) -> plt.Axes:
174+
"""
175+
Plot a histogram of the predicted probabilities.
176+
177+
Args:
178+
y_prob: Predicted probabilities for the positive class. The array should come from
179+
a binary classifier.
180+
n_bins: Number of bins used for the histogram. Defaults to 5.
181+
plot_title: Title for the plot. Defaults to "Distribution of predicted probabilities".
182+
ax: An existing Axes in which to draw the plot. Defaults to None.
183+
**kwargs: Additional keyword arguments passed to sns.histplot and matplotlib.
184+
185+
Returns:
186+
Matplolib axes containing the plot.
187+
"""
188+
sns.set_theme(style="white")
189+
plt.figure()
190+
out_ax = sns.histplot(y_prob, bins=n_bins, ax=ax, **kwargs)
191+
out_ax.set(xlabel="Predicted probability", ylabel="Count", title=plot_title)
192+
return out_ax

eis_toolkit/validation/plot_confusion_matrix.py eis_toolkit/evaluation/plot_confusion_matrix.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010

1111
@beartype
1212
def plot_confusion_matrix(
13-
confusion_matrix: np.ndarray, cmap: Optional[Union[str, Colormap, Sequence]] = None
13+
confusion_matrix: np.ndarray,
14+
cmap: Optional[Union[str, Colormap, Sequence]] = None,
15+
plot_title: str = "Confusion matrix",
16+
ax: Optional[plt.Axes] = None,
17+
**kwargs,
1418
) -> plt.Axes:
1519
"""Plot confusion matrix to visualize classification results.
1620
@@ -19,6 +23,9 @@ def plot_confusion_matrix(
1923
(upper-left corner) to have True negatives.
2024
cmap: Colormap name, matploltib colormap objects or list of colors for coloring the plot.
2125
Optional parameter.
26+
plot_title: Title for the plot. Defaults to "Confusion matrix".
27+
ax: An existing Axes in which to draw the plot. Defaults to None.
28+
**kwargs: Additional keyword arguments passed to sns.heatmap.
2229
2330
Returns:
2431
Matplotlib axes containing the plot.
@@ -40,7 +47,7 @@ def plot_confusion_matrix(
4047
else:
4148
labels = np.asarray([f"{v1}\n{v2}" for v1, v2 in zip(counts, percentages)]).reshape(shape)
4249

43-
ax = sns.heatmap(confusion_matrix, annot=labels, fmt="", cmap=cmap)
44-
ax.set(xlabel="Predicted label", ylabel="True label")
50+
out_ax = sns.heatmap(confusion_matrix, annot=labels, fmt="", cmap=cmap, ax=ax, **kwargs)
51+
out_ax.set(xlabel="Predicted label", ylabel="True label", title=plot_title)
4552

46-
return ax
53+
return out_ax

0 commit comments

Comments
 (0)