Skip to content

Commit 22b3ebb

Browse files
YunYang1994YunYang1994
YunYang1994
authored and
YunYang1994
committed
FCN
1 parent d15e243 commit 22b3ebb

File tree

6 files changed

+163
-144
lines changed

6 files changed

+163
-144
lines changed

5-Image_Segmentation/FCN/README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
## Train PASCAL VOC2012
1111
--------------------
1212
Download VOC PASCAL trainval and test data.
13+
1314
```bashrc
1415
$ wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
1516
$ wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
1617
$ wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
1718
```
1819
Extract all of these tars into one directory and rename them, which should have the following basic structure.
20+
1921
```bashrc
2022
VOC # path: /home/yang/dataset/VOC
2123
├── test
@@ -26,7 +28,8 @@ VOC # path: /home/yang/dataset/VOC
2628
└──VOC2007 (from VOCtrainval_06-Nov-2007.tar)
2729
└──VOC2012 (from VOCtrainval_11-May-2012.tar)
2830
```
29-
Finally you need to make some transformation and train it.
31+
Finally you need to make some transformation and train it. Here is my trained weight
32+
3033
```bashrc
3134
$ python parser_voc.py --voc_path /home/yang/dataset/VOC
3235
$ python train.py
@@ -37,6 +40,8 @@ Epoch 2/30
3740
...
3841
Epoch 30/30
3942
6000/6000 [==============================] - 3552s 592ms/step - loss: 0.0811 - accuracy: 0.9797
43+
44+
$ python test.py
4045
```
4146

4247
|![image](https://user-images.githubusercontent.com/30433053/66732790-d4d56680-ee8f-11e9-9120-07b0e8aa53d4.jpg)|![image](https://user-images.githubusercontent.com/30433053/66732791-d69f2a00-ee8f-11e9-9c5d-16cc84bc7e9e.jpg)|![image](https://user-images.githubusercontent.com/30433053/66732795-da32b100-ee8f-11e9-9d85-f0ddba7a3ab1.jpg)|

5-Image_Segmentation/FCN/config.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

5-Image_Segmentation/FCN/parser_voc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
if not os.path.exists("./data"): os.mkdir("./data")
2121
if not os.path.exists("./data/train_labels"): os.mkdir("./data/train_labels")
2222
if not os.path.exists("./data/test_labels"): os.mkdir("./data/test_labels")
23+
if not os.path.exists("./data/prediction"): os.mkdir("./data/prediction")
2324

2425
parser = argparse.ArgumentParser()
2526
parser.add_argument("--voc_path", type=str, default="/home/yang/dataset/VOC")

5-Image_Segmentation/FCN/test.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#! /usr/bin/env python
2+
# coding=utf-8
3+
#================================================================
4+
# Copyright (C) 2019 * Ltd. All rights reserved.
5+
#
6+
# Editor : VIM
7+
# File name : test.py
8+
# Author : YunYang1994
9+
# Created date: 2019-10-23 23:14:38
10+
# Description :
11+
#
12+
#================================================================
13+
14+
import numpy as np
15+
import tensorflow as tf
16+
17+
from fcn8s import FCN8s
18+
from utils import visual_result, DataGenerator
19+
20+
model = FCN8s(n_class=21)
21+
TestSet = DataGenerator("./data/test_image.txt", "./data/test_labels", 1)
22+
23+
## load weights and test your model after training
24+
## if you want to test model, first you need to initialize your model
25+
## with "model(data)", and then load model weights
26+
data = np.ones(shape=[1,224,224,3], dtype=np.float)
27+
model(data)
28+
model.load_weights("FCN8s.h5")
29+
30+
for idx, (x, y) in enumerate(TestSet):
31+
result = model(x)
32+
pred_label = tf.argmax(result, axis=-1)
33+
result = visual_result(x[0], pred_label[0].numpy())
34+
save_file = "./data/prediction/%d.jpg" %idx
35+
print("=> saving prediction result into ", save_file)
36+
result.save(save_file)
37+
if idx == 209:
38+
result.show()
39+
break
40+

5-Image_Segmentation/FCN/train.py

Lines changed: 2 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -11,131 +11,18 @@
1111
#
1212
#================================================================
1313

14-
import os
15-
import cv2
16-
import random
1714
import tensorflow as tf
18-
import numpy as np
1915
from fcn8s import FCN8s
20-
from PIL import Image
21-
from config import colormap, classes, rgb_mean, rgb_std
22-
23-
24-
def create_image_label_path_generator(images_filepath, labels_filepath):
25-
image_paths = open(images_filepath).readlines()
26-
all_label_txts = os.listdir(labels_filepath)
27-
image_label_paths = []
28-
for label_txt in all_label_txts:
29-
label_name = label_txt[:-4]
30-
label_path = labels_filepath + "/" + label_txt
31-
for image_path in image_paths:
32-
image_path = image_path.rstrip()
33-
image_name = image_path.split("/")[-1][:-4]
34-
if label_name == image_name:
35-
image_label_paths.append((image_path, label_path))
36-
while True:
37-
random.shuffle(image_label_paths)
38-
for i in range(len(image_label_paths)):
39-
yield image_label_paths[i]
40-
41-
42-
def process_image_label(image_path, label_path):
43-
# image = misc.imread(image_path)
44-
image = cv2.imread(image_path)
45-
image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_NEAREST)
46-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
47-
# data augmentation here
48-
# pass
49-
# image transformation here
50-
image = (image / 255. - rgb_mean) / rgb_std
51-
52-
label = open(label_path).readlines()
53-
label = [np.array(line.rstrip().split(" ")) for line in label]
54-
label = np.array(label, dtype=np.int)
55-
label = cv2.resize(label, (224, 224), interpolation=cv2.INTER_NEAREST)
56-
label = label.astype(np.int)
57-
58-
return image, label
59-
60-
61-
def DataGenerator(train_image_txt, train_labels_dir, batch_size):
62-
"""
63-
generate image and mask at the same time
64-
"""
65-
image_label_path_generator = create_image_label_path_generator(
66-
train_image_txt, train_labels_dir
67-
)
68-
while True:
69-
images = np.zeros(shape=[batch_size, 224, 224, 3])
70-
labels = np.zeros(shape=[batch_size, 224, 224], dtype=np.float)
71-
for i in range(batch_size):
72-
image_path, label_path = next(image_label_path_generator)
73-
image, label = process_image_label(image_path, label_path)
74-
images[i], labels[i] = image, label
75-
yield images, labels
76-
77-
78-
def visual_result(image, label, alpha=0.7):
79-
"""
80-
image shape -> [H, W, C]
81-
label shape -> [H, W]
82-
"""
83-
image = (image * rgb_std + rgb_mean) * 255
84-
image, label = image.astype(np.int), label.astype(np.int)
85-
H, W, C = image.shape
86-
masks_color = np.zeros(shape=[H, W, C])
87-
inv_masks_color = np.zeros(shape=[H, W, C])
88-
cls = []
89-
for i in range(H):
90-
for j in range(W):
91-
cls_idx = label[i, j]
92-
masks_color[i, j] = np.array(colormap[cls_idx])
93-
cls.append(cls_idx)
94-
if classes[cls_idx] == "background":
95-
inv_masks_color[i, j] = alpha * image[i, j]
96-
97-
show_image = np.zeros(shape=[224, 672, 3])
98-
cls = set(cls)
99-
for x in cls:
100-
print("=> ", classes[x])
101-
show_image[:, :224, :] = image
102-
show_image[:, 224:448, :] = masks_color
103-
show_image[:, 448:, :] = (1-alpha)*image + alpha*masks_color + inv_masks_color
104-
show_image = Image.fromarray(np.uint8(show_image))
105-
return show_image
16+
from utils import DataGenerator
10617

10718
TrainSet = DataGenerator("./data/train_image.txt", "./data/train_labels", 2)
108-
TestSet = DataGenerator("./data/test_image.txt", "./data/test_labels", 1)
109-
11019
model = FCN8s(n_class=21)
111-
callback = tf.keras.callbacks.ModelCheckpoint("model.h5", verbose=1, save_weights_only=True)
20+
callback = tf.keras.callbacks.ModelCheckpoint("FCN8s.h5", verbose=1, save_weights_only=True)
11221
model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4),
11322
callback=callback,
11423
loss='sparse_categorical_crossentropy',
11524
metrics=['accuracy'])
11625

11726
## train your FCN8s model
11827
model.fit_generator(TrainSet, steps_per_epoch=6000, epochs=30)
119-
model.save_weights("model.h5")
120-
121-
## load weights and test your model after training
122-
## if you want to test model, first you need to initialize your model
123-
## with "model(data)", and then load model weights
124-
125-
# data = np.ones(shape=[1,224,224,3], dtype=np.float)
126-
# model(data)
127-
# model.load_weights("model.h5")
128-
129-
for idx, (x, y) in enumerate(TestSet):
130-
result = model(x)
131-
pred_label = tf.argmax(result, axis=-1)
132-
result = visual_result(x[0], pred_label[0].numpy())
133-
save_file = "./data/prediction/%d.jpg" %idx
134-
print("=> saving prediction result into ", save_file)
135-
result.save(save_file)
136-
if idx == 209:
137-
result.show()
138-
break
139-
140-
14128

5-Image_Segmentation/FCN/utils.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#! /usr/bin/env python
2+
# coding=utf-8
3+
#================================================================
4+
# Copyright (C) 2019 * Ltd. All rights reserved.
5+
#
6+
# Editor : VIM
7+
# File name : utils.py
8+
# Author : YunYang1994
9+
# Created date: 2019-10-12 17:47:24
10+
# Description :
11+
#
12+
#================================================================
13+
14+
import os
15+
import cv2
16+
import random
17+
import numpy as np
18+
19+
from PIL import Image
20+
21+
classes = ['background','aeroplane','bicycle','bird','boat',
22+
'bottle','bus','car','cat','chair','cow','diningtable',
23+
'dog','horse','motorbike','person','potted plant',
24+
'sheep','sofa','train','tv/monitor']
25+
# RGB color for each class
26+
colormap = [[0,0,0],[128,0,0],[0,128,0], [128,128,0], [0,0,128],
27+
[128,0,128],[0,128,128],[128,128,128],[64,0,0],[192,0,0],
28+
[64,128,0],[192,128,0],[64,0,128],[192,0,128],
29+
[64,128,128],[192,128,128],[0,64,0],[128,64,0],
30+
[0,192,0],[128,192,0],[0,64,128]]
31+
32+
rgb_mean = np.array([0.485, 0.456, 0.406])
33+
rgb_std = np.array([0.229, 0.224, 0.225])
34+
35+
def visual_result(image, label, alpha=0.7):
36+
"""
37+
image shape -> [H, W, C]
38+
label shape -> [H, W]
39+
"""
40+
image = (image * rgb_std + rgb_mean) * 255
41+
image, label = image.astype(np.int), label.astype(np.int)
42+
H, W, C = image.shape
43+
masks_color = np.zeros(shape=[H, W, C])
44+
inv_masks_color = np.zeros(shape=[H, W, C])
45+
cls = []
46+
for i in range(H):
47+
for j in range(W):
48+
cls_idx = label[i, j]
49+
masks_color[i, j] = np.array(colormap[cls_idx])
50+
cls.append(cls_idx)
51+
if classes[cls_idx] == "background":
52+
inv_masks_color[i, j] = alpha * image[i, j]
53+
54+
show_image = np.zeros(shape=[224, 672, 3])
55+
cls = set(cls)
56+
for x in cls:
57+
print("=> ", classes[x])
58+
show_image[:, :224, :] = image
59+
show_image[:, 224:448, :] = masks_color
60+
show_image[:, 448:, :] = (1-alpha)*image + alpha*masks_color + inv_masks_color
61+
show_image = Image.fromarray(np.uint8(show_image))
62+
return show_image
63+
64+
def create_image_label_path_generator(images_filepath, labels_filepath):
65+
image_paths = open(images_filepath).readlines()
66+
all_label_txts = os.listdir(labels_filepath)
67+
image_label_paths = []
68+
for label_txt in all_label_txts:
69+
label_name = label_txt[:-4]
70+
label_path = labels_filepath + "/" + label_txt
71+
for image_path in image_paths:
72+
image_path = image_path.rstrip()
73+
image_name = image_path.split("/")[-1][:-4]
74+
if label_name == image_name:
75+
image_label_paths.append((image_path, label_path))
76+
while True:
77+
random.shuffle(image_label_paths)
78+
for i in range(len(image_label_paths)):
79+
yield image_label_paths[i]
80+
81+
def process_image_label(image_path, label_path):
82+
# image = misc.imread(image_path)
83+
image = cv2.imread(image_path)
84+
image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_NEAREST)
85+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
86+
# data augmentation here
87+
# pass
88+
# image transformation here
89+
image = (image / 255. - rgb_mean) / rgb_std
90+
91+
label = open(label_path).readlines()
92+
label = [np.array(line.rstrip().split(" ")) for line in label]
93+
label = np.array(label, dtype=np.int)
94+
label = cv2.resize(label, (224, 224), interpolation=cv2.INTER_NEAREST)
95+
label = label.astype(np.int)
96+
97+
return image, label
98+
99+
100+
def DataGenerator(train_image_txt, train_labels_dir, batch_size):
101+
"""
102+
generate image and mask at the same time
103+
"""
104+
image_label_path_generator = create_image_label_path_generator(
105+
train_image_txt, train_labels_dir
106+
)
107+
while True:
108+
images = np.zeros(shape=[batch_size, 224, 224, 3])
109+
labels = np.zeros(shape=[batch_size, 224, 224], dtype=np.float)
110+
for i in range(batch_size):
111+
image_path, label_path = next(image_label_path_generator)
112+
image, label = process_image_label(image_path, label_path)
113+
images[i], labels[i] = image, label
114+
yield images, labels

0 commit comments

Comments
 (0)