|
4 | 4 | from sklearn.neural_network import MLPClassifier
|
5 | 5 | from sklearn.preprocessing import StandardScaler
|
6 | 6 |
|
7 |
| -from eis_toolkit.exceptions import InvalidDatasetException, InvalidParameterValueException |
| 7 | +from eis_toolkit.exceptions import ( |
| 8 | + InvalidDatasetException, |
| 9 | + InvalidParameterValueException, |
| 10 | + NonMatchingParameterLengthsException, |
| 11 | +) |
8 | 12 | from eis_toolkit.exploratory_analyses.feature_importance import evaluate_feature_importance
|
9 | 13 |
|
10 | 14 | feature_names = [
|
@@ -42,43 +46,55 @@ def test_empty_data():
|
42 | 46 | empty_data = np.array([])
|
43 | 47 | empty_labels = np.array([])
|
44 | 48 | with pytest.raises(InvalidDatasetException):
|
45 |
| - _, _ = evaluate_feature_importance( |
46 |
| - model=classifier, x_test=empty_data, y_test=labels, feature_names=feature_names |
47 |
| - ) |
| 49 | + _, _ = evaluate_feature_importance(model=classifier, X=empty_data, y=labels, feature_names=feature_names) |
48 | 50 |
|
49 | 51 | with pytest.raises(InvalidDatasetException):
|
50 |
| - _, _ = evaluate_feature_importance( |
51 |
| - model=classifier, x_test=data, y_test=empty_labels, feature_names=feature_names |
52 |
| - ) |
| 52 | + _, _ = evaluate_feature_importance(model=classifier, X=data, y=empty_labels, feature_names=feature_names) |
53 | 53 |
|
54 | 54 |
|
55 | 55 | def test_invalid_n_repeats():
|
56 | 56 | """Test that invalid value for 'n_repeats' raises exception."""
|
57 | 57 | with pytest.raises(InvalidParameterValueException):
|
58 |
| - _, _ = evaluate_feature_importance( |
59 |
| - model=classifier, x_test=data, y_test=labels, feature_names=feature_names, n_repeats=0 |
60 |
| - ) |
| 58 | + _, _ = evaluate_feature_importance(model=classifier, X=data, y=labels, feature_names=feature_names, n_repeats=0) |
61 | 59 |
|
62 | 60 |
|
63 | 61 | def test_model_output():
|
64 | 62 | """Test that function output is as expected."""
|
65 | 63 | classifier.fit(data, labels.ravel())
|
66 | 64 | feature_importance, importance_results = evaluate_feature_importance(
|
67 |
| - model=classifier, x_test=data, y_test=labels, feature_names=feature_names, random_state=0 |
| 65 | + model=classifier, X=data, y=labels, feature_names=feature_names, n_repeats=50, random_state=0 |
68 | 66 | )
|
69 | 67 |
|
70 | 68 | np.testing.assert_almost_equal(
|
71 | 69 | feature_importance.loc[feature_importance["Feature"] == "EM_ratio", "Importance"].values[0],
|
72 |
| - desired=12.923077, |
| 70 | + desired=0.129231, |
73 | 71 | decimal=6,
|
74 | 72 | )
|
75 | 73 | np.testing.assert_almost_equal(
|
76 | 74 | feature_importance.loc[feature_importance["Feature"] == "EM_Qd", "Importance"].values[0],
|
77 |
| - desired=4.461538, |
| 75 | + desired=0.044615, |
78 | 76 | decimal=6,
|
79 | 77 | )
|
80 | 78 | np.testing.assert_equal(len(feature_importance), desired=len(feature_names))
|
81 | 79 | np.testing.assert_equal(
|
82 | 80 | tuple(importance_results.keys()),
|
83 | 81 | desired=("importances_mean", "importances_std", "importances"),
|
84 | 82 | )
|
| 83 | + |
| 84 | + |
| 85 | +def test_invalid_input_lengths(): |
| 86 | + """Test that non matcing X and y lengths raises an exception.""" |
| 87 | + labels = np.random.randint(2, size=12) |
| 88 | + with pytest.raises(NonMatchingParameterLengthsException): |
| 89 | + _, _ = evaluate_feature_importance(model=classifier, X=data, y=labels, feature_names=feature_names) |
| 90 | + |
| 91 | + |
| 92 | +def test_invalid_number_of_feature_names(): |
| 93 | + """Test that invalid number of feature names raises an exception.""" |
| 94 | + with pytest.raises(InvalidParameterValueException): |
| 95 | + _, _ = evaluate_feature_importance( |
| 96 | + model=classifier, |
| 97 | + X=data, |
| 98 | + y=labels, |
| 99 | + feature_names=["a", "b", "c"], |
| 100 | + ) |
0 commit comments