Skip to content

Commit 5ca25b3

Browse files
author
Mika Sorvoja
authored
466 add feature importance CLI and modify tool (#468)
* Change n_repeats default to 10 from 50 * Rename parameters, improve documentation * Add missing checks * Add CLI function
1 parent 6bee396 commit 5ca25b3

File tree

3 files changed

+89
-29
lines changed

3 files changed

+89
-29
lines changed

eis_toolkit/cli.py

+34
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,40 @@ def local_morans_i_cli(
871871
typer.echo(f"Local Moran's I completed, output vector saved to {output_vector}.")
872872

873873

874+
# FEATURE IMPORTANCE
875+
@app.command()
876+
def feature_importance_cli(
877+
model_file: INPUT_FILE_OPTION,
878+
input_rasters: INPUT_FILES_ARGUMENT,
879+
target_labels: INPUT_FILE_OPTION,
880+
n_repeats: int = 10,
881+
random_state: Optional[int] = None,
882+
):
883+
"""Evaluate the feature importance of a sklearn classifier or regressor."""
884+
from eis_toolkit.exploratory_analyses.feature_importance import evaluate_feature_importance
885+
from eis_toolkit.prediction.machine_learning_general import load_model, prepare_data_for_ml
886+
887+
typer.echo("Progress: 10%")
888+
889+
model = load_model(model_file)
890+
typer.echo("Progress: 20%")
891+
892+
X, y, _, _ = prepare_data_for_ml(input_rasters, target_labels)
893+
typer.echo("Progress: 30%")
894+
895+
feature_names = [raster.name for raster in input_rasters]
896+
typer.echo("Progress: 40%")
897+
898+
feature_importance, _ = evaluate_feature_importance(model, X, y, feature_names, n_repeats, random_state)
899+
typer.echo("Progress: 80%")
900+
901+
results = dict(zip(feature_importance["Feature"], feature_importance["Importance"]))
902+
json_str = json.dumps(results)
903+
typer.echo("Progress: 100%")
904+
905+
typer.echo(f"Results: {json_str}")
906+
907+
874908
# --- RASTER PROCESSING ---
875909

876910

eis_toolkit/exploratory_analyses/feature_importance.py

+26-16
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,32 @@
55
from beartype.typing import Optional, Sequence
66
from sklearn.inspection import permutation_importance
77

8-
from eis_toolkit.exceptions import InvalidDatasetException, InvalidParameterValueException
8+
from eis_toolkit.exceptions import (
9+
InvalidDatasetException,
10+
InvalidParameterValueException,
11+
NonMatchingParameterLengthsException,
12+
)
913

1014

1115
@beartype
1216
def evaluate_feature_importance(
1317
model: sklearn.base.BaseEstimator,
14-
x_test: np.ndarray,
15-
y_test: np.ndarray,
18+
X: np.ndarray,
19+
y: np.ndarray,
1620
feature_names: Sequence[str],
17-
n_repeats: int = 50,
21+
n_repeats: int = 10,
1822
random_state: Optional[int] = None,
1923
) -> tuple[pd.DataFrame, dict]:
2024
"""
21-
Evaluate the feature importance of a sklearn classifier or regressor.
25+
Evaluate the feature importance of a Sklearn classifier or regressor.
2226
2327
Args:
2428
model: A trained and fitted Sklearn model.
25-
x_test: Testing feature data (X data need to be normalized / standardized).
26-
y_test: Testing label data.
27-
feature_names: Names of the feature columns.
28-
n_repeats: Number of iteration used when calculate feature importance. Defaults to 50.
29-
random_state: random state for repeatability of results. Optional parameter.
29+
X: Feature data.
30+
y: Target labels.
31+
feature_names: Names of features in X.
32+
n_repeats: Number of iteration used when calculating feature importance. Defaults to 10.
33+
random_state: Seed for random number generation. Defaults to None.
3034
3135
Returns:
3236
A dataframe containing features and their importance.
@@ -37,18 +41,24 @@ def evaluate_feature_importance(
3741
InvalidParameterValueException: Value for 'n_repeats' is not at least one.
3842
"""
3943

40-
if x_test.size == 0:
41-
raise InvalidDatasetException("Array 'x_test' is empty.")
44+
if X.size == 0:
45+
raise InvalidDatasetException("Feature matrix X is empty.")
4246

43-
if y_test.size == 0:
44-
raise InvalidDatasetException("Array 'y_test' is empty.")
47+
if y.size == 0:
48+
raise InvalidDatasetException("Target labels y is empty.")
4549

4650
if n_repeats < 1:
4751
raise InvalidParameterValueException("Value for 'n_repeats' is less than one.")
4852

49-
result = permutation_importance(model, x_test, y_test.ravel(), n_repeats=n_repeats, random_state=random_state)
53+
if len(X) != len(y):
54+
raise NonMatchingParameterLengthsException("Feature matrix X and target labels y must have the same length.")
5055

51-
feature_importance = pd.DataFrame({"Feature": feature_names, "Importance": result.importances_mean * 100})
56+
if len(feature_names) != X.shape[1]:
57+
raise InvalidParameterValueException("Number of feature names must match the number of input features.")
58+
59+
result = permutation_importance(model, X, y.ravel(), n_repeats=n_repeats, random_state=random_state)
60+
61+
feature_importance = pd.DataFrame({"Feature": feature_names, "Importance": result.importances_mean})
5262

5363
feature_importance = feature_importance.sort_values(by="Importance", ascending=False)
5464

tests/exploratory_analyses/feature_importance_test.py

+29-13
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
from sklearn.neural_network import MLPClassifier
55
from sklearn.preprocessing import StandardScaler
66

7-
from eis_toolkit.exceptions import InvalidDatasetException, InvalidParameterValueException
7+
from eis_toolkit.exceptions import (
8+
InvalidDatasetException,
9+
InvalidParameterValueException,
10+
NonMatchingParameterLengthsException,
11+
)
812
from eis_toolkit.exploratory_analyses.feature_importance import evaluate_feature_importance
913

1014
feature_names = [
@@ -42,43 +46,55 @@ def test_empty_data():
4246
empty_data = np.array([])
4347
empty_labels = np.array([])
4448
with pytest.raises(InvalidDatasetException):
45-
_, _ = evaluate_feature_importance(
46-
model=classifier, x_test=empty_data, y_test=labels, feature_names=feature_names
47-
)
49+
_, _ = evaluate_feature_importance(model=classifier, X=empty_data, y=labels, feature_names=feature_names)
4850

4951
with pytest.raises(InvalidDatasetException):
50-
_, _ = evaluate_feature_importance(
51-
model=classifier, x_test=data, y_test=empty_labels, feature_names=feature_names
52-
)
52+
_, _ = evaluate_feature_importance(model=classifier, X=data, y=empty_labels, feature_names=feature_names)
5353

5454

5555
def test_invalid_n_repeats():
5656
"""Test that invalid value for 'n_repeats' raises exception."""
5757
with pytest.raises(InvalidParameterValueException):
58-
_, _ = evaluate_feature_importance(
59-
model=classifier, x_test=data, y_test=labels, feature_names=feature_names, n_repeats=0
60-
)
58+
_, _ = evaluate_feature_importance(model=classifier, X=data, y=labels, feature_names=feature_names, n_repeats=0)
6159

6260

6361
def test_model_output():
6462
"""Test that function output is as expected."""
6563
classifier.fit(data, labels.ravel())
6664
feature_importance, importance_results = evaluate_feature_importance(
67-
model=classifier, x_test=data, y_test=labels, feature_names=feature_names, random_state=0
65+
model=classifier, X=data, y=labels, feature_names=feature_names, n_repeats=50, random_state=0
6866
)
6967

7068
np.testing.assert_almost_equal(
7169
feature_importance.loc[feature_importance["Feature"] == "EM_ratio", "Importance"].values[0],
72-
desired=12.923077,
70+
desired=0.129231,
7371
decimal=6,
7472
)
7573
np.testing.assert_almost_equal(
7674
feature_importance.loc[feature_importance["Feature"] == "EM_Qd", "Importance"].values[0],
77-
desired=4.461538,
75+
desired=0.044615,
7876
decimal=6,
7977
)
8078
np.testing.assert_equal(len(feature_importance), desired=len(feature_names))
8179
np.testing.assert_equal(
8280
tuple(importance_results.keys()),
8381
desired=("importances_mean", "importances_std", "importances"),
8482
)
83+
84+
85+
def test_invalid_input_lengths():
86+
"""Test that non matcing X and y lengths raises an exception."""
87+
labels = np.random.randint(2, size=12)
88+
with pytest.raises(NonMatchingParameterLengthsException):
89+
_, _ = evaluate_feature_importance(model=classifier, X=data, y=labels, feature_names=feature_names)
90+
91+
92+
def test_invalid_number_of_feature_names():
93+
"""Test that invalid number of feature names raises an exception."""
94+
with pytest.raises(InvalidParameterValueException):
95+
_, _ = evaluate_feature_importance(
96+
model=classifier,
97+
X=data,
98+
y=labels,
99+
feature_names=["a", "b", "c"],
100+
)

0 commit comments

Comments
 (0)