Skip to content

Commit e47459f

Browse files
committed
fix(CLI): Fix model scoring in CLI functions
1 parent 7193a8c commit e47459f

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

eis_toolkit/cli.py

+22-17
Original file line numberDiff line numberDiff line change
@@ -2318,20 +2318,23 @@ def evaluate_trained_model_cli(
23182318
output_raster: OUTPUT_FILE_OPTION,
23192319
validation_metrics: Annotated[List[str], typer.Option()],
23202320
):
2321-
"""Train and optionally validate a Gradient boosting regressor model using Sklearn."""
2322-
from eis_toolkit.prediction.machine_learning_general import (
2323-
evaluate_model,
2324-
load_model,
2325-
prepare_data_for_ml,
2326-
reshape_predictions,
2327-
)
2321+
"""Evaluate a trained machine learning model by predicting and scoring."""
2322+
from sklearn.base import is_classifier
23282323

2329-
X, y, reference_profile, nodata_mask = prepare_data_for_ml(input_rasters, target_labels)
2324+
from eis_toolkit.evaluation.scoring import score_predictions
2325+
from eis_toolkit.prediction.machine_learning_general import load_model, prepare_data_for_ml, reshape_predictions
2326+
from eis_toolkit.prediction.machine_learning_predict import predict_classifier, predict_regressor
23302327

2328+
X, y, reference_profile, nodata_mask = prepare_data_for_ml(input_rasters, target_labels)
2329+
print(len(np.unique(y)))
23312330
typer.echo("Progress: 30%")
23322331

23332332
model = load_model(model_file)
2334-
predictions, metrics_dict = evaluate_model(X, y, model, validation_metrics)
2333+
if is_classifier(model):
2334+
predictions, probabilities = predict_classifier(X, model, True)
2335+
else:
2336+
predictions = predict_regressor(X, model)
2337+
metrics_dict = score_predictions(y, predictions, validation_metrics)
23352338
predictions_reshaped = reshape_predictions(
23362339
predictions, reference_profile["height"], reference_profile["width"], nodata_mask
23372340
)
@@ -2359,20 +2362,22 @@ def predict_with_trained_model_cli(
23592362
model_file: INPUT_FILE_OPTION,
23602363
output_raster: OUTPUT_FILE_OPTION,
23612364
):
2362-
"""Train and optionally validate a Gradient boosting regressor model using Sklearn."""
2363-
from eis_toolkit.prediction.machine_learning_general import (
2364-
load_model,
2365-
predict,
2366-
prepare_data_for_ml,
2367-
reshape_predictions,
2368-
)
2365+
"""Predict with a trained machine learning model."""
2366+
from sklearn.base import is_classifier
2367+
2368+
from eis_toolkit.prediction.machine_learning_general import load_model, prepare_data_for_ml, reshape_predictions
2369+
from eis_toolkit.prediction.machine_learning_predict import predict_classifier, predict_regressor
23692370

23702371
X, _, reference_profile, nodata_mask = prepare_data_for_ml(input_rasters)
23712372

23722373
typer.echo("Progress: 30%")
23732374

23742375
model = load_model(model_file)
2375-
predictions = predict(X, model)
2376+
if is_classifier(model):
2377+
predictions, probabilities = predict_classifier(X, model, True)
2378+
else:
2379+
predictions = predict_regressor(X, model)
2380+
23762381
predictions_reshaped = reshape_predictions(
23772382
predictions, reference_profile["height"], reference_profile["width"], nodata_mask
23782383
)

0 commit comments

Comments
 (0)