Skip to content

Commit ba0df6b

Browse files
committed
Added CLI functions for DET curve, precision-recall curve, confusion matrix plot and calibration curve
1 parent c996979 commit ba0df6b

File tree

1 file changed

+153
-0
lines changed

1 file changed

+153
-0
lines changed

eis_toolkit/cli.py

+153
Original file line numberDiff line numberDiff line change
@@ -3104,6 +3104,159 @@ def plot_roc_curve_cli(
31043104
typer.echo("ROC curve plot completed" + echo_str_end)
31053105

31063106

3107+
@app.command()
3108+
def plot_det_curve_cli(
3109+
true_labels: INPUT_FILE_OPTION,
3110+
probabilities: INPUT_FILE_OPTION,
3111+
output_file: OUTPUT_FILE_OPTION,
3112+
show_plot: bool = False,
3113+
save_dpi: Optional[int] = None,
3114+
):
3115+
"""
3116+
Plot DET (detection error tradeoff) curve.
3117+
3118+
DET curve is a binary classification multi-threshold metric. DET curves are a variation of ROC curves where
3119+
False Negative Rate is plotted on the y-axis instead of True Positive Rate. The ideal performance corner of
3120+
the plot is bottom-left. When comparing the performance of different models, DET curves can be
3121+
slightly easier to assess visually than ROC curves.
3122+
"""
3123+
import matplotlib.pyplot as plt
3124+
3125+
from eis_toolkit.evaluation.classification_probability_evaluation import plot_det_curve
3126+
from eis_toolkit.prediction.machine_learning_general import read_data_for_evaluation
3127+
3128+
typer.echo("Progress: 10%")
3129+
3130+
(y_prob, y_true), _, _ = read_data_for_evaluation([probabilities, true_labels])
3131+
typer.echo("Progress: 25%")
3132+
3133+
_ = plot_det_curve(y_true=y_true, y_prob=y_prob)
3134+
typer.echo("Progress: 75%")
3135+
if show_plot:
3136+
plt.show()
3137+
3138+
if output_file is not None:
3139+
dpi = "figure" if save_dpi is None else save_dpi
3140+
plt.savefig(output_file, dpi=dpi)
3141+
echo_str_end = f", output figure saved to {output_file}."
3142+
typer.echo("Progress: 100% \n")
3143+
3144+
typer.echo("DET curve plot completed" + echo_str_end)
3145+
3146+
3147+
@app.command()
3148+
def plot_precision_recall_curve_cli(
3149+
true_labels: INPUT_FILE_OPTION,
3150+
probabilities: INPUT_FILE_OPTION,
3151+
output_file: OUTPUT_FILE_OPTION,
3152+
show_plot: bool = False,
3153+
save_dpi: Optional[int] = None,
3154+
):
3155+
"""
3156+
Plot precision-recall curve.
3157+
3158+
Precision-recall curve is a binary classification multi-threshold metric. Precision-recall curve shows
3159+
the tradeoff between precision and recall for different classification thresholds.
3160+
It can be a useful measure of success when classes are imbalanced.
3161+
"""
3162+
import matplotlib.pyplot as plt
3163+
3164+
from eis_toolkit.evaluation.classification_probability_evaluation import plot_precision_recall_curve
3165+
from eis_toolkit.prediction.machine_learning_general import read_data_for_evaluation
3166+
3167+
typer.echo("Progress: 10%")
3168+
3169+
(y_prob, y_true), _, _ = read_data_for_evaluation([probabilities, true_labels])
3170+
typer.echo("Progress: 25%")
3171+
3172+
_ = plot_precision_recall_curve(y_true=y_true, y_prob=y_prob)
3173+
typer.echo("Progress: 75%")
3174+
if show_plot:
3175+
plt.show()
3176+
3177+
if output_file is not None:
3178+
dpi = "figure" if save_dpi is None else save_dpi
3179+
plt.savefig(output_file, dpi=dpi)
3180+
echo_str_end = f", output figure saved to {output_file}."
3181+
typer.echo("Progress: 100% \n")
3182+
3183+
typer.echo("Precision-Recall curve plot completed" + echo_str_end)
3184+
3185+
3186+
@app.command()
3187+
def plot_calibration_curve_cli(
3188+
true_labels: INPUT_FILE_OPTION,
3189+
probabilities: INPUT_FILE_OPTION,
3190+
output_file: OUTPUT_FILE_OPTION,
3191+
n_bins: int = 5,
3192+
show_plot: bool = False,
3193+
save_dpi: Optional[int] = None,
3194+
):
3195+
"""
3196+
Plot calibration curve (aka realibity diagram).
3197+
3198+
Calibration curve has the frequency of the positive labels on the y-axis and the predicted probability on
3199+
the x-axis. Generally, the close the calibration curve is to line x=y, the better the model is calibrated.
3200+
"""
3201+
import matplotlib.pyplot as plt
3202+
3203+
from eis_toolkit.evaluation.classification_probability_evaluation import plot_calibration_curve
3204+
from eis_toolkit.prediction.machine_learning_general import read_data_for_evaluation
3205+
3206+
typer.echo("Progress: 10%")
3207+
3208+
(y_prob, y_true), _, _ = read_data_for_evaluation([probabilities, true_labels])
3209+
typer.echo("Progress: 25%")
3210+
3211+
_ = plot_calibration_curve(y_true=y_true, y_prob=y_prob, n_bins=n_bins)
3212+
typer.echo("Progress: 75%")
3213+
if show_plot:
3214+
plt.show()
3215+
3216+
if output_file is not None:
3217+
dpi = "figure" if save_dpi is None else save_dpi
3218+
plt.savefig(output_file, dpi=dpi)
3219+
echo_str_end = f", output figure saved to {output_file}."
3220+
typer.echo("Progress: 100% \n")
3221+
3222+
typer.echo("Calibration curve plot completed" + echo_str_end)
3223+
3224+
3225+
@app.command()
3226+
def plot_confusion_matrix_cli(
3227+
true_labels: INPUT_FILE_OPTION,
3228+
predictions: INPUT_FILE_OPTION,
3229+
output_file: OUTPUT_FILE_OPTION,
3230+
show_plot: bool = False,
3231+
save_dpi: Optional[int] = None,
3232+
):
3233+
"""Plot confusion matrix to visualize classification results."""
3234+
import matplotlib.pyplot as plt
3235+
from sklearn.metrics import confusion_matrix
3236+
3237+
from eis_toolkit.evaluation.plot_confusion_matrix import plot_confusion_matrix
3238+
from eis_toolkit.prediction.machine_learning_general import read_data_for_evaluation
3239+
3240+
typer.echo("Progress: 10%")
3241+
3242+
(y_pred, y_true), _, _ = read_data_for_evaluation([predictions, true_labels])
3243+
typer.echo("Progress: 25%")
3244+
3245+
matrix = confusion_matrix(y_true, y_pred)
3246+
_ = plot_confusion_matrix(confusion_matrix=matrix)
3247+
typer.echo("Progress: 75%")
3248+
if show_plot:
3249+
plt.show()
3250+
3251+
if output_file is not None:
3252+
dpi = "figure" if save_dpi is None else save_dpi
3253+
plt.savefig(output_file, dpi=dpi)
3254+
echo_str_end = f", output figure saved to {output_file}."
3255+
typer.echo("Progress: 100% \n")
3256+
3257+
typer.echo("Confusion matrix plot completed" + echo_str_end)
3258+
3259+
31073260
@app.command()
31083261
def score_predictions_cli(
31093262
true_labels: INPUT_FILE_OPTION,

0 commit comments

Comments
 (0)