Skip to content

Commit 6e4c2a7

Browse files
committed
Fix error in tools.
1 parent 724eb20 commit 6e4c2a7

File tree

1 file changed

+52
-30
lines changed

1 file changed

+52
-30
lines changed

tools/convert_voc2010.py

+52-30
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
"""
1615
File: convert_voc2010.py
1716
This file is based on https://www.cs.stanford.edu/~roozbeh/pascal-context/ to generate PASCAL-Context Dataset.
@@ -24,33 +23,31 @@
2423
|--ImageSets
2524
|
2625
|--SegmentationClass
27-
|
26+
|
2827
|--JPEGImages
2928
|
3029
|--SegmentationObject
3130
|
3231
|--trainval_merged.json
3332
"""
3433

34+
import argparse
3535
import os
3636

37-
import argparse
3837
import tqdm
3938
import numpy as np
4039
from detail import Detail
4140
from PIL import Image
41+
from paddleseg.utils.download import _download_file
42+
43+
JSON_URL = 'https://codalabuser.blob.core.windows.net/public/trainval_merged.json'
4244

4345

4446
def parse_args():
4547
parser = argparse.ArgumentParser(
46-
description=
47-
'Generate PASCAL-Context dataset'
48-
)
48+
description='Generate PASCAL-Context dataset')
4949
parser.add_argument(
50-
'--voc_path',
51-
dest='voc_path',
52-
help='pascal voc path',
53-
type=str)
50+
'--voc_path', dest='voc_path', help='pascal voc path', type=str)
5451
parser.add_argument(
5552
'--annotation_path',
5653
dest='annotation_path',
@@ -66,17 +63,23 @@ def __init__(self, voc_path, annotation_path):
6663
self.annotation_path = annotation_path
6764
self.label_dir = os.path.join(self.voc_path, 'Context')
6865
self._image_dir = os.path.join(self.voc_path, 'JPEGImages')
69-
self.annFile = os.path.join(self.annotation_path, 'trainval_merged.json')
70-
66+
self.annFile = os.path.join(self.annotation_path,
67+
'trainval_merged.json')
68+
7169
if not os.path.exists(self.annFile):
72-
_download_file(url=JSON_URL, savepath=self.annotation_path, print_progress=True)
73-
74-
self._mapping = np.sort(np.array([
75-
0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22,
76-
23, 397, 25, 284, 158, 159, 416, 33, 162, 420, 454, 295, 296,
77-
427, 44, 45, 46, 308, 59, 440, 445, 31, 232, 65, 354, 424,
78-
68, 326, 72, 458, 34, 207, 80, 355, 85, 347, 220, 349, 360,
79-
98, 187, 104, 105, 366, 189, 368, 113, 115]))
70+
_download_file(
71+
url=JSON_URL,
72+
savepath=self.annotation_path,
73+
print_progress=True)
74+
75+
self._mapping = np.sort(
76+
np.array([
77+
0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 23, 397, 25,
78+
284, 158, 159, 416, 33, 162, 420, 454, 295, 296, 427, 44, 45,
79+
46, 308, 59, 440, 445, 31, 232, 65, 354, 424, 68, 326, 72, 458,
80+
34, 207, 80, 355, 85, 347, 220, 349, 360, 98, 187, 104, 105,
81+
366, 189, 368, 113, 115
82+
]))
8083
self._key = np.array(range(len(self._mapping))).astype('uint8') - 1
8184

8285
self.train_detail = Detail(self.annFile, self._image_dir, 'train')
@@ -94,12 +97,20 @@ def _class_to_index(self, mask, _mapping, _key):
9497
assert (values[i] in _mapping)
9598
index = np.digitize(mask.ravel(), _mapping, right=True)
9699
return _key[index].reshape(mask.shape)
97-
100+
98101
def save_mask(self, img_id, mode):
99102
if mode == 'train':
100-
mask = Image.fromarray(self._class_to_index(self.train_detail.getMask(img_id), _mapping=self._mapping, _key=self._key))
103+
mask = Image.fromarray(
104+
self._class_to_index(
105+
self.train_detail.getMask(img_id),
106+
_mapping=self._mapping,
107+
_key=self._key))
101108
elif mode == 'val':
102-
mask = Image.fromarray(self._class_to_index(self.val_detail.getMask(img_id), _mapping=self._mapping, _key=self._key))
109+
mask = Image.fromarray(
110+
self._class_to_index(
111+
self.val_detail.getMask(img_id),
112+
_mapping=self._mapping,
113+
_key=self._key))
103114
filename = img_id['file_name']
104115
basename, _ = os.path.splitext(filename)
105116
if filename.endswith(".jpg"):
@@ -109,27 +120,38 @@ def save_mask(self, img_id, mode):
109120

110121
def generate_label(self):
111122

112-
with open(os.path.join(self.voc_path, 'ImageSets/Segmentation/train_context.txt'), 'w') as f:
123+
with open(
124+
os.path.join(self.voc_path,
125+
'ImageSets/Segmentation/train_context.txt'),
126+
'w') as f:
113127
for img_id in tqdm.tqdm(self.train_ids, desc='train'):
114128
basename = self.save_mask(img_id, 'train')
115129
f.writelines(''.join([basename, '\n']))
116130

117-
with open(os.path.join(self.voc_path, 'ImageSets/Segmentation/val_context.txt'), 'w') as f:
131+
with open(
132+
os.path.join(self.voc_path,
133+
'ImageSets/Segmentation/val_context.txt'),
134+
'w') as f:
118135
for img_id in tqdm.tqdm(self.val_ids, desc='val'):
119136
basename = self.save_mask(img_id, 'val')
120137
f.writelines(''.join([basename, '\n']))
121-
122-
with open(os.path.join(self.voc_path, 'ImageSets/Segmentation/trainval_context.txt'), 'w') as f:
138+
139+
with open(
140+
os.path.join(self.voc_path,
141+
'ImageSets/Segmentation/trainval_context.txt'),
142+
'w') as f:
123143
for img in tqdm.tqdm(os.listdir(self.label_dir), desc='trainval'):
124144
if img.endswith('.png'):
125145
basename = img.split('.', 1)[0]
126146
f.writelines(''.join([basename, '\n']))
127-
128-
147+
148+
129149
def main():
130150
args = parse_args()
131-
generator = PascalContextGenerator(voc_path=args.voc_path, annotation_path=args.annotation_path)
151+
generator = PascalContextGenerator(
152+
voc_path=args.voc_path, annotation_path=args.annotation_path)
132153
generator.generate_label()
133154

155+
134156
if __name__ == '__main__':
135157
main()

0 commit comments

Comments
 (0)