1
1
from numbers import Number
2
- from typing import Dict
3
2
4
3
import numpy as np
4
+ from beartype .typing import Dict , Optional
5
5
from sklearn .metrics import accuracy_score , confusion_matrix , precision_recall_fscore_support
6
6
7
7
8
- def summarize_label_metrics_binary (y_true : np .ndarray , y_pred : np .ndarray ) -> Dict [str , Number ]:
8
+ def summarize_label_metrics_binary (
9
+ y_true : np .ndarray ,
10
+ y_pred : np .ndarray ,
11
+ decimals : Optional [int ] = None ,
12
+ ) -> Dict [str , Number ]:
9
13
"""
10
14
Generate a comprehensive report of various evaluation metrics for binary classification results.
11
15
@@ -15,18 +19,21 @@ def summarize_label_metrics_binary(y_true: np.ndarray, y_pred: np.ndarray) -> Di
15
19
Args:
16
20
y_true: True labels.
17
21
y_pred: Predicted labels. The array should come from a binary classifier.
22
+ decimals: Number of decimals used in rounding the scores. If None, scores are not rounded.
23
+ Defaults to None.
18
24
19
25
Returns:
20
26
A dictionary containing the evaluated metrics.
21
27
"""
22
28
metrics = {}
23
29
24
- metrics ["Accuracy" ] = accuracy_score (y_true , y_pred )
30
+ accuracy = accuracy_score (y_true , y_pred )
31
+ metrics ["Accuracy" ] = round (accuracy , decimals ) if decimals is not None else accuracy
25
32
26
33
precision , recall , f1 , _ = precision_recall_fscore_support (y_true , y_pred , average = "binary" )
27
- metrics ["Precision" ] = precision
28
- metrics ["Recall" ] = recall
29
- metrics ["F1_score" ] = f1
34
+ metrics ["Precision" ] = round ( precision , decimals ) if decimals is not None else precision
35
+ metrics ["Recall" ] = round ( recall , decimals ) if decimals is not None else recall
36
+ metrics ["F1_score" ] = round ( f1 , decimals ) if decimals is not None else f1
30
37
31
38
tn , fp , fn , tp = confusion_matrix (y_true , y_pred ).ravel ()
32
39
metrics ["True_negatives" ] = tn
0 commit comments