@@ -2318,20 +2318,23 @@ def evaluate_trained_model_cli(
2318
2318
output_raster : OUTPUT_FILE_OPTION ,
2319
2319
validation_metrics : Annotated [List [str ], typer .Option ()],
2320
2320
):
2321
- """Train and optionally validate a Gradient boosting regressor model using Sklearn."""
2322
- from eis_toolkit .prediction .machine_learning_general import (
2323
- evaluate_model ,
2324
- load_model ,
2325
- prepare_data_for_ml ,
2326
- reshape_predictions ,
2327
- )
2321
+ """Evaluate a trained machine learning model by predicting and scoring."""
2322
+ from sklearn .base import is_classifier
2328
2323
2329
- X , y , reference_profile , nodata_mask = prepare_data_for_ml (input_rasters , target_labels )
2324
+ from eis_toolkit .evaluation .scoring import score_predictions
2325
+ from eis_toolkit .prediction .machine_learning_general import load_model , prepare_data_for_ml , reshape_predictions
2326
+ from eis_toolkit .prediction .machine_learning_predict import predict_classifier , predict_regressor
2330
2327
2328
+ X , y , reference_profile , nodata_mask = prepare_data_for_ml (input_rasters , target_labels )
2329
+ print (len (np .unique (y )))
2331
2330
typer .echo ("Progress: 30%" )
2332
2331
2333
2332
model = load_model (model_file )
2334
- predictions , metrics_dict = evaluate_model (X , y , model , validation_metrics )
2333
+ if is_classifier (model ):
2334
+ predictions , probabilities = predict_classifier (X , model , True )
2335
+ else :
2336
+ predictions = predict_regressor (X , model )
2337
+ metrics_dict = score_predictions (y , predictions , validation_metrics )
2335
2338
predictions_reshaped = reshape_predictions (
2336
2339
predictions , reference_profile ["height" ], reference_profile ["width" ], nodata_mask
2337
2340
)
@@ -2359,20 +2362,22 @@ def predict_with_trained_model_cli(
2359
2362
model_file : INPUT_FILE_OPTION ,
2360
2363
output_raster : OUTPUT_FILE_OPTION ,
2361
2364
):
2362
- """Train and optionally validate a Gradient boosting regressor model using Sklearn."""
2363
- from eis_toolkit .prediction .machine_learning_general import (
2364
- load_model ,
2365
- predict ,
2366
- prepare_data_for_ml ,
2367
- reshape_predictions ,
2368
- )
2365
+ """Predict with a trained machine learning model."""
2366
+ from sklearn .base import is_classifier
2367
+
2368
+ from eis_toolkit .prediction .machine_learning_general import load_model , prepare_data_for_ml , reshape_predictions
2369
+ from eis_toolkit .prediction .machine_learning_predict import predict_classifier , predict_regressor
2369
2370
2370
2371
X , _ , reference_profile , nodata_mask = prepare_data_for_ml (input_rasters )
2371
2372
2372
2373
typer .echo ("Progress: 30%" )
2373
2374
2374
2375
model = load_model (model_file )
2375
- predictions = predict (X , model )
2376
+ if is_classifier (model ):
2377
+ predictions , probabilities = predict_classifier (X , model , True )
2378
+ else :
2379
+ predictions = predict_regressor (X , model )
2380
+
2376
2381
predictions_reshaped = reshape_predictions (
2377
2382
predictions , reference_profile ["height" ], reference_profile ["width" ], nodata_mask
2378
2383
)
0 commit comments