Skip to content

Commit 40f7c1f

Browse files
committed
fix(evaluation): Fix invalid scoring for classifier metrics
1 parent 5c8d32e commit 40f7c1f

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

eis_toolkit/evaluation/scoring.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,10 @@ def score_predictions(
5353
def _score_predictions(
5454
y_true: Union[np.ndarray, pd.Series], y_pred: Union[np.ndarray, pd.Series], metric: str
5555
) -> Number:
56+
num_classes = len(np.unique(y_true))
57+
5658
# Multiclass classification
57-
if len(y_true) > 2:
59+
if num_classes > 2:
5860
average_method = "micro"
5961
# Binary classification
6062
else:

tests/evaluation/scoring_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
X, y = make_classification(n_samples=200, n_features=20, n_informative=2, n_redundant=10, random_state=42)
99
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
10-
rf_model, history = random_forest_classifier_train(X_train, y_train)
10+
rf_model, history = random_forest_classifier_train(X_train, y_train, random_state=42)
1111
y_pred = predict_classifier(X_test, rf_model, include_probabilities=False)
1212

1313

0 commit comments

Comments
 (0)