Skip to content

Commit 2390407

Browse files
authored
Merge pull request #411 from GispoCoding/406-reduce-number-of-decimals-in-numeric-outputs
Reduce number of decimals in outputs
2 parents 813362c + f7a6269 commit 2390407

7 files changed

+49
-28
lines changed

eis_toolkit/cli.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -2375,7 +2375,7 @@ def classifier_test_cli(
23752375
predictions, reference_profile["height"], reference_profile["width"], nodata_mask
23762376
)
23772377

2378-
metrics_dict = score_predictions(y, predictions, get_enum_values(test_metrics))
2378+
metrics_dict = score_predictions(y, predictions, get_enum_values(test_metrics), decimals=3)
23792379
typer.echo("Progress: 80%")
23802380

23812381
out_profile = reference_profile.copy()
@@ -2421,7 +2421,7 @@ def regressor_test_cli(
24212421
predictions, reference_profile["height"], reference_profile["width"], nodata_mask
24222422
)
24232423

2424-
metrics_dict = score_predictions(y, predictions, get_enum_values(test_metrics))
2424+
metrics_dict = score_predictions(y, predictions, get_enum_values(test_metrics), decimals=3)
24252425
typer.echo("Progress: 80%")
24262426

24272427
out_profile = reference_profile.copy()
@@ -3109,7 +3109,7 @@ def summarize_probability_metrics_cli(true_labels: INPUT_FILE_OPTION, probabilit
31093109
(y_prob, y_true), _, _ = read_data_for_evaluation([probabilities, true_labels])
31103110
typer.echo("Progress: 25%")
31113111

3112-
results_dict = summarize_probability_metrics(y_true=y_true, y_prob=y_prob)
3112+
results_dict = summarize_probability_metrics(y_true=y_true, y_prob=y_prob, decimals=3)
31133113

31143114
typer.echo("Progress: 75%")
31153115

@@ -3135,7 +3135,7 @@ def summarize_label_metrics_binary_cli(true_labels: INPUT_FILE_OPTION, predictio
31353135
(y_pred, y_true), _, _ = read_data_for_evaluation([predictions, true_labels])
31363136
typer.echo("Progress: 25%")
31373137

3138-
results_dict = summarize_label_metrics_binary(y_true=y_true, y_pred=y_pred)
3138+
results_dict = summarize_label_metrics_binary(y_true=y_true, y_pred=y_pred, decimals=3)
31393139
typer.echo("Progress: 75%")
31403140

31413141
typer.echo("Progress: 100% \n")
@@ -3340,6 +3340,7 @@ def score_predictions_cli(
33403340
true_labels: INPUT_FILE_OPTION,
33413341
predictions: INPUT_FILE_OPTION,
33423342
metrics: Annotated[List[str], typer.Option()],
3343+
decimals: Optional[int] = None,
33433344
):
33443345
"""Score predictions."""
33453346
from eis_toolkit.evaluation.scoring import score_predictions
@@ -3350,7 +3351,7 @@ def score_predictions_cli(
33503351
(y_pred, y_true), _, _ = read_data_for_evaluation([predictions, true_labels])
33513352
typer.echo("Progress: 25%")
33523353

3353-
outputs = score_predictions(y_true, y_pred, metrics)
3354+
outputs = score_predictions(y_true, y_pred, metrics, decimals)
33543355
typer.echo("Progress: 100% \n")
33553356

33563357
typer.echo(f"Results: {str(outputs)}")

eis_toolkit/evaluation/classification_label_evaluation.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from numbers import Number
2-
from typing import Dict
32

43
import numpy as np
4+
from beartype.typing import Dict, Optional
55
from sklearn.metrics import accuracy_score, confusion_matrix, precision_recall_fscore_support
66

77

8-
def summarize_label_metrics_binary(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, Number]:
8+
def summarize_label_metrics_binary(
9+
y_true: np.ndarray,
10+
y_pred: np.ndarray,
11+
decimals: Optional[int] = None,
12+
) -> Dict[str, Number]:
913
"""
1014
Generate a comprehensive report of various evaluation metrics for binary classification results.
1115
@@ -15,18 +19,21 @@ def summarize_label_metrics_binary(y_true: np.ndarray, y_pred: np.ndarray) -> Di
1519
Args:
1620
y_true: True labels.
1721
y_pred: Predicted labels. The array should come from a binary classifier.
22+
decimals: Number of decimals used in rounding the scores. If None, scores are not rounded.
23+
Defaults to None.
1824
1925
Returns:
2026
A dictionary containing the evaluated metrics.
2127
"""
2228
metrics = {}
2329

24-
metrics["Accuracy"] = accuracy_score(y_true, y_pred)
30+
accuracy = accuracy_score(y_true, y_pred)
31+
metrics["Accuracy"] = round(accuracy, decimals) if decimals is not None else accuracy
2532

2633
precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary")
27-
metrics["Precision"] = precision
28-
metrics["Recall"] = recall
29-
metrics["F1_score"] = f1
34+
metrics["Precision"] = round(precision, decimals) if decimals is not None else precision
35+
metrics["Recall"] = round(recall, decimals) if decimals is not None else recall
36+
metrics["F1_score"] = round(f1, decimals) if decimals is not None else f1
3037

3138
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
3239
metrics["True_negatives"] = tn

eis_toolkit/evaluation/classification_probability_evaluation.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
from typing import Dict
2-
31
import matplotlib.pyplot as plt
42
import numpy as np
53
import seaborn as sns
6-
from beartype.typing import Optional
4+
from beartype.typing import Dict, Optional
75
from sklearn.calibration import CalibrationDisplay
86
from sklearn.metrics import (
97
DetCurveDisplay,
@@ -16,7 +14,11 @@
1614
)
1715

1816

19-
def summarize_probability_metrics(y_true: np.ndarray, y_prob: np.ndarray) -> Dict[str, float]:
17+
def summarize_probability_metrics(
18+
y_true: np.ndarray,
19+
y_prob: np.ndarray,
20+
decimals: Optional[int] = None,
21+
) -> Dict[str, float]:
2022
"""
2123
Generate a comprehensive report of various evaluation metrics for classification probabilities.
2224
@@ -26,6 +28,8 @@ def summarize_probability_metrics(y_true: np.ndarray, y_prob: np.ndarray) -> Dic
2628
y_true: True labels.
2729
y_prob: Predicted probabilities for the positive class. The array should come from
2830
a binary classifier.
31+
decimals: Number of decimals used in rounding the scores. If None, scores are not rounded.
32+
Defaults to None.
2933
3034
Returns:
3135
A dictionary containing the evaluated metrics.
@@ -37,6 +41,9 @@ def summarize_probability_metrics(y_true: np.ndarray, y_prob: np.ndarray) -> Dic
3741
metrics["average_precision"] = average_precision_score(y_true, y_prob)
3842
metrics["brier_score_loss"] = brier_score_loss(y_true, y_prob)
3943

44+
for key, value in metrics.items():
45+
metrics[key] = round(value, decimals) if decimals is not None else value
46+
4047
return metrics
4148

4249

eis_toolkit/evaluation/scoring.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from numbers import Number
2+
from typing import Optional
23

34
import numpy as np
45
import pandas as pd
@@ -19,7 +20,10 @@
1920

2021
@beartype
2122
def score_predictions(
22-
y_true: Union[np.ndarray, pd.Series], y_pred: Union[np.ndarray, pd.Series], metrics: Union[str, Sequence[str]]
23+
y_true: Union[np.ndarray, pd.Series],
24+
y_pred: Union[np.ndarray, pd.Series],
25+
metrics: Union[str, Sequence[str]],
26+
decimals: Optional[int] = None,
2327
) -> Union[Number, Dict[str, Number]]:
2428
"""
2529
Score model predictions with given metrics.
@@ -34,18 +38,20 @@ def score_predictions(
3438
y_pred: Predicted labels.
3539
metrics: The metrics to use for scoring the model. Select only metrics applicable
3640
for the model type.
41+
decimals: Number of decimals used in rounding the scores. If None, scores are not rounded.
42+
Defaults to None.
3743
3844
Returns:
3945
Metric scores as a dictionary if multiple metrics, otherwise just the metric value.
4046
"""
4147
if isinstance(metrics, str):
4248
score = _score_predictions(y_true, y_pred, metrics)
43-
return score
49+
return round(score, decimals) if decimals is not None else score
4450
else:
4551
out_metrics = {}
4652
for metric in metrics:
4753
score = _score_predictions(y_true, y_pred, metric)
48-
out_metrics[metric] = score
54+
out_metrics[metric] = round(score, decimals) if decimals is not None else score
4955
return out_metrics
5056

5157

eis_toolkit/prediction/machine_learning_general.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def _train_and_validate_sklearn_model(
350350

351351
out_metrics = {}
352352
for metric in metrics:
353-
score = score_predictions(y_valid, y_pred, metric)
353+
score = score_predictions(y_valid, y_pred, metric, decimals=3)
354354
out_metrics[metric] = score
355355

356356
# Validation approach 3: Cross-validation
@@ -369,7 +369,7 @@ def _train_and_validate_sklearn_model(
369369
y_pred = model.predict(X[valid_index])
370370

371371
for metric in metrics:
372-
score = score_predictions(y[valid_index], y_pred, metric)
372+
score = score_predictions(y[valid_index], y_pred, metric, decimals=3)
373373
all_scores = out_metrics[metric][f"{metric}_all"]
374374
all_scores.append(score)
375375

tests/prediction/gradient_boosting_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ def test_gradient_boosting_regressor():
4040
np.testing.assert_equal(len(predicted_labels), len(Y_IRIS))
4141

4242
np.testing.assert_equal(count_false, 150)
43-
np.testing.assert_almost_equal(out_metrics["mae"], 0.03101, decimal=4)
44-
np.testing.assert_almost_equal(out_metrics["mse"], 0.00434, decimal=4)
45-
np.testing.assert_almost_equal(out_metrics["rmse"], 0.06593, decimal=4)
46-
np.testing.assert_almost_equal(out_metrics["r2"], 0.99377, decimal=4)
43+
np.testing.assert_equal(out_metrics["mae"], 0.031)
44+
np.testing.assert_equal(out_metrics["mse"], 0.004)
45+
np.testing.assert_equal(out_metrics["rmse"], 0.066)
46+
np.testing.assert_equal(out_metrics["r2"], 0.994)
4747

4848

4949
def test_invalid_learning_rate():

tests/prediction/random_forest_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ def test_random_forest_regressor():
3737
np.testing.assert_equal(len(predicted_labels), len(Y_IRIS))
3838

3939
np.testing.assert_equal(count_false, 35)
40-
np.testing.assert_almost_equal(out_metrics["mae"], 0.01366, decimal=4)
41-
np.testing.assert_almost_equal(out_metrics["mse"], 0.00138, decimal=4)
42-
np.testing.assert_almost_equal(out_metrics["rmse"], 0.03719, decimal=4)
43-
np.testing.assert_almost_equal(out_metrics["r2"], 0.99802, decimal=4)
40+
np.testing.assert_equal(out_metrics["mae"], 0.014)
41+
np.testing.assert_equal(out_metrics["mse"], 0.001)
42+
np.testing.assert_equal(out_metrics["rmse"], 0.037)
43+
np.testing.assert_equal(out_metrics["r2"], 0.998)
4444

4545

4646
def test_random_forest_invalid_n_estimators():

0 commit comments

Comments
 (0)