@@ -3104,6 +3104,159 @@ def plot_roc_curve_cli(
3104
3104
typer .echo ("ROC curve plot completed" + echo_str_end )
3105
3105
3106
3106
3107
+ @app .command ()
3108
+ def plot_det_curve_cli (
3109
+ true_labels : INPUT_FILE_OPTION ,
3110
+ probabilities : INPUT_FILE_OPTION ,
3111
+ output_file : OUTPUT_FILE_OPTION ,
3112
+ show_plot : bool = False ,
3113
+ save_dpi : Optional [int ] = None ,
3114
+ ):
3115
+ """
3116
+ Plot DET (detection error tradeoff) curve.
3117
+
3118
+ DET curve is a binary classification multi-threshold metric. DET curves are a variation of ROC curves where
3119
+ False Negative Rate is plotted on the y-axis instead of True Positive Rate. The ideal performance corner of
3120
+ the plot is bottom-left. When comparing the performance of different models, DET curves can be
3121
+ slightly easier to assess visually than ROC curves.
3122
+ """
3123
+ import matplotlib .pyplot as plt
3124
+
3125
+ from eis_toolkit .evaluation .classification_probability_evaluation import plot_det_curve
3126
+ from eis_toolkit .prediction .machine_learning_general import read_data_for_evaluation
3127
+
3128
+ typer .echo ("Progress: 10%" )
3129
+
3130
+ (y_prob , y_true ), _ , _ = read_data_for_evaluation ([probabilities , true_labels ])
3131
+ typer .echo ("Progress: 25%" )
3132
+
3133
+ _ = plot_det_curve (y_true = y_true , y_prob = y_prob )
3134
+ typer .echo ("Progress: 75%" )
3135
+ if show_plot :
3136
+ plt .show ()
3137
+
3138
+ if output_file is not None :
3139
+ dpi = "figure" if save_dpi is None else save_dpi
3140
+ plt .savefig (output_file , dpi = dpi )
3141
+ echo_str_end = f", output figure saved to { output_file } ."
3142
+ typer .echo ("Progress: 100% \n " )
3143
+
3144
+ typer .echo ("DET curve plot completed" + echo_str_end )
3145
+
3146
+
3147
+ @app .command ()
3148
+ def plot_precision_recall_curve_cli (
3149
+ true_labels : INPUT_FILE_OPTION ,
3150
+ probabilities : INPUT_FILE_OPTION ,
3151
+ output_file : OUTPUT_FILE_OPTION ,
3152
+ show_plot : bool = False ,
3153
+ save_dpi : Optional [int ] = None ,
3154
+ ):
3155
+ """
3156
+ Plot precision-recall curve.
3157
+
3158
+ Precision-recall curve is a binary classification multi-threshold metric. Precision-recall curve shows
3159
+ the tradeoff between precision and recall for different classification thresholds.
3160
+ It can be a useful measure of success when classes are imbalanced.
3161
+ """
3162
+ import matplotlib .pyplot as plt
3163
+
3164
+ from eis_toolkit .evaluation .classification_probability_evaluation import plot_precision_recall_curve
3165
+ from eis_toolkit .prediction .machine_learning_general import read_data_for_evaluation
3166
+
3167
+ typer .echo ("Progress: 10%" )
3168
+
3169
+ (y_prob , y_true ), _ , _ = read_data_for_evaluation ([probabilities , true_labels ])
3170
+ typer .echo ("Progress: 25%" )
3171
+
3172
+ _ = plot_precision_recall_curve (y_true = y_true , y_prob = y_prob )
3173
+ typer .echo ("Progress: 75%" )
3174
+ if show_plot :
3175
+ plt .show ()
3176
+
3177
+ if output_file is not None :
3178
+ dpi = "figure" if save_dpi is None else save_dpi
3179
+ plt .savefig (output_file , dpi = dpi )
3180
+ echo_str_end = f", output figure saved to { output_file } ."
3181
+ typer .echo ("Progress: 100% \n " )
3182
+
3183
+ typer .echo ("Precision-Recall curve plot completed" + echo_str_end )
3184
+
3185
+
3186
+ @app .command ()
3187
+ def plot_calibration_curve_cli (
3188
+ true_labels : INPUT_FILE_OPTION ,
3189
+ probabilities : INPUT_FILE_OPTION ,
3190
+ output_file : OUTPUT_FILE_OPTION ,
3191
+ n_bins : int = 5 ,
3192
+ show_plot : bool = False ,
3193
+ save_dpi : Optional [int ] = None ,
3194
+ ):
3195
+ """
3196
+ Plot calibration curve (aka realibity diagram).
3197
+
3198
+ Calibration curve has the frequency of the positive labels on the y-axis and the predicted probability on
3199
+ the x-axis. Generally, the close the calibration curve is to line x=y, the better the model is calibrated.
3200
+ """
3201
+ import matplotlib .pyplot as plt
3202
+
3203
+ from eis_toolkit .evaluation .classification_probability_evaluation import plot_calibration_curve
3204
+ from eis_toolkit .prediction .machine_learning_general import read_data_for_evaluation
3205
+
3206
+ typer .echo ("Progress: 10%" )
3207
+
3208
+ (y_prob , y_true ), _ , _ = read_data_for_evaluation ([probabilities , true_labels ])
3209
+ typer .echo ("Progress: 25%" )
3210
+
3211
+ _ = plot_calibration_curve (y_true = y_true , y_prob = y_prob , n_bins = n_bins )
3212
+ typer .echo ("Progress: 75%" )
3213
+ if show_plot :
3214
+ plt .show ()
3215
+
3216
+ if output_file is not None :
3217
+ dpi = "figure" if save_dpi is None else save_dpi
3218
+ plt .savefig (output_file , dpi = dpi )
3219
+ echo_str_end = f", output figure saved to { output_file } ."
3220
+ typer .echo ("Progress: 100% \n " )
3221
+
3222
+ typer .echo ("Calibration curve plot completed" + echo_str_end )
3223
+
3224
+
3225
+ @app .command ()
3226
+ def plot_confusion_matrix_cli (
3227
+ true_labels : INPUT_FILE_OPTION ,
3228
+ predictions : INPUT_FILE_OPTION ,
3229
+ output_file : OUTPUT_FILE_OPTION ,
3230
+ show_plot : bool = False ,
3231
+ save_dpi : Optional [int ] = None ,
3232
+ ):
3233
+ """Plot confusion matrix to visualize classification results."""
3234
+ import matplotlib .pyplot as plt
3235
+ from sklearn .metrics import confusion_matrix
3236
+
3237
+ from eis_toolkit .evaluation .plot_confusion_matrix import plot_confusion_matrix
3238
+ from eis_toolkit .prediction .machine_learning_general import read_data_for_evaluation
3239
+
3240
+ typer .echo ("Progress: 10%" )
3241
+
3242
+ (y_pred , y_true ), _ , _ = read_data_for_evaluation ([predictions , true_labels ])
3243
+ typer .echo ("Progress: 25%" )
3244
+
3245
+ matrix = confusion_matrix (y_true , y_pred )
3246
+ _ = plot_confusion_matrix (confusion_matrix = matrix )
3247
+ typer .echo ("Progress: 75%" )
3248
+ if show_plot :
3249
+ plt .show ()
3250
+
3251
+ if output_file is not None :
3252
+ dpi = "figure" if save_dpi is None else save_dpi
3253
+ plt .savefig (output_file , dpi = dpi )
3254
+ echo_str_end = f", output figure saved to { output_file } ."
3255
+ typer .echo ("Progress: 100% \n " )
3256
+
3257
+ typer .echo ("Confusion matrix plot completed" + echo_str_end )
3258
+
3259
+
3107
3260
@app .command ()
3108
3261
def score_predictions_cli (
3109
3262
true_labels : INPUT_FILE_OPTION ,
0 commit comments