Skip to content

Commit cca25dd

Browse files
authored
Merge pull request #387 from GispoCoding/386-add-evaluation-cli-tools
386 add evaluation CLI tools
2 parents 3c27177 + ba0df6b commit cca25dd

File tree

2 files changed

+404
-15
lines changed

2 files changed

+404
-15
lines changed

eis_toolkit/cli.py

+322-13
Original file line numberDiff line numberDiff line change
@@ -722,10 +722,14 @@ def descriptive_statistics_raster_cli(input_file: INPUT_FILE_OPTION):
722722
results_dict = descriptive_statistics_raster(raster)
723723
typer.echo("Progress: 75%")
724724

725-
json_str = json.dumps(results_dict)
726-
typer.echo("Progress: 100%")
727-
typer.echo(f"Results: {json_str}")
728-
typer.echo("Descriptive statistics (raster) completed")
725+
# json_str = json.dumps(results_dict)
726+
typer.echo("Progress: 100%\n")
727+
# typer.echo(f"Results: {json_str}")
728+
# typer.echo("Results:\n")
729+
for key, value in results_dict.items():
730+
typer.echo(f"{key}: {value}")
731+
732+
typer.echo("\nDescriptive statistics (raster) completed")
729733

730734

731735
# DESCRIPTIVE STATISTICS (VECTOR)
@@ -2318,23 +2322,29 @@ def evaluate_trained_model_cli(
23182322
output_raster: OUTPUT_FILE_OPTION,
23192323
validation_metrics: Annotated[List[str], typer.Option()],
23202324
):
2321-
"""Evaluate a trained machine learning model by predicting and scoring."""
2325+
"""Predict and evaluate a trained machine learning model by predicting and scoring."""
23222326
from sklearn.base import is_classifier
23232327

23242328
from eis_toolkit.evaluation.scoring import score_predictions
23252329
from eis_toolkit.prediction.machine_learning_general import load_model, prepare_data_for_ml, reshape_predictions
23262330
from eis_toolkit.prediction.machine_learning_predict import predict_classifier, predict_regressor
23272331

23282332
X, y, reference_profile, nodata_mask = prepare_data_for_ml(input_rasters, target_labels)
2329-
print(len(np.unique(y)))
23302333
typer.echo("Progress: 30%")
23312334

23322335
model = load_model(model_file)
23332336
if is_classifier(model):
23342337
predictions, probabilities = predict_classifier(X, model, True)
2338+
probabilities = probabilities[:, 1]
2339+
probabilities = probabilities.astype(np.float32)
2340+
probabilities_reshaped = reshape_predictions(
2341+
probabilities, reference_profile["height"], reference_profile["width"], nodata_mask
2342+
)
23352343
else:
23362344
predictions = predict_regressor(X, model)
2345+
23372346
metrics_dict = score_predictions(y, predictions, validation_metrics)
2347+
23382348
predictions_reshaped = reshape_predictions(
23392349
predictions, reference_profile["height"], reference_profile["width"], nodata_mask
23402350
)
@@ -2344,10 +2354,21 @@ def evaluate_trained_model_cli(
23442354
json_str = json.dumps(metrics_dict)
23452355

23462356
out_profile = reference_profile.copy()
2347-
out_profile.update({"count": 1, "dtype": predictions_reshaped.dtype})
2357+
out_profile.update({"count": 1, "dtype": np.float32})
23482358

2349-
with rasterio.open(output_raster, "w", **out_profile) as dst:
2350-
dst.write(predictions_reshaped, 1)
2359+
if is_classifier(model):
2360+
directory = os.path.split(output_raster)[0]
2361+
name = os.path.splitext(os.path.basename(output_raster))[0]
2362+
labels_output = os.path.join(directory, name + "_labels" + ".tif")
2363+
probabilities_output = os.path.join(directory, name + "_probabilities" + ".tif")
2364+
for output_path, output_data in zip(
2365+
[labels_output, probabilities_output], [predictions_reshaped, probabilities_reshaped]
2366+
):
2367+
with rasterio.open(output_path, "w", **out_profile) as dst:
2368+
dst.write(output_data, 1)
2369+
else:
2370+
with rasterio.open(output_raster, "w", **out_profile) as dst:
2371+
dst.write(predictions_reshaped, 1)
23512372

23522373
typer.echo("Progress: 100%")
23532374
typer.echo(f"Results: {json_str}")
@@ -2375,6 +2396,11 @@ def predict_with_trained_model_cli(
23752396
model = load_model(model_file)
23762397
if is_classifier(model):
23772398
predictions, probabilities = predict_classifier(X, model, True)
2399+
probabilities = probabilities[:, 1]
2400+
probabilities = probabilities.astype(np.float32)
2401+
probabilities_reshaped = reshape_predictions(
2402+
probabilities, reference_profile["height"], reference_profile["width"], nodata_mask
2403+
)
23782404
else:
23792405
predictions = predict_regressor(X, model)
23802406

@@ -2385,10 +2411,21 @@ def predict_with_trained_model_cli(
23852411
typer.echo("Progress: 80%")
23862412

23872413
out_profile = reference_profile.copy()
2388-
out_profile.update({"count": 1, "dtype": predictions_reshaped.dtype})
2414+
out_profile.update({"count": 1, "dtype": np.float32})
23892415

2390-
with rasterio.open(output_raster, "w", **out_profile) as dst:
2391-
dst.write(predictions_reshaped, 1)
2416+
if is_classifier(model):
2417+
directory = os.path.split(output_raster)[0]
2418+
name = os.path.splitext(os.path.basename(output_raster))[0]
2419+
labels_output = os.path.join(directory, name + "_labels" + ".tif")
2420+
probabilities_output = os.path.join(directory, name + "_probabilities" + ".tif")
2421+
for output_path, output_data in zip(
2422+
[labels_output, probabilities_output], [predictions_reshaped, probabilities_reshaped]
2423+
):
2424+
with rasterio.open(output_path, "w", **out_profile) as dst:
2425+
dst.write(output_data, 1)
2426+
else:
2427+
with rasterio.open(output_raster, "w", **out_profile) as dst:
2428+
dst.write(predictions_reshaped, 1)
23922429

23932430
typer.echo("Progress: 100%")
23942431
typer.echo("Predicting completed")
@@ -2972,7 +3009,279 @@ def winsorize_transform_cli(
29723009

29733010

29743011
# ---EVALUATION ---
2975-
# TODO
3012+
3013+
3014+
@app.command()
3015+
def summarize_probability_metrics_cli(true_labels: INPUT_FILE_OPTION, probabilities: INPUT_FILE_OPTION):
3016+
"""
3017+
Generate a comprehensive report of various evaluation metrics for classification probabilities.
3018+
3019+
The output includes ROC AUC, log loss, average precision and Brier score loss.
3020+
"""
3021+
from eis_toolkit.evaluation.classification_probability_evaluation import summarize_probability_metrics
3022+
from eis_toolkit.prediction.machine_learning_general import read_data_for_evaluation
3023+
3024+
typer.echo("Progress: 10%")
3025+
3026+
(y_prob, y_true), _, _ = read_data_for_evaluation([probabilities, true_labels])
3027+
typer.echo("Progress: 25%")
3028+
3029+
results_dict = summarize_probability_metrics(y_true=y_true, y_prob=y_prob)
3030+
3031+
typer.echo("Progress: 75%")
3032+
3033+
# json_str = json.dumps(results_dict)
3034+
typer.echo("Progress: 100% \n")
3035+
# typer.echo("Results:\n")
3036+
for key, value in results_dict.items():
3037+
typer.echo(f"{key}: {value}")
3038+
# typer.echo(f"Results: {json_str}")
3039+
typer.echo("\nGenerating probability metrics summary completed.")
3040+
3041+
3042+
@app.command()
3043+
def summarize_label_metrics_binary_cli(true_labels: INPUT_FILE_OPTION, predictions: INPUT_FILE_OPTION):
3044+
"""
3045+
Generate a comprehensive report of various evaluation metrics for binary classification results.
3046+
3047+
The output includes accuracy, precision, recall, F1 scores and confusion matrix elements
3048+
(true negatives, false positives, false negatives, true positives).
3049+
"""
3050+
from eis_toolkit.evaluation.classification_label_evaluation import summarize_label_metrics_binary
3051+
from eis_toolkit.prediction.machine_learning_general import read_data_for_evaluation
3052+
3053+
typer.echo("Progress: 10%")
3054+
3055+
(y_pred, y_true), _, _ = read_data_for_evaluation([predictions, true_labels])
3056+
typer.echo("Progress: 25%")
3057+
3058+
results_dict = summarize_label_metrics_binary(y_true=y_true, y_pred=y_pred)
3059+
typer.echo("Progress: 75%")
3060+
3061+
# json_str = json.dumps(results_dict)
3062+
typer.echo("Progress: 100% \n")
3063+
for key, value in results_dict.items():
3064+
typer.echo(f"{key}: {value}")
3065+
# typer.echo(f"Results: {json_str}")
3066+
typer.echo("\n Generating prediction label metrics summary completed.")
3067+
3068+
3069+
@app.command()
3070+
def plot_roc_curve_cli(
3071+
true_labels: INPUT_FILE_OPTION,
3072+
probabilities: INPUT_FILE_OPTION,
3073+
output_file: OUTPUT_FILE_OPTION,
3074+
show_plot: bool = False,
3075+
save_dpi: Optional[int] = None,
3076+
):
3077+
"""
3078+
Plot ROC (receiver operating characteristic) curve.
3079+
3080+
ROC curve is a binary classification multi-threshold metric. The ideal performance corner of the plot
3081+
is top-left. AUC of the ROC curve summarizes model performance across different classification thresholds.
3082+
"""
3083+
import matplotlib.pyplot as plt
3084+
3085+
from eis_toolkit.evaluation.classification_probability_evaluation import plot_roc_curve
3086+
from eis_toolkit.prediction.machine_learning_general import read_data_for_evaluation
3087+
3088+
typer.echo("Progress: 10%")
3089+
3090+
(y_prob, y_true), _, _ = read_data_for_evaluation([probabilities, true_labels])
3091+
typer.echo("Progress: 25%")
3092+
3093+
_ = plot_roc_curve(y_true=y_true, y_prob=y_prob)
3094+
typer.echo("Progress: 75%")
3095+
if show_plot:
3096+
plt.show()
3097+
3098+
if output_file is not None:
3099+
dpi = "figure" if save_dpi is None else save_dpi
3100+
plt.savefig(output_file, dpi=dpi)
3101+
echo_str_end = f", output figure saved to {output_file}."
3102+
typer.echo("Progress: 100% \n")
3103+
3104+
typer.echo("ROC curve plot completed" + echo_str_end)
3105+
3106+
3107+
@app.command()
3108+
def plot_det_curve_cli(
3109+
true_labels: INPUT_FILE_OPTION,
3110+
probabilities: INPUT_FILE_OPTION,
3111+
output_file: OUTPUT_FILE_OPTION,
3112+
show_plot: bool = False,
3113+
save_dpi: Optional[int] = None,
3114+
):
3115+
"""
3116+
Plot DET (detection error tradeoff) curve.
3117+
3118+
DET curve is a binary classification multi-threshold metric. DET curves are a variation of ROC curves where
3119+
False Negative Rate is plotted on the y-axis instead of True Positive Rate. The ideal performance corner of
3120+
the plot is bottom-left. When comparing the performance of different models, DET curves can be
3121+
slightly easier to assess visually than ROC curves.
3122+
"""
3123+
import matplotlib.pyplot as plt
3124+
3125+
from eis_toolkit.evaluation.classification_probability_evaluation import plot_det_curve
3126+
from eis_toolkit.prediction.machine_learning_general import read_data_for_evaluation
3127+
3128+
typer.echo("Progress: 10%")
3129+
3130+
(y_prob, y_true), _, _ = read_data_for_evaluation([probabilities, true_labels])
3131+
typer.echo("Progress: 25%")
3132+
3133+
_ = plot_det_curve(y_true=y_true, y_prob=y_prob)
3134+
typer.echo("Progress: 75%")
3135+
if show_plot:
3136+
plt.show()
3137+
3138+
if output_file is not None:
3139+
dpi = "figure" if save_dpi is None else save_dpi
3140+
plt.savefig(output_file, dpi=dpi)
3141+
echo_str_end = f", output figure saved to {output_file}."
3142+
typer.echo("Progress: 100% \n")
3143+
3144+
typer.echo("DET curve plot completed" + echo_str_end)
3145+
3146+
3147+
@app.command()
3148+
def plot_precision_recall_curve_cli(
3149+
true_labels: INPUT_FILE_OPTION,
3150+
probabilities: INPUT_FILE_OPTION,
3151+
output_file: OUTPUT_FILE_OPTION,
3152+
show_plot: bool = False,
3153+
save_dpi: Optional[int] = None,
3154+
):
3155+
"""
3156+
Plot precision-recall curve.
3157+
3158+
Precision-recall curve is a binary classification multi-threshold metric. Precision-recall curve shows
3159+
the tradeoff between precision and recall for different classification thresholds.
3160+
It can be a useful measure of success when classes are imbalanced.
3161+
"""
3162+
import matplotlib.pyplot as plt
3163+
3164+
from eis_toolkit.evaluation.classification_probability_evaluation import plot_precision_recall_curve
3165+
from eis_toolkit.prediction.machine_learning_general import read_data_for_evaluation
3166+
3167+
typer.echo("Progress: 10%")
3168+
3169+
(y_prob, y_true), _, _ = read_data_for_evaluation([probabilities, true_labels])
3170+
typer.echo("Progress: 25%")
3171+
3172+
_ = plot_precision_recall_curve(y_true=y_true, y_prob=y_prob)
3173+
typer.echo("Progress: 75%")
3174+
if show_plot:
3175+
plt.show()
3176+
3177+
if output_file is not None:
3178+
dpi = "figure" if save_dpi is None else save_dpi
3179+
plt.savefig(output_file, dpi=dpi)
3180+
echo_str_end = f", output figure saved to {output_file}."
3181+
typer.echo("Progress: 100% \n")
3182+
3183+
typer.echo("Precision-Recall curve plot completed" + echo_str_end)
3184+
3185+
3186+
@app.command()
3187+
def plot_calibration_curve_cli(
3188+
true_labels: INPUT_FILE_OPTION,
3189+
probabilities: INPUT_FILE_OPTION,
3190+
output_file: OUTPUT_FILE_OPTION,
3191+
n_bins: int = 5,
3192+
show_plot: bool = False,
3193+
save_dpi: Optional[int] = None,
3194+
):
3195+
"""
3196+
Plot calibration curve (aka realibity diagram).
3197+
3198+
Calibration curve has the frequency of the positive labels on the y-axis and the predicted probability on
3199+
the x-axis. Generally, the close the calibration curve is to line x=y, the better the model is calibrated.
3200+
"""
3201+
import matplotlib.pyplot as plt
3202+
3203+
from eis_toolkit.evaluation.classification_probability_evaluation import plot_calibration_curve
3204+
from eis_toolkit.prediction.machine_learning_general import read_data_for_evaluation
3205+
3206+
typer.echo("Progress: 10%")
3207+
3208+
(y_prob, y_true), _, _ = read_data_for_evaluation([probabilities, true_labels])
3209+
typer.echo("Progress: 25%")
3210+
3211+
_ = plot_calibration_curve(y_true=y_true, y_prob=y_prob, n_bins=n_bins)
3212+
typer.echo("Progress: 75%")
3213+
if show_plot:
3214+
plt.show()
3215+
3216+
if output_file is not None:
3217+
dpi = "figure" if save_dpi is None else save_dpi
3218+
plt.savefig(output_file, dpi=dpi)
3219+
echo_str_end = f", output figure saved to {output_file}."
3220+
typer.echo("Progress: 100% \n")
3221+
3222+
typer.echo("Calibration curve plot completed" + echo_str_end)
3223+
3224+
3225+
@app.command()
3226+
def plot_confusion_matrix_cli(
3227+
true_labels: INPUT_FILE_OPTION,
3228+
predictions: INPUT_FILE_OPTION,
3229+
output_file: OUTPUT_FILE_OPTION,
3230+
show_plot: bool = False,
3231+
save_dpi: Optional[int] = None,
3232+
):
3233+
"""Plot confusion matrix to visualize classification results."""
3234+
import matplotlib.pyplot as plt
3235+
from sklearn.metrics import confusion_matrix
3236+
3237+
from eis_toolkit.evaluation.plot_confusion_matrix import plot_confusion_matrix
3238+
from eis_toolkit.prediction.machine_learning_general import read_data_for_evaluation
3239+
3240+
typer.echo("Progress: 10%")
3241+
3242+
(y_pred, y_true), _, _ = read_data_for_evaluation([predictions, true_labels])
3243+
typer.echo("Progress: 25%")
3244+
3245+
matrix = confusion_matrix(y_true, y_pred)
3246+
_ = plot_confusion_matrix(confusion_matrix=matrix)
3247+
typer.echo("Progress: 75%")
3248+
if show_plot:
3249+
plt.show()
3250+
3251+
if output_file is not None:
3252+
dpi = "figure" if save_dpi is None else save_dpi
3253+
plt.savefig(output_file, dpi=dpi)
3254+
echo_str_end = f", output figure saved to {output_file}."
3255+
typer.echo("Progress: 100% \n")
3256+
3257+
typer.echo("Confusion matrix plot completed" + echo_str_end)
3258+
3259+
3260+
@app.command()
3261+
def score_predictions_cli(
3262+
true_labels: INPUT_FILE_OPTION,
3263+
predictions: INPUT_FILE_OPTION,
3264+
metrics: Annotated[List[str], typer.Option()],
3265+
):
3266+
"""Score predictions."""
3267+
from eis_toolkit.evaluation.scoring import score_predictions
3268+
from eis_toolkit.prediction.machine_learning_general import read_data_for_evaluation
3269+
3270+
typer.echo("Progress: 10%")
3271+
3272+
(y_pred, y_true), _, _ = read_data_for_evaluation([predictions, true_labels])
3273+
typer.echo("Progress: 25%")
3274+
3275+
outputs = score_predictions(y_true, y_pred, metrics)
3276+
typer.echo("Progress: 100% \n")
3277+
3278+
if isinstance(outputs, dict):
3279+
for key, value in outputs.items():
3280+
typer.echo(f"{key}: {value}")
3281+
else:
3282+
typer.echo(outputs)
3283+
3284+
typer.echo("\nScoring predictions completed.")
29763285

29773286

29783287
# --- UTILITIES ---

0 commit comments

Comments
 (0)