Skip to content

Commit f7a6269

Browse files
committed
feat(evaluation): Add decimals parameter to summarize probability and label metrics. Make corresponding CLI tools default to 3 decimals
1 parent 9c311b6 commit f7a6269

File tree

3 files changed

+26
-12
lines changed

3 files changed

+26
-12
lines changed

eis_toolkit/cli.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3109,7 +3109,7 @@ def summarize_probability_metrics_cli(true_labels: INPUT_FILE_OPTION, probabilit
31093109
(y_prob, y_true), _, _ = read_data_for_evaluation([probabilities, true_labels])
31103110
typer.echo("Progress: 25%")
31113111

3112-
results_dict = summarize_probability_metrics(y_true=y_true, y_prob=y_prob)
3112+
results_dict = summarize_probability_metrics(y_true=y_true, y_prob=y_prob, decimals=3)
31133113

31143114
typer.echo("Progress: 75%")
31153115

@@ -3135,7 +3135,7 @@ def summarize_label_metrics_binary_cli(true_labels: INPUT_FILE_OPTION, predictio
31353135
(y_pred, y_true), _, _ = read_data_for_evaluation([predictions, true_labels])
31363136
typer.echo("Progress: 25%")
31373137

3138-
results_dict = summarize_label_metrics_binary(y_true=y_true, y_pred=y_pred)
3138+
results_dict = summarize_label_metrics_binary(y_true=y_true, y_pred=y_pred, decimals=3)
31393139
typer.echo("Progress: 75%")
31403140

31413141
typer.echo("Progress: 100% \n")

eis_toolkit/evaluation/classification_label_evaluation.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from numbers import Number
2-
from typing import Dict
32

43
import numpy as np
4+
from beartype.typing import Dict, Optional
55
from sklearn.metrics import accuracy_score, confusion_matrix, precision_recall_fscore_support
66

77

8-
def summarize_label_metrics_binary(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, Number]:
8+
def summarize_label_metrics_binary(
9+
y_true: np.ndarray,
10+
y_pred: np.ndarray,
11+
decimals: Optional[int] = None,
12+
) -> Dict[str, Number]:
913
"""
1014
Generate a comprehensive report of various evaluation metrics for binary classification results.
1115
@@ -15,18 +19,21 @@ def summarize_label_metrics_binary(y_true: np.ndarray, y_pred: np.ndarray) -> Di
1519
Args:
1620
y_true: True labels.
1721
y_pred: Predicted labels. The array should come from a binary classifier.
22+
decimals: Number of decimals used in rounding the scores. If None, scores are not rounded.
23+
Defaults to None.
1824
1925
Returns:
2026
A dictionary containing the evaluated metrics.
2127
"""
2228
metrics = {}
2329

24-
metrics["Accuracy"] = accuracy_score(y_true, y_pred)
30+
accuracy = accuracy_score(y_true, y_pred)
31+
metrics["Accuracy"] = round(accuracy, decimals) if decimals is not None else accuracy
2532

2633
precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary")
27-
metrics["Precision"] = precision
28-
metrics["Recall"] = recall
29-
metrics["F1_score"] = f1
34+
metrics["Precision"] = round(precision, decimals) if decimals is not None else precision
35+
metrics["Recall"] = round(recall, decimals) if decimals is not None else recall
36+
metrics["F1_score"] = round(f1, decimals) if decimals is not None else f1
3037

3138
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
3239
metrics["True_negatives"] = tn

eis_toolkit/evaluation/classification_probability_evaluation.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
from typing import Dict
2-
31
import matplotlib.pyplot as plt
42
import numpy as np
53
import seaborn as sns
6-
from beartype.typing import Optional
4+
from beartype.typing import Dict, Optional
75
from sklearn.calibration import CalibrationDisplay
86
from sklearn.metrics import (
97
DetCurveDisplay,
@@ -16,7 +14,11 @@
1614
)
1715

1816

19-
def summarize_probability_metrics(y_true: np.ndarray, y_prob: np.ndarray) -> Dict[str, float]:
17+
def summarize_probability_metrics(
18+
y_true: np.ndarray,
19+
y_prob: np.ndarray,
20+
decimals: Optional[int] = None,
21+
) -> Dict[str, float]:
2022
"""
2123
Generate a comprehensive report of various evaluation metrics for classification probabilities.
2224
@@ -26,6 +28,8 @@ def summarize_probability_metrics(y_true: np.ndarray, y_prob: np.ndarray) -> Dic
2628
y_true: True labels.
2729
y_prob: Predicted probabilities for the positive class. The array should come from
2830
a binary classifier.
31+
decimals: Number of decimals used in rounding the scores. If None, scores are not rounded.
32+
Defaults to None.
2933
3034
Returns:
3135
A dictionary containing the evaluated metrics.
@@ -37,6 +41,9 @@ def summarize_probability_metrics(y_true: np.ndarray, y_prob: np.ndarray) -> Dic
3741
metrics["average_precision"] = average_precision_score(y_true, y_prob)
3842
metrics["brier_score_loss"] = brier_score_loss(y_true, y_prob)
3943

44+
for key, value in metrics.items():
45+
metrics[key] = round(value, decimals) if decimals is not None else value
46+
4047
return metrics
4148

4249

0 commit comments

Comments
 (0)