7
7
import cv2 as cv
8
8
import shutil
9
9
import random
10
+ from console_progressbar import ProgressBar
10
11
11
12
12
13
def ensure_folder (folder ):
@@ -22,9 +23,11 @@ def save_train_data(fnames, labels, bboxes):
22
23
num_valid = num_samples - num_train
23
24
train_indexes = random .sample (range (num_samples ), num_valid )
24
25
26
+ pb = ProgressBar (total = 100 , prefix = 'Save train data' , suffix = '' , decimals = 3 , length = 50 , fill = '>' )
27
+
25
28
for i in range (num_samples ):
26
29
fname = fnames [i ]
27
- label = labels [i ][ 0 ]
30
+ label = labels [i ]
28
31
(x1 , y1 , x2 , y2 ) = bboxes [i ]
29
32
30
33
src_path = os .path .join (src_folder , fname )
@@ -36,7 +39,8 @@ def save_train_data(fnames, labels, bboxes):
36
39
y1 = max (0 , y1 - margin )
37
40
x2 = min (x2 + 1 + margin , width )
38
41
y2 = min (y2 + 1 + margin , height )
39
- print ("{} -> {}" .format (fname , label ))
42
+ # print("{} -> {}".format(fname, label))
43
+ pb .print_progress_bar ((i + 1 ) * 100 / num_samples )
40
44
41
45
if i in train_indexes :
42
46
dst_folder = 'data/train'
@@ -87,13 +91,15 @@ def process_train_data():
87
91
fnames = []
88
92
class_ids = []
89
93
bboxes = []
94
+ labels = []
90
95
91
96
for annotation in annotations :
92
97
bbox_x1 = annotation [0 ][0 ][0 ][0 ]
93
98
bbox_y1 = annotation [0 ][1 ][0 ][0 ]
94
99
bbox_x2 = annotation [0 ][2 ][0 ][0 ]
95
100
bbox_y2 = annotation [0 ][3 ][0 ][0 ]
96
101
class_id = annotation [0 ][4 ][0 ][0 ]
102
+ labels .append (str (class_id ))
97
103
fname = annotation [0 ][5 ][0 ]
98
104
bboxes .append ((bbox_x1 , bbox_y1 , bbox_x2 , bbox_y2 ))
99
105
class_ids .append (class_id )
@@ -103,11 +109,6 @@ def process_train_data():
103
109
print (np .unique (class_ids ))
104
110
print ('The number of different cars is %d' % labels_count )
105
111
106
- labels = []
107
- for class_id in class_ids :
108
- class_name = class_names [class_id - 1 ][0 ]
109
- labels .append (class_name )
110
-
111
112
save_train_data (fnames , labels , bboxes )
112
113
113
114
@@ -134,7 +135,7 @@ def process_test_data():
134
135
135
136
if __name__ == '__main__' :
136
137
# parameters
137
- img_width , img_height = 224 , 224
138
+ img_width , img_height = 227 , 227
138
139
139
140
print ('Extracting cars_train.tgz...' )
140
141
if not os .path .exists ('cars_train' ):
@@ -153,7 +154,7 @@ def process_test_data():
153
154
class_names = cars_meta ['class_names' ] # shape=(1, 196)
154
155
class_names = np .transpose (class_names )
155
156
print ('class_names.shape: ' + str (class_names .shape ))
156
- print ('Sample class_name: {} ' .format (class_names [8 ][0 ][0 ]))
157
+ print ('Sample class_name: [{}] ' .format (class_names [8 ][0 ][0 ]))
157
158
158
159
ensure_folder ('data/train' )
159
160
ensure_folder ('data/valid' )
@@ -163,9 +164,9 @@ def process_test_data():
163
164
# process_test_data()
164
165
165
166
# clean up
166
- shutil .rmtree ('cars_train' )
167
- shutil .rmtree ('cars_test' )
168
- shutil .rmtree ('devkit' )
167
+ # shutil.rmtree('cars_train')
168
+ # shutil.rmtree('cars_test')
169
+ # shutil.rmtree('devkit')
169
170
170
171
171
172
0 commit comments