Skip to content

Commit efedd52

Browse files
committed
feat(evaluation): Add decimals parameter to score_predictions. Make tools calling score_predictions default to 3 decimals
1 parent 5c8d32e commit efedd52

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

eis_toolkit/cli.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -2375,7 +2375,7 @@ def classifier_test_cli(
23752375
predictions, reference_profile["height"], reference_profile["width"], nodata_mask
23762376
)
23772377

2378-
metrics_dict = score_predictions(y, predictions, get_enum_values(test_metrics))
2378+
metrics_dict = score_predictions(y, predictions, get_enum_values(test_metrics), decimals=3)
23792379
typer.echo("Progress: 80%")
23802380

23812381
out_profile = reference_profile.copy()
@@ -2421,7 +2421,7 @@ def regressor_test_cli(
24212421
predictions, reference_profile["height"], reference_profile["width"], nodata_mask
24222422
)
24232423

2424-
metrics_dict = score_predictions(y, predictions, get_enum_values(test_metrics))
2424+
metrics_dict = score_predictions(y, predictions, get_enum_values(test_metrics), decimals=3)
24252425
typer.echo("Progress: 80%")
24262426

24272427
out_profile = reference_profile.copy()
@@ -3340,6 +3340,7 @@ def score_predictions_cli(
33403340
true_labels: INPUT_FILE_OPTION,
33413341
predictions: INPUT_FILE_OPTION,
33423342
metrics: Annotated[List[str], typer.Option()],
3343+
decimals: Optional[int] = None,
33433344
):
33443345
"""Score predictions."""
33453346
from eis_toolkit.evaluation.scoring import score_predictions
@@ -3350,7 +3351,7 @@ def score_predictions_cli(
33503351
(y_pred, y_true), _, _ = read_data_for_evaluation([predictions, true_labels])
33513352
typer.echo("Progress: 25%")
33523353

3353-
outputs = score_predictions(y_true, y_pred, metrics)
3354+
outputs = score_predictions(y_true, y_pred, metrics, decimals)
33543355
typer.echo("Progress: 100% \n")
33553356

33563357
typer.echo(f"Results: {str(outputs)}")

eis_toolkit/evaluation/scoring.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from numbers import Number
2+
from typing import Optional
23

34
import numpy as np
45
import pandas as pd
@@ -19,7 +20,10 @@
1920

2021
@beartype
2122
def score_predictions(
22-
y_true: Union[np.ndarray, pd.Series], y_pred: Union[np.ndarray, pd.Series], metrics: Union[str, Sequence[str]]
23+
y_true: Union[np.ndarray, pd.Series],
24+
y_pred: Union[np.ndarray, pd.Series],
25+
metrics: Union[str, Sequence[str]],
26+
decimals: Optional[int] = None,
2327
) -> Union[Number, Dict[str, Number]]:
2428
"""
2529
Score model predictions with given metrics.
@@ -34,18 +38,20 @@ def score_predictions(
3438
y_pred: Predicted labels.
3539
metrics: The metrics to use for scoring the model. Select only metrics applicable
3640
for the model type.
41+
decimals: Number of decimals used in rounding the scores. If None, scores are not rounded.
42+
Defaults to None.
3743
3844
Returns:
3945
Metric scores as a dictionary if multiple metrics, otherwise just the metric value.
4046
"""
4147
if isinstance(metrics, str):
4248
score = _score_predictions(y_true, y_pred, metrics)
43-
return score
49+
return round(score, decimals) if decimals is not None else score
4450
else:
4551
out_metrics = {}
4652
for metric in metrics:
4753
score = _score_predictions(y_true, y_pred, metric)
48-
out_metrics[metric] = score
54+
out_metrics[metric] = round(score, decimals) if decimals is not None else score
4955
return out_metrics
5056

5157

eis_toolkit/prediction/machine_learning_general.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def _train_and_validate_sklearn_model(
350350

351351
out_metrics = {}
352352
for metric in metrics:
353-
score = score_predictions(y_valid, y_pred, metric)
353+
score = score_predictions(y_valid, y_pred, metric, decimals=3)
354354
out_metrics[metric] = score
355355

356356
# Validation approach 3: Cross-validation
@@ -369,7 +369,7 @@ def _train_and_validate_sklearn_model(
369369
y_pred = model.predict(X[valid_index])
370370

371371
for metric in metrics:
372-
score = score_predictions(y[valid_index], y_pred, metric)
372+
score = score_predictions(y[valid_index], y_pred, metric, decimals=3)
373373
all_scores = out_metrics[metric][f"{metric}_all"]
374374
all_scores.append(score)
375375

0 commit comments

Comments
 (0)