Skip to content

Commit 9e522bb

Browse files
committed
improve pre-proc
1 parent 9797308 commit 9e522bb

File tree

4 files changed

+22
-12
lines changed

4 files changed

+22
-12
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@ custom_layers/__pycache__/
77
data/
88
imagenet_models/
99
logs/
10+
cars_test/
11+
cars_train/
12+
devkit/

Requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ tensorflow-gpu
33
keras
44
pillow
55
sklearn
6+
console-progressbar

clean.sh

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
rm data/ -r
2+
rm logs/ -r
3+
rm cars_test/ -r
4+
rm cars_train/ -r
5+
rm devkit/ -r

pre-process.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import cv2 as cv
88
import shutil
99
import random
10+
from console_progressbar import ProgressBar
1011

1112

1213
def ensure_folder(folder):
@@ -22,9 +23,11 @@ def save_train_data(fnames, labels, bboxes):
2223
num_valid = num_samples - num_train
2324
train_indexes = random.sample(range(num_samples), num_valid)
2425

26+
pb = ProgressBar(total=100, prefix='Save train data', suffix='', decimals=3, length=50, fill='>')
27+
2528
for i in range(num_samples):
2629
fname = fnames[i]
27-
label = labels[i][0]
30+
label = labels[i]
2831
(x1, y1, x2, y2) = bboxes[i]
2932

3033
src_path = os.path.join(src_folder, fname)
@@ -36,7 +39,8 @@ def save_train_data(fnames, labels, bboxes):
3639
y1 = max(0, y1 - margin)
3740
x2 = min(x2 + 1 + margin, width)
3841
y2 = min(y2 + 1 + margin, height)
39-
print("{} -> {}".format(fname, label))
42+
# print("{} -> {}".format(fname, label))
43+
pb.print_progress_bar((i+1) * 100 / num_samples)
4044

4145
if i in train_indexes:
4246
dst_folder = 'data/train'
@@ -87,13 +91,15 @@ def process_train_data():
8791
fnames = []
8892
class_ids = []
8993
bboxes = []
94+
labels = []
9095

9196
for annotation in annotations:
9297
bbox_x1 = annotation[0][0][0][0]
9398
bbox_y1 = annotation[0][1][0][0]
9499
bbox_x2 = annotation[0][2][0][0]
95100
bbox_y2 = annotation[0][3][0][0]
96101
class_id = annotation[0][4][0][0]
102+
labels.append(str(class_id))
97103
fname = annotation[0][5][0]
98104
bboxes.append((bbox_x1, bbox_y1, bbox_x2, bbox_y2))
99105
class_ids.append(class_id)
@@ -103,11 +109,6 @@ def process_train_data():
103109
print(np.unique(class_ids))
104110
print('The number of different cars is %d' % labels_count)
105111

106-
labels = []
107-
for class_id in class_ids:
108-
class_name = class_names[class_id-1][0]
109-
labels.append(class_name)
110-
111112
save_train_data(fnames, labels, bboxes)
112113

113114

@@ -134,7 +135,7 @@ def process_test_data():
134135

135136
if __name__ == '__main__':
136137
# parameters
137-
img_width, img_height = 224, 224
138+
img_width, img_height = 227, 227
138139

139140
print('Extracting cars_train.tgz...')
140141
if not os.path.exists('cars_train'):
@@ -153,7 +154,7 @@ def process_test_data():
153154
class_names = cars_meta['class_names'] # shape=(1, 196)
154155
class_names = np.transpose(class_names)
155156
print('class_names.shape: ' + str(class_names.shape))
156-
print('Sample class_name: {}'.format(class_names[8][0][0]))
157+
print('Sample class_name: [{}]'.format(class_names[8][0][0]))
157158

158159
ensure_folder('data/train')
159160
ensure_folder('data/valid')
@@ -163,9 +164,9 @@ def process_test_data():
163164
# process_test_data()
164165

165166
# clean up
166-
shutil.rmtree('cars_train')
167-
shutil.rmtree('cars_test')
168-
shutil.rmtree('devkit')
167+
# shutil.rmtree('cars_train')
168+
# shutil.rmtree('cars_test')
169+
# shutil.rmtree('devkit')
169170

170171

171172

0 commit comments

Comments
 (0)