Skip to content

Commit c996979

Browse files
committed
Various CLI additions and updates:
- Added CLI functions for the evaluation functions - Updated evaluate trained model CLI function - Disabled JSON dumping in some CLI functions due to reported errors
1 parent 712afb7 commit c996979

File tree

1 file changed

+169
-13
lines changed

1 file changed

+169
-13
lines changed

eis_toolkit/cli.py

+169-13
Original file line numberDiff line numberDiff line change
@@ -722,10 +722,14 @@ def descriptive_statistics_raster_cli(input_file: INPUT_FILE_OPTION):
722722
results_dict = descriptive_statistics_raster(raster)
723723
typer.echo("Progress: 75%")
724724

725-
json_str = json.dumps(results_dict)
726-
typer.echo("Progress: 100%")
727-
typer.echo(f"Results: {json_str}")
728-
typer.echo("Descriptive statistics (raster) completed")
725+
# json_str = json.dumps(results_dict)
726+
typer.echo("Progress: 100%\n")
727+
# typer.echo(f"Results: {json_str}")
728+
# typer.echo("Results:\n")
729+
for key, value in results_dict.items():
730+
typer.echo(f"{key}: {value}")
731+
732+
typer.echo("\nDescriptive statistics (raster) completed")
729733

730734

731735
# DESCRIPTIVE STATISTICS (VECTOR)
@@ -2318,23 +2322,29 @@ def evaluate_trained_model_cli(
23182322
output_raster: OUTPUT_FILE_OPTION,
23192323
validation_metrics: Annotated[List[str], typer.Option()],
23202324
):
2321-
"""Evaluate a trained machine learning model by predicting and scoring."""
2325+
"""Predict and evaluate a trained machine learning model by predicting and scoring."""
23222326
from sklearn.base import is_classifier
23232327

23242328
from eis_toolkit.evaluation.scoring import score_predictions
23252329
from eis_toolkit.prediction.machine_learning_general import load_model, prepare_data_for_ml, reshape_predictions
23262330
from eis_toolkit.prediction.machine_learning_predict import predict_classifier, predict_regressor
23272331

23282332
X, y, reference_profile, nodata_mask = prepare_data_for_ml(input_rasters, target_labels)
2329-
print(len(np.unique(y)))
23302333
typer.echo("Progress: 30%")
23312334

23322335
model = load_model(model_file)
23332336
if is_classifier(model):
23342337
predictions, probabilities = predict_classifier(X, model, True)
2338+
probabilities = probabilities[:, 1]
2339+
probabilities = probabilities.astype(np.float32)
2340+
probabilities_reshaped = reshape_predictions(
2341+
probabilities, reference_profile["height"], reference_profile["width"], nodata_mask
2342+
)
23352343
else:
23362344
predictions = predict_regressor(X, model)
2345+
23372346
metrics_dict = score_predictions(y, predictions, validation_metrics)
2347+
23382348
predictions_reshaped = reshape_predictions(
23392349
predictions, reference_profile["height"], reference_profile["width"], nodata_mask
23402350
)
@@ -2344,10 +2354,21 @@ def evaluate_trained_model_cli(
23442354
json_str = json.dumps(metrics_dict)
23452355

23462356
out_profile = reference_profile.copy()
2347-
out_profile.update({"count": 1, "dtype": predictions_reshaped.dtype})
2357+
out_profile.update({"count": 1, "dtype": np.float32})
23482358

2349-
with rasterio.open(output_raster, "w", **out_profile) as dst:
2350-
dst.write(predictions_reshaped, 1)
2359+
if is_classifier(model):
2360+
directory = os.path.split(output_raster)[0]
2361+
name = os.path.splitext(os.path.basename(output_raster))[0]
2362+
labels_output = os.path.join(directory, name + "_labels" + ".tif")
2363+
probabilities_output = os.path.join(directory, name + "_probabilities" + ".tif")
2364+
for output_path, output_data in zip(
2365+
[labels_output, probabilities_output], [predictions_reshaped, probabilities_reshaped]
2366+
):
2367+
with rasterio.open(output_path, "w", **out_profile) as dst:
2368+
dst.write(output_data, 1)
2369+
else:
2370+
with rasterio.open(output_raster, "w", **out_profile) as dst:
2371+
dst.write(predictions_reshaped, 1)
23512372

23522373
typer.echo("Progress: 100%")
23532374
typer.echo(f"Results: {json_str}")
@@ -2375,6 +2396,11 @@ def predict_with_trained_model_cli(
23752396
model = load_model(model_file)
23762397
if is_classifier(model):
23772398
predictions, probabilities = predict_classifier(X, model, True)
2399+
probabilities = probabilities[:, 1]
2400+
probabilities = probabilities.astype(np.float32)
2401+
probabilities_reshaped = reshape_predictions(
2402+
probabilities, reference_profile["height"], reference_profile["width"], nodata_mask
2403+
)
23782404
else:
23792405
predictions = predict_regressor(X, model)
23802406

@@ -2385,10 +2411,21 @@ def predict_with_trained_model_cli(
23852411
typer.echo("Progress: 80%")
23862412

23872413
out_profile = reference_profile.copy()
2388-
out_profile.update({"count": 1, "dtype": predictions_reshaped.dtype})
2414+
out_profile.update({"count": 1, "dtype": np.float32})
23892415

2390-
with rasterio.open(output_raster, "w", **out_profile) as dst:
2391-
dst.write(predictions_reshaped, 1)
2416+
if is_classifier(model):
2417+
directory = os.path.split(output_raster)[0]
2418+
name = os.path.splitext(os.path.basename(output_raster))[0]
2419+
labels_output = os.path.join(directory, name + "_labels" + ".tif")
2420+
probabilities_output = os.path.join(directory, name + "_probabilities" + ".tif")
2421+
for output_path, output_data in zip(
2422+
[labels_output, probabilities_output], [predictions_reshaped, probabilities_reshaped]
2423+
):
2424+
with rasterio.open(output_path, "w", **out_profile) as dst:
2425+
dst.write(output_data, 1)
2426+
else:
2427+
with rasterio.open(output_raster, "w", **out_profile) as dst:
2428+
dst.write(predictions_reshaped, 1)
23922429

23932430
typer.echo("Progress: 100%")
23942431
typer.echo("Predicting completed")
@@ -2972,7 +3009,126 @@ def winsorize_transform_cli(
29723009

29733010

29743011
# ---EVALUATION ---
2975-
# TODO
3012+
3013+
3014+
@app.command()
3015+
def summarize_probability_metrics_cli(true_labels: INPUT_FILE_OPTION, probabilities: INPUT_FILE_OPTION):
3016+
"""
3017+
Generate a comprehensive report of various evaluation metrics for classification probabilities.
3018+
3019+
The output includes ROC AUC, log loss, average precision and Brier score loss.
3020+
"""
3021+
from eis_toolkit.evaluation.classification_probability_evaluation import summarize_probability_metrics
3022+
from eis_toolkit.prediction.machine_learning_general import read_data_for_evaluation
3023+
3024+
typer.echo("Progress: 10%")
3025+
3026+
(y_prob, y_true), _, _ = read_data_for_evaluation([probabilities, true_labels])
3027+
typer.echo("Progress: 25%")
3028+
3029+
results_dict = summarize_probability_metrics(y_true=y_true, y_prob=y_prob)
3030+
3031+
typer.echo("Progress: 75%")
3032+
3033+
# json_str = json.dumps(results_dict)
3034+
typer.echo("Progress: 100% \n")
3035+
# typer.echo("Results:\n")
3036+
for key, value in results_dict.items():
3037+
typer.echo(f"{key}: {value}")
3038+
# typer.echo(f"Results: {json_str}")
3039+
typer.echo("\nGenerating probability metrics summary completed.")
3040+
3041+
3042+
@app.command()
3043+
def summarize_label_metrics_binary_cli(true_labels: INPUT_FILE_OPTION, predictions: INPUT_FILE_OPTION):
3044+
"""
3045+
Generate a comprehensive report of various evaluation metrics for binary classification results.
3046+
3047+
The output includes accuracy, precision, recall, F1 scores and confusion matrix elements
3048+
(true negatives, false positives, false negatives, true positives).
3049+
"""
3050+
from eis_toolkit.evaluation.classification_label_evaluation import summarize_label_metrics_binary
3051+
from eis_toolkit.prediction.machine_learning_general import read_data_for_evaluation
3052+
3053+
typer.echo("Progress: 10%")
3054+
3055+
(y_pred, y_true), _, _ = read_data_for_evaluation([predictions, true_labels])
3056+
typer.echo("Progress: 25%")
3057+
3058+
results_dict = summarize_label_metrics_binary(y_true=y_true, y_pred=y_pred)
3059+
typer.echo("Progress: 75%")
3060+
3061+
# json_str = json.dumps(results_dict)
3062+
typer.echo("Progress: 100% \n")
3063+
for key, value in results_dict.items():
3064+
typer.echo(f"{key}: {value}")
3065+
# typer.echo(f"Results: {json_str}")
3066+
typer.echo("\n Generating prediction label metrics summary completed.")
3067+
3068+
3069+
@app.command()
3070+
def plot_roc_curve_cli(
3071+
true_labels: INPUT_FILE_OPTION,
3072+
probabilities: INPUT_FILE_OPTION,
3073+
output_file: OUTPUT_FILE_OPTION,
3074+
show_plot: bool = False,
3075+
save_dpi: Optional[int] = None,
3076+
):
3077+
"""
3078+
Plot ROC (receiver operating characteristic) curve.
3079+
3080+
ROC curve is a binary classification multi-threshold metric. The ideal performance corner of the plot
3081+
is top-left. AUC of the ROC curve summarizes model performance across different classification thresholds.
3082+
"""
3083+
import matplotlib.pyplot as plt
3084+
3085+
from eis_toolkit.evaluation.classification_probability_evaluation import plot_roc_curve
3086+
from eis_toolkit.prediction.machine_learning_general import read_data_for_evaluation
3087+
3088+
typer.echo("Progress: 10%")
3089+
3090+
(y_prob, y_true), _, _ = read_data_for_evaluation([probabilities, true_labels])
3091+
typer.echo("Progress: 25%")
3092+
3093+
_ = plot_roc_curve(y_true=y_true, y_prob=y_prob)
3094+
typer.echo("Progress: 75%")
3095+
if show_plot:
3096+
plt.show()
3097+
3098+
if output_file is not None:
3099+
dpi = "figure" if save_dpi is None else save_dpi
3100+
plt.savefig(output_file, dpi=dpi)
3101+
echo_str_end = f", output figure saved to {output_file}."
3102+
typer.echo("Progress: 100% \n")
3103+
3104+
typer.echo("ROC curve plot completed" + echo_str_end)
3105+
3106+
3107+
@app.command()
3108+
def score_predictions_cli(
3109+
true_labels: INPUT_FILE_OPTION,
3110+
predictions: INPUT_FILE_OPTION,
3111+
metrics: Annotated[List[str], typer.Option()],
3112+
):
3113+
"""Score predictions."""
3114+
from eis_toolkit.evaluation.scoring import score_predictions
3115+
from eis_toolkit.prediction.machine_learning_general import read_data_for_evaluation
3116+
3117+
typer.echo("Progress: 10%")
3118+
3119+
(y_pred, y_true), _, _ = read_data_for_evaluation([predictions, true_labels])
3120+
typer.echo("Progress: 25%")
3121+
3122+
outputs = score_predictions(y_true, y_pred, metrics)
3123+
typer.echo("Progress: 100% \n")
3124+
3125+
if isinstance(outputs, dict):
3126+
for key, value in outputs.items():
3127+
typer.echo(f"{key}: {value}")
3128+
else:
3129+
typer.echo(outputs)
3130+
3131+
typer.echo("\nScoring predictions completed.")
29763132

29773133

29783134
# --- UTILITIES ---

0 commit comments

Comments
 (0)