1
1
# -*- coding: utf-8 -*-
2
2
3
- import itertools
4
- import numpy as np
3
+ import os
4
+
5
5
import matplotlib .pyplot as plt
6
- from sklearn . metrics import confusion_matrix
6
+ import numpy as np
7
7
from keras .preprocessing import image
8
- import os
8
+ from sklearn .metrics import confusion_matrix
9
+ from tqdm import tqdm
10
+
9
11
from utils import load_model
10
- from console_progressbar import ProgressBar
11
12
12
13
13
14
def decode_predictions (preds , top = 5 ):
@@ -29,8 +30,8 @@ def predict(img_dir, model):
29
30
30
31
y_pred = []
31
32
y_test = []
32
- pb = ProgressBar ( total = 100 , prefix = 'Predict data' , suffix = '' , decimals = 3 , length = 50 , fill = '=' )
33
- for img_path in img_files :
33
+
34
+ for img_path in tqdm ( img_files ) :
34
35
# print(img_path)
35
36
img = image .load_img (img_path , target_size = (224 , 224 ))
36
37
x = image .img_to_array (img )
@@ -39,11 +40,10 @@ def predict(img_dir, model):
39
40
pred_label = decoded [0 ][0 ][0 ]
40
41
# print(pred_label)
41
42
y_pred .append (pred_label )
42
- tokens = img_path .split (' \\ ' )
43
+ tokens = img_path .split (os . pathsep )
43
44
class_id = int (tokens [- 2 ])
44
45
# print(str(class_id))
45
46
y_test .append (class_id )
46
- pb .print_progress_bar (len (y_pred ) * 100 / len (img_files ))
47
47
48
48
return y_pred , y_test
49
49
@@ -67,7 +67,7 @@ def plot_confusion_matrix(cm, classes,
67
67
plt .imshow (cm , interpolation = 'nearest' , cmap = cmap )
68
68
plt .title (title )
69
69
plt .colorbar ()
70
- #tick_marks = np.arange(len(classes))
70
+ # tick_marks = np.arange(len(classes))
71
71
# plt.xticks(tick_marks, classes, rotation=45)
72
72
# plt.yticks(tick_marks, classes)
73
73
0 commit comments