Skip to content

Commit f1ed697

Browse files
committed
Rename validation_metrics to test_metrics, fix metric prints
1 parent 6da73e4 commit f1ed697

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

eis_toolkit/cli.py

+20-13
Original file line numberDiff line numberDiff line change
@@ -2322,9 +2322,7 @@ def classifier_test_cli(
23222322
output_raster_probability: OUTPUT_FILE_OPTION,
23232323
output_raster_classified: OUTPUT_FILE_OPTION,
23242324
classification_threshold: float = 0.5,
2325-
validation_metrics: Annotated[List[ClassifierMetrics], typer.Option(case_sensitive=False)] = [
2326-
ClassifierMetrics.accuracy
2327-
],
2325+
test_metrics: Annotated[List[ClassifierMetrics], typer.Option(case_sensitive=False)] = [ClassifierMetrics.accuracy],
23282326
):
23292327
"""Test trained machine learning classifier model by predicting and scoring."""
23302328
from eis_toolkit.evaluation.scoring import score_predictions
@@ -2343,8 +2341,8 @@ def classifier_test_cli(
23432341
predictions, reference_profile["height"], reference_profile["width"], nodata_mask
23442342
)
23452343

2346-
metrics_dict = score_predictions(y, predictions, validation_metrics)
2347-
json_str = json.dumps(metrics_dict)
2344+
metrics_dict = score_predictions(y, predictions, test_metrics)
2345+
# json_str = json.dumps(metrics_dict)
23482346
typer.echo("Progress: 80%")
23492347

23502348
out_profile = reference_profile.copy()
@@ -2355,8 +2353,13 @@ def classifier_test_cli(
23552353
with rasterio.open(output_raster_classified, "w", **out_profile) as dst:
23562354
dst.write(predictions_reshaped, 1)
23572355

2358-
typer.echo("Progress: 100%")
2359-
typer.echo(f"Results: {json_str}")
2356+
typer.echo("Progress: 100%\n")
2357+
# typer.echo(f"Results:")
2358+
2359+
for key, value in metrics_dict.items():
2360+
typer.echo(f"{key}: {value}")
2361+
2362+
typer.echo("\n")
23602363

23612364
typer.echo(
23622365
(
@@ -2373,7 +2376,7 @@ def regressor_test_cli(
23732376
target_labels: INPUT_FILE_OPTION,
23742377
model_file: INPUT_FILE_OPTION,
23752378
output_raster: OUTPUT_FILE_OPTION,
2376-
validation_metrics: Annotated[List[RegressorMetrics], typer.Option(case_sensitive=False)] = [RegressorMetrics.mse],
2379+
test_metrics: Annotated[List[RegressorMetrics], typer.Option(case_sensitive=False)] = [RegressorMetrics.mse],
23772380
):
23782381
"""Test trained machine learning regressor model by predicting and scoring."""
23792382
from eis_toolkit.evaluation.scoring import score_predictions
@@ -2385,23 +2388,27 @@ def regressor_test_cli(
23852388

23862389
model = load_model(model_file)
23872390
predictions = predict_regressor(X, model)
2388-
metrics_dict = score_predictions(y, predictions, validation_metrics)
2391+
metrics_dict = score_predictions(y, predictions, test_metrics)
23892392
predictions_reshaped = reshape_predictions(
23902393
predictions, reference_profile["height"], reference_profile["width"], nodata_mask
23912394
)
2392-
23932395
typer.echo("Progress: 80%")
23942396

2395-
json_str = json.dumps(metrics_dict)
2397+
# json_str = json.dumps(metrics_dict)
23962398

23972399
out_profile = reference_profile.copy()
23982400
out_profile.update({"count": 1, "dtype": np.float32})
23992401

24002402
with rasterio.open(output_raster, "w", **out_profile) as dst:
24012403
dst.write(predictions_reshaped, 1)
24022404

2403-
typer.echo("Progress: 100%")
2404-
typer.echo(f"Results: {json_str}")
2405+
typer.echo("Progress: 100%\n")
2406+
# typer.echo("Results: ")
2407+
2408+
for key, value in metrics_dict.items():
2409+
typer.echo(f"{key}: {value}")
2410+
2411+
typer.echo("\n")
24052412

24062413
typer.echo(f"Testing regressor model completed, writing raster to {output_raster}.")
24072414

0 commit comments

Comments
 (0)