Skip to content

Commit 493e6c6

Browse files
committed
Fix test metrics prints for classifier and regressor test CLI functions
1 parent f1ed697 commit 493e6c6

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

eis_toolkit/cli.py

+8-13
Original file line numberDiff line numberDiff line change
@@ -2341,8 +2341,7 @@ def classifier_test_cli(
23412341
predictions, reference_profile["height"], reference_profile["width"], nodata_mask
23422342
)
23432343

2344-
metrics_dict = score_predictions(y, predictions, test_metrics)
2345-
# json_str = json.dumps(metrics_dict)
2344+
metrics_dict = score_predictions(y, predictions, get_enum_values(test_metrics))
23462345
typer.echo("Progress: 80%")
23472346

23482347
out_profile = reference_profile.copy()
@@ -2353,14 +2352,12 @@ def classifier_test_cli(
23532352
with rasterio.open(output_raster_classified, "w", **out_profile) as dst:
23542353
dst.write(predictions_reshaped, 1)
23552354

2356-
typer.echo("Progress: 100%\n")
2357-
# typer.echo(f"Results:")
2358-
2355+
typer.echo("\n")
23592356
for key, value in metrics_dict.items():
23602357
typer.echo(f"{key}: {value}")
2361-
23622358
typer.echo("\n")
23632359

2360+
typer.echo("Progress: 100%")
23642361
typer.echo(
23652362
(
23662363
"Testing classifier model completed, writing rasters to "
@@ -2388,28 +2385,26 @@ def regressor_test_cli(
23882385

23892386
model = load_model(model_file)
23902387
predictions = predict_regressor(X, model)
2391-
metrics_dict = score_predictions(y, predictions, test_metrics)
23922388
predictions_reshaped = reshape_predictions(
23932389
predictions, reference_profile["height"], reference_profile["width"], nodata_mask
23942390
)
2395-
typer.echo("Progress: 80%")
23962391

2397-
# json_str = json.dumps(metrics_dict)
2392+
metrics_dict = score_predictions(y, predictions, get_enum_values(test_metrics))
2393+
typer.echo("Progress: 80%")
23982394

23992395
out_profile = reference_profile.copy()
24002396
out_profile.update({"count": 1, "dtype": np.float32})
24012397

24022398
with rasterio.open(output_raster, "w", **out_profile) as dst:
24032399
dst.write(predictions_reshaped, 1)
24042400

2405-
typer.echo("Progress: 100%\n")
2406-
# typer.echo("Results: ")
2407-
2401+
typer.echo("\n")
24082402
for key, value in metrics_dict.items():
24092403
typer.echo(f"{key}: {value}")
2410-
24112404
typer.echo("\n")
24122405

2406+
typer.echo("Progress: 100%\n")
2407+
24132408
typer.echo(f"Testing regressor model completed, writing raster to {output_raster}.")
24142409

24152410

0 commit comments

Comments
 (0)