|
| 1 | +from typing import Dict |
| 2 | + |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +import numpy as np |
| 5 | +import seaborn as sns |
| 6 | +from beartype.typing import Optional |
| 7 | +from sklearn.calibration import CalibrationDisplay |
| 8 | +from sklearn.metrics import ( |
| 9 | + DetCurveDisplay, |
| 10 | + PrecisionRecallDisplay, |
| 11 | + RocCurveDisplay, |
| 12 | + average_precision_score, |
| 13 | + brier_score_loss, |
| 14 | + log_loss, |
| 15 | + roc_auc_score, |
| 16 | +) |
| 17 | + |
| 18 | + |
| 19 | +def summarize_probability_metrics(y_true: np.ndarray, y_prob: np.ndarray) -> Dict[str, float]: |
| 20 | + """ |
| 21 | + Generate a comprehensive report of various evaluation metrics for classification probabilities. |
| 22 | +
|
| 23 | + The output includes ROC AUC, log loss, average precision and Brier score loss. |
| 24 | +
|
| 25 | + Args: |
| 26 | + y_true: True labels. |
| 27 | + y_prob: Predicted probabilities for the positive class. The array should come from |
| 28 | + a binary classifier. |
| 29 | +
|
| 30 | + Returns: |
| 31 | + A dictionary containing the evaluated metrics. |
| 32 | + """ |
| 33 | + metrics = {} |
| 34 | + |
| 35 | + metrics["roc_auc"] = roc_auc_score(y_true, y_prob) |
| 36 | + metrics["log_loss"] = log_loss(y_true, y_prob) |
| 37 | + metrics["average_precision"] = average_precision_score(y_true, y_prob) |
| 38 | + metrics["brier_score_loss"] = brier_score_loss(y_true, y_prob) |
| 39 | + |
| 40 | + return metrics |
| 41 | + |
| 42 | + |
| 43 | +def plot_roc_curve( |
| 44 | + y_true: np.ndarray, |
| 45 | + y_prob: np.ndarray, |
| 46 | + plot_title: Optional[str] = "ROC curve", |
| 47 | + ax: Optional[plt.Axes] = None, |
| 48 | + **kwargs |
| 49 | +) -> plt.Axes: |
| 50 | + """ |
| 51 | + Plot ROC (receiver operating characteristic) curve. |
| 52 | +
|
| 53 | + ROC curve is a binary classification multi-threshold metric. The ideal performance corner of the plot |
| 54 | + is top-left. AUC of the ROC curve summarizes model performance across different classification thresholds. |
| 55 | +
|
| 56 | + Args: |
| 57 | + y_true: True labels. |
| 58 | + y_prob: Predicted probabilities for the positive class. The array should come from |
| 59 | + a binary classifier. |
| 60 | + plot_title: Title for the plot. Defaults to "ROC curve". |
| 61 | + ax: An existing Axes in which to draw the plot. Defaults to None. |
| 62 | + **kwargs: Additional keyword arguments passed to matplotlib.pyplot.plot. |
| 63 | +
|
| 64 | + Returns: |
| 65 | + Matplotlib axes containing the plot. |
| 66 | + """ |
| 67 | + display = RocCurveDisplay.from_predictions(y_true, y_prob, plot_chance_level=True, ax=ax, **kwargs) |
| 68 | + out_ax = display.ax_ |
| 69 | + out_ax.set(xlabel="False positive rate", ylabel="True positive rate", title=plot_title) |
| 70 | + return out_ax |
| 71 | + |
| 72 | + |
| 73 | +def plot_det_curve( |
| 74 | + y_true: np.ndarray, |
| 75 | + y_prob: np.ndarray, |
| 76 | + plot_title: Optional[str] = "DET curve", |
| 77 | + ax: Optional[plt.Axes] = None, |
| 78 | + **kwargs |
| 79 | +) -> plt.Axes: |
| 80 | + """ |
| 81 | + Plot DET (detection error tradeoff) curve. |
| 82 | +
|
| 83 | + DET curve is a binary classification multi-threshold metric. DET curves are a variation of ROC curves where |
| 84 | + False Negative Rate is plotted on the y-axis instead of True Positive Rate. The ideal performance corner of |
| 85 | + the plot is bottom-left. When comparing the performance of different models, DET curves can be |
| 86 | + slightly easier to assess visually than ROC curves. |
| 87 | +
|
| 88 | + Args: |
| 89 | + y_true: True labels. |
| 90 | + y_prob: Predicted probabilities for the positive class. The array should come from |
| 91 | + a binary classifier. |
| 92 | + plot_title: Title for the plot. Defaults to "DET curve". |
| 93 | + ax: An existing Axes in which to draw the plot. Defaults to None. |
| 94 | + **kwargs: Additional keyword arguments passed to matplotlib.pyplot.plot. |
| 95 | +
|
| 96 | + Returns: |
| 97 | + Matplotlib axes containing the plot. |
| 98 | + """ |
| 99 | + display = DetCurveDisplay.from_predictions(y_true, y_prob, ax=ax, **kwargs) |
| 100 | + out_ax = display.ax_ |
| 101 | + out_ax.set(xlabel="False positive rate", ylabel="False negative rate", title=plot_title) |
| 102 | + return out_ax |
| 103 | + |
| 104 | + |
| 105 | +def plot_precision_recall_curve( |
| 106 | + y_true: np.ndarray, |
| 107 | + y_prob: np.ndarray, |
| 108 | + plot_title: Optional[str] = "Precision-Recall curve", |
| 109 | + ax: Optional[plt.Axes] = None, |
| 110 | + **kwargs |
| 111 | +) -> plt.Axes: |
| 112 | + """ |
| 113 | + Plot precision-recall curve. |
| 114 | +
|
| 115 | + Precision-recall curve is a binary classification multi-threshold metric. Precision-recall curve shows |
| 116 | + the tradeoff between precision and recall for different classification thresholds. |
| 117 | + It can be a useful measure of success when classes are imbalanced. |
| 118 | +
|
| 119 | + Args: |
| 120 | + y_true: True labels. |
| 121 | + y_prob: Predicted probabilities for the positive class. The array should come from |
| 122 | + a binary classifier. |
| 123 | + plot_title: Title for the plot. Defaults to "Precision-Recall curve". |
| 124 | + ax: An existing Axes in which to draw the plot. Defaults to None. |
| 125 | + **kwargs: Additional keyword arguments passed to matplotlib.pyplot.plot. |
| 126 | +
|
| 127 | + Returns: |
| 128 | + Matplotlib axes containing the plot. |
| 129 | + """ |
| 130 | + display = PrecisionRecallDisplay.from_predictions(y_true, y_prob, plot_chance_level=True, ax=ax, **kwargs) |
| 131 | + out_ax = display.ax_ |
| 132 | + out_ax.set(xlabel="Recall", ylabel="Precision", title=plot_title) |
| 133 | + return out_ax |
| 134 | + |
| 135 | + |
| 136 | +def plot_calibration_curve( |
| 137 | + y_true: np.ndarray, |
| 138 | + y_prob: np.ndarray, |
| 139 | + n_bins: int = 5, |
| 140 | + plot_title: Optional[str] = "Calibration curve", |
| 141 | + ax: Optional[plt.Axes] = None, |
| 142 | + **kwargs |
| 143 | +) -> plt.Axes: |
| 144 | + """ |
| 145 | + Plot calibration curve (aka realibity diagram). |
| 146 | +
|
| 147 | + Calibration curve has the frequency of the positive labels on the y-axis and the predicted probability on |
| 148 | + the x-axis. Generally, the close the calibration curve is to line x=y, the better the model is calibrated. |
| 149 | +
|
| 150 | + Args: |
| 151 | + y_true: True labels. |
| 152 | + y_prob: Predicted probabilities for the positive class. The array should come from |
| 153 | + a binary classifier. |
| 154 | + plot_title: Title for the plot. Defaults to "Precision-Recall curve". |
| 155 | + ax: An existing Axes in which to draw the plot. Defaults to None. |
| 156 | + **kwargs: Additional keyword arguments passed to matplotlib.pyplot.plot. |
| 157 | +
|
| 158 | + Returns: |
| 159 | + Matplotlib axes containing the plot. |
| 160 | + """ |
| 161 | + display = CalibrationDisplay.from_predictions(y_true, y_prob, n_bins=n_bins, ax=ax, **kwargs) |
| 162 | + out_ax = display.ax_ |
| 163 | + out_ax.set(xlabel="Mean predicted probability", ylabel="Fraction of positives", title=plot_title) |
| 164 | + return out_ax |
| 165 | + |
| 166 | + |
| 167 | +def plot_predicted_probability_distribution( |
| 168 | + y_prob: np.ndarray, |
| 169 | + n_bins: int = 5, |
| 170 | + plot_title: Optional[str] = "Distribution of predicted probabilities", |
| 171 | + ax: Optional[plt.Axes] = None, |
| 172 | + **kwargs |
| 173 | +) -> plt.Axes: |
| 174 | + """ |
| 175 | + Plot a histogram of the predicted probabilities. |
| 176 | +
|
| 177 | + Args: |
| 178 | + y_prob: Predicted probabilities for the positive class. The array should come from |
| 179 | + a binary classifier. |
| 180 | + n_bins: Number of bins used for the histogram. Defaults to 5. |
| 181 | + plot_title: Title for the plot. Defaults to "Distribution of predicted probabilities". |
| 182 | + ax: An existing Axes in which to draw the plot. Defaults to None. |
| 183 | + **kwargs: Additional keyword arguments passed to sns.histplot and matplotlib. |
| 184 | +
|
| 185 | + Returns: |
| 186 | + Matplolib axes containing the plot. |
| 187 | + """ |
| 188 | + sns.set_theme(style="white") |
| 189 | + plt.figure() |
| 190 | + out_ax = sns.histplot(y_prob, bins=n_bins, ax=ax, **kwargs) |
| 191 | + out_ax.set(xlabel="Predicted probability", ylabel="Count", title=plot_title) |
| 192 | + return out_ax |
0 commit comments