Skip to content

Commit 74aefe5

Browse files
committed
before pytest
1 parent 70f02cc commit 74aefe5

File tree

3 files changed

+16
-15
lines changed

3 files changed

+16
-15
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Byte-compiled / optimized / DLL files
2+
.idea/*
23
__pycache__/
34
*.py[cod]
45
*$py.class
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from typing import Sequence
2+
13
import numpy as np
2-
import pandas
34
import pandas as pd
45
import sklearn.neural_network
56
from sklearn.inspection import permutation_importance
@@ -8,41 +9,40 @@
89

910

1011
def evaluate_feature_importance(
11-
clf: sklearn.neural_network or sklearn.linear_model,
12+
classifier: sklearn.base.BaseEstimator,
1213
x_test: np.ndarray,
1314
y_test: np.ndarray,
14-
feature_names: list[str],
15+
feature_names: Sequence[str],
1516
number_of_repetition: int = 50,
1617
random_state: int = 0,
17-
) -> (pandas.DataFrame, dict):
18+
) -> (pd.DataFrame, dict):
1819
"""
1920
Evaluate the feature importance of a sklearn classifier or linear model.
2021
2122
Parameters:
22-
clf (Any sklearn nn model or lm model): Trained classifier.
23-
x_test (np.ndarray): Testing feature data (X data need to be normalized / standardized).
24-
y_test (np.ndarray): Testing target data.
25-
feature_names (list): Names of the feature columns.
26-
number_of_repetition (int): Number of iteration used when calculate feature importance (default 50).
27-
random_state (int): random state for repeatability of results (Default 0).
23+
classifier: Trained classifier.
24+
x_test: Testing feature data (X data need to be normalized / standardized).
25+
y_test: Testing target data.
26+
feature_names: Names of the feature columns.
27+
number_of_repetition: Number of iteration used when calculate feature importance (default 50).
28+
random_state: random state for repeatability of results (Default 0).
2829
Return:
29-
feature_importance (pd.Dataframe): A dataframe composed by features name and Importance value
30-
result (dict[object]): The resulted object with importance mean, importance std, and overall importance
31-
Raise:
30+
A dataframe composed by features name and Importance value
31+
The resulted object with importance mean, importance std, and overall importance
32+
Raises:
3233
InvalidDatasetException: When the dataset is None.
3334
"""
3435

3536
if x_test is None or y_test is None:
3637
raise InvalidDatasetException
3738

3839
result = permutation_importance(
39-
clf, x_test, y_test.ravel(), n_repeats=number_of_repetition, random_state=random_state
40+
classifier, x_test, y_test.ravel(), n_repeats=number_of_repetition, random_state=random_state
4041
)
4142

4243
feature_importance = pd.DataFrame({"Feature": feature_names, "Importance": result.importances_mean})
4344

4445
feature_importance["Importance"] = feature_importance["Importance"] * 100
4546
feature_importance = feature_importance.sort_values(by="Importance", ascending=False)
46-
# feature_importance['Importance'] = feature_importance['Importance'].apply(lambda x: '{:.6f}%'.format(x))
4747

4848
return feature_importance, result

0 commit comments

Comments
 (0)