14
14
save_model ,
15
15
split_data ,
16
16
)
17
+ from eis_toolkit .prediction .model_utils import test_model as model_test
17
18
18
19
TEST_DIR = Path (__file__ ).parent .parent
19
20
@@ -50,7 +51,7 @@ def test_train_and_evaluate_with_split():
50
51
)
51
52
52
53
assert isinstance (model , RandomForestClassifier )
53
- assert len (out_metrics ) == 4
54
+ np . testing . assert_equal ( len (out_metrics ), 4 )
54
55
55
56
56
57
def test_train_and_evaluate_with_kfold_cv ():
@@ -60,7 +61,7 @@ def test_train_and_evaluate_with_kfold_cv():
60
61
)
61
62
62
63
assert isinstance (model , RandomForestClassifier )
63
- assert len (out_metrics ) == 4
64
+ np . testing . assert_equal ( len (out_metrics ), 4 )
64
65
65
66
66
67
def test_train_and_evaluate_with_skfold_cv ():
@@ -70,7 +71,7 @@ def test_train_and_evaluate_with_skfold_cv():
70
71
)
71
72
72
73
assert isinstance (model , RandomForestClassifier )
73
- assert len (out_metrics ) == 4
74
+ np . testing . assert_equal ( len (out_metrics ), 4 )
74
75
75
76
76
77
def test_binary_classification ():
@@ -97,7 +98,7 @@ def test_binary_classification():
97
98
)
98
99
99
100
assert isinstance (model , RandomForestClassifier )
100
- assert len (out_metrics ) == 4
101
+ np . testing . assert_equal ( len (out_metrics ), 4 )
101
102
102
103
103
104
def test_splitting ():
@@ -109,6 +110,18 @@ def test_splitting():
109
110
np .testing .assert_equal (len (y_test ), len (Y_IRIS ) * 0.2 )
110
111
111
112
113
+ def test_test_model_sklearn ():
114
+ """Test that test model works as expected with a Sklearn model."""
115
+ X_train , X_test , y_train , y_test = split_data (X_IRIS , Y_IRIS , split_size = 0.2 )
116
+
117
+ model , _ = _train_and_validate_sklearn_model (
118
+ X_train , y_train , model = RF_MODEL , validation_method = "none" , metrics = CLF_METRICS , random_state = 42
119
+ )
120
+
121
+ out_metrics = model_test (X_test , y_test , model )
122
+ np .testing .assert_equal (out_metrics ["accuracy" ], 1.0 )
123
+
124
+
112
125
def test_predict_sklearn ():
113
126
"""Test that predict works as expected with a Sklearn model."""
114
127
X_train , X_test , y_train , y_test = split_data (X_IRIS , Y_IRIS , split_size = 0.2 )
@@ -117,8 +130,8 @@ def test_predict_sklearn():
117
130
X_train , y_train , model = RF_MODEL , validation_method = "none" , metrics = CLF_METRICS , random_state = 42
118
131
)
119
132
120
- predicted_labels = predict (model , X_test )
121
- assert len (predicted_labels ) == len (y_test )
133
+ predicted_labels = predict (X_test , model )
134
+ np . testing . assert_equal ( len (predicted_labels ), len (y_test ) )
122
135
123
136
124
137
def test_save_and_load_model ():
0 commit comments