|
| 1 | +from typing import Sequence |
| 2 | + |
1 | 3 | import numpy as np
|
2 |
| -import pandas |
3 | 4 | import pandas as pd
|
4 | 5 | import sklearn.neural_network
|
5 | 6 | from sklearn.inspection import permutation_importance
|
|
8 | 9 |
|
9 | 10 |
|
10 | 11 | def evaluate_feature_importance(
|
11 |
| - clf: sklearn.neural_network or sklearn.linear_model, |
| 12 | + classifier: sklearn.base.BaseEstimator, |
12 | 13 | x_test: np.ndarray,
|
13 | 14 | y_test: np.ndarray,
|
14 |
| - feature_names: list[str], |
| 15 | + feature_names: Sequence[str], |
15 | 16 | number_of_repetition: int = 50,
|
16 | 17 | random_state: int = 0,
|
17 |
| -) -> (pandas.DataFrame, dict): |
| 18 | +) -> (pd.DataFrame, dict): |
18 | 19 | """
|
19 | 20 | Evaluate the feature importance of a sklearn classifier or linear model.
|
20 | 21 |
|
21 | 22 | Parameters:
|
22 |
| - clf (Any sklearn nn model or lm model): Trained classifier. |
23 |
| - x_test (np.ndarray): Testing feature data (X data need to be normalized / standardized). |
24 |
| - y_test (np.ndarray): Testing target data. |
25 |
| - feature_names (list): Names of the feature columns. |
26 |
| - number_of_repetition (int): Number of iteration used when calculate feature importance (default 50). |
27 |
| - random_state (int): random state for repeatability of results (Default 0). |
| 23 | + classifier: Trained classifier. |
| 24 | + x_test: Testing feature data (X data need to be normalized / standardized). |
| 25 | + y_test: Testing target data. |
| 26 | + feature_names: Names of the feature columns. |
| 27 | + number_of_repetition: Number of iteration used when calculate feature importance (default 50). |
| 28 | + random_state: random state for repeatability of results (Default 0). |
28 | 29 | Return:
|
29 |
| - feature_importance (pd.Dataframe): A dataframe composed by features name and Importance value |
30 |
| - result (dict[object]): The resulted object with importance mean, importance std, and overall importance |
31 |
| - Raise: |
| 30 | + A dataframe composed by features name and Importance value |
| 31 | + The resulted object with importance mean, importance std, and overall importance |
| 32 | + Raises: |
32 | 33 | InvalidDatasetException: When the dataset is None.
|
33 | 34 | """
|
34 | 35 |
|
35 | 36 | if x_test is None or y_test is None:
|
36 | 37 | raise InvalidDatasetException
|
37 | 38 |
|
38 | 39 | result = permutation_importance(
|
39 |
| - clf, x_test, y_test.ravel(), n_repeats=number_of_repetition, random_state=random_state |
| 40 | + classifier, x_test, y_test.ravel(), n_repeats=number_of_repetition, random_state=random_state |
40 | 41 | )
|
41 | 42 |
|
42 | 43 | feature_importance = pd.DataFrame({"Feature": feature_names, "Importance": result.importances_mean})
|
43 | 44 |
|
44 | 45 | feature_importance["Importance"] = feature_importance["Importance"] * 100
|
45 | 46 | feature_importance = feature_importance.sort_values(by="Importance", ascending=False)
|
46 |
| - # feature_importance['Importance'] = feature_importance['Importance'].apply(lambda x: '{:.6f}%'.format(x)) |
47 | 47 |
|
48 | 48 | return feature_importance, result
|
0 commit comments