Skip to content

Commit 0a796ec

Browse files
author
tianchu.gtc
committed
更改mask2former接口,适配maas
Link: https://code.alibaba-inc.com/pai-vision/EasyCV/codereview/10292532 * modify postprocess * Merge remote-tracking branch 'remotes/origin/master' into mask2former_postprocessing # Conflicts: # tests/ut_config.py * assert the output value * add mask2former models to data/test/xxx/models * fixed train
1 parent 7633707 commit 0a796ec

File tree

7 files changed

+160
-55
lines changed

7 files changed

+160
-55
lines changed

configs/segmentation/mask2former/mask2former_r50_8xb2_e50_instance.py

+20
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,26 @@
1515
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
1616
'hair drier', 'toothbrush'
1717
]
18+
PALETTE = [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230),
19+
(106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192),
20+
(250, 170, 30), (100, 170, 30), (220, 220, 0), (175, 116, 175),
21+
(250, 0, 30), (165, 42, 42), (255, 77, 255), (0, 226, 252),
22+
(182, 182, 255), (0, 82, 0), (120, 166, 157), (110, 76, 0),
23+
(174, 57, 255), (199, 100, 0), (72, 0, 118), (255, 179, 240),
24+
(0, 125, 92), (209, 0, 151), (188, 208, 182), (0, 220, 176),
25+
(255, 99, 164), (92, 0, 73), (133, 129, 255), (78, 180, 255),
26+
(0, 228, 0), (174, 255, 243), (45, 89, 255), (134, 134, 103),
27+
(145, 148, 174), (255, 208, 186), (197, 226, 255), (171, 134, 1),
28+
(109, 63, 54), (207, 138, 255), (151, 0, 95), (9, 80, 61),
29+
(84, 105, 51), (74, 65, 105), (166, 196, 102), (208, 195, 210),
30+
(255, 109, 65), (0, 143, 149), (179, 0, 194), (209, 99, 106),
31+
(5, 121, 0), (227, 255, 205), (147, 186, 208), (153, 69, 1),
32+
(3, 95, 161), (163, 255, 0), (119, 0, 170), (0, 182, 199),
33+
(0, 165, 120), (183, 130, 88), (95, 32, 0), (130, 114, 135),
34+
(110, 129, 133), (166, 74, 118), (219, 142, 185), (79, 210, 114),
35+
(178, 90, 62), (65, 70, 15), (127, 167, 115), (59, 105, 106),
36+
(142, 108, 45), (196, 172, 0), (95, 54, 80), (128, 76, 255),
37+
(201, 57, 1), (246, 0, 122), (191, 162, 208)]
1838

1939
model = dict(
2040
type='Mask2Former',
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:163a344e29b965cdb6c6c24e189e84a269580d63237253f359de35e944ec5421
3+
size 528712836
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:a33e6b1f5623057c6920226767c91a44a072acc27ece5ba24fdeb2a9a1bb2ba2
3+
size 528548036

easycv/models/segmentation/mask2former.py

+17-22
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,13 @@ def init_weights(self):
110110
print_log('load model from init weights')
111111
self.backbone.init_weights()
112112

113-
def forward_train(self, img, gt_labels, gt_masks, gt_semantic_seg,
114-
img_metas):
113+
def forward_train(self,
114+
img,
115+
gt_labels,
116+
gt_masks=None,
117+
gt_semantic_seg=None,
118+
img_metas=None,
119+
**kwargs):
115120
features = self.backbone(img)
116121
outputs = self.head(features)
117122
targets = self.preprocess_gt(gt_labels, gt_masks, gt_semantic_seg,
@@ -125,7 +130,12 @@ def forward_train(self, img, gt_labels, gt_masks, gt_semantic_seg,
125130
losses.pop(k)
126131
return losses
127132

128-
def forward_test(self, img, img_metas, rescale=True, encode=True):
133+
def forward_test(self,
134+
img,
135+
img_metas,
136+
rescale=True,
137+
encode=True,
138+
**kwargs):
129139
features = self.backbone(img[0])
130140
outputs = self.head(features)
131141
mask_cls_results = outputs['pred_logits']
@@ -189,23 +199,6 @@ def forward_test(self, img, img_metas, rescale=True, encode=True):
189199
outputs['pan_results'] = pan_masks
190200
return outputs
191201

192-
def forward(self,
193-
img,
194-
mode='train',
195-
gt_labels=None,
196-
gt_masks=None,
197-
gt_semantic_seg=None,
198-
img_metas=None,
199-
**kwargs):
200-
201-
if mode == 'train':
202-
return self.forward_train(img, gt_labels, gt_masks,
203-
gt_semantic_seg, img_metas)
204-
elif mode == 'test':
205-
return self.forward_test(img, img_metas)
206-
else:
207-
raise Exception('No such mode: {}'.format(mode))
208-
209202
def instance_postprocess(self, mask_cls, mask_pred):
210203
"""Instance segmengation postprocess.
211204
@@ -233,8 +226,10 @@ def instance_postprocess(self, mask_cls, mask_pred):
233226
# shape (num_queries, num_class)
234227
scores = F.softmax(mask_cls, dim=-1)[:, :-1]
235228
# shape (num_queries * num_class, )
236-
labels = torch.arange(self.num_classes, device=mask_cls.device).\
237-
unsqueeze(0).repeat(num_queries, 1).flatten(0, 1)
229+
labels = torch.arange(
230+
self.num_classes,
231+
device=mask_cls.device).unsqueeze(0).repeat(num_queries,
232+
1).flatten(0, 1)
238233
scores_per_image, top_indices = scores.flatten(0, 1).topk(
239234
max_per_image, sorted=False)
240235
labels_per_image = labels[top_indices]

easycv/predictors/segmentation.py

+31-15
Original file line numberDiff line numberDiff line change
@@ -167,29 +167,44 @@ def forward(self, inputs):
167167
"""Model forward.
168168
"""
169169
with torch.no_grad():
170-
outputs = self.model(**inputs, mode='test', encode=False)
170+
outputs = self.model.forward(**inputs, mode='test', encode=False)
171+
if self.task_mode == 'instance':
172+
outputs.pop('pan_results')
173+
elif self.task_mode == 'panoptic':
174+
outputs.pop('detection_masks')
175+
outputs.pop('detection_boxes')
176+
outputs.pop('detection_scores')
177+
outputs.pop('detection_classes')
171178
return outputs
172179

173-
def postprocess(self, inputs):
180+
def postprocess_single(self, inputs, *args, **kwargs):
174181
output = {}
175182
if self.task_mode == 'panoptic':
176-
output['pan'] = inputs['pan_results'][0]
183+
pan_results = inputs['pan_results']
184+
# keep objects ahead
185+
ids = np.unique(pan_results)[::-1]
186+
legal_indices = ids != len(self.CLASSES) # for VOID label
187+
ids = ids[legal_indices]
188+
labels = np.array([id % 1000 for id in ids], dtype=np.int64)
189+
segms = (pan_results[None] == ids[:, None, None])
190+
masks = [it.astype(np.int) for it in segms]
191+
labels_txt = np.array(self.CLASSES)[labels].tolist()
192+
193+
output['masks'] = masks
194+
output['labels'] = labels_txt
195+
output['labels_ids'] = labels
177196
elif self.task_mode == 'instance':
178-
output['segms'] = inputs['detection_masks'][0]
179-
output['bboxes'] = inputs['detection_boxes'][0]
180-
output['scores'] = inputs['detection_scores'][0]
181-
output['labels'] = inputs['detection_classes'][0]
197+
output['segms'] = inputs['detection_masks']
198+
output['bboxes'] = inputs['detection_boxes']
199+
output['scores'] = inputs['detection_scores']
200+
output['labels'] = inputs['detection_classes']
182201
else:
183202
raise ValueError(f'Not support model {self.task_mode}')
184203
return output
185204

186-
def show_panoptic(self, img, pan_mask):
187-
pan_label = np.unique(pan_mask)
188-
pan_label = pan_label[pan_label % 1000 != self.classes]
189-
masks = np.array([pan_mask == num for num in pan_label])
190-
205+
def show_panoptic(self, img, masks, labels):
191206
palette = np.asarray(self.cfg.PALETTE)
192-
palette = palette[pan_label % 1000]
207+
palette = palette[labels % 1000]
193208
panoptic_result = draw_masks(img, masks, palette)
194209
return panoptic_result
195210

@@ -199,10 +214,11 @@ def show_instance(self, img, segms, bboxes, scores, labels, score_thr=0.5):
199214
bboxes = bboxes[inds, :]
200215
segms = segms[inds, ...]
201216
labels = labels[inds]
202-
palette = np.asarray(self.cfg.PALETTE)
217+
palette = np.asarray(self.PALETTE)
203218
palette = palette[labels]
219+
204220
instance_result = draw_masks(img, segms, palette)
205-
class_name = np.array(self.class_name)
221+
class_name = np.array(self.CLASSES)
206222
instance_result = imshow_bboxes(
207223
instance_result, bboxes, class_name[labels], show=False)
208224
return instance_result

tests/predictors/test_segmentation.py

+82-18
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,19 @@
55
import tempfile
66
import unittest
77

8+
import cv2
89
import numpy as np
910
from mmcv import Config
1011
from PIL import Image
11-
from tests.ut_config import (MODEL_CONFIG_SEGFORMER,
12+
from tests.ut_config import (MODEL_CONFIG_MASK2FORMER_INS,
13+
MODEL_CONFIG_MASK2FORMER_PAN,
14+
MODEL_CONFIG_SEGFORMER,
1215
PRETRAINED_MODEL_MASK2FORMER_DIR,
1316
PRETRAINED_MODEL_SEGFORMER, TEST_IMAGES_DIR)
1417

1518
from easycv.file import io
16-
from easycv.predictors.segmentation import SegmentationPredictor
19+
from easycv.predictors.segmentation import (Mask2formerPredictor,
20+
SegmentationPredictor)
1721

1822

1923
class SegmentationPredictorTest(unittest.TestCase):
@@ -112,34 +116,94 @@ def test_dump(self):
112116
shutil.rmtree(temp_dir, ignore_errors=True)
113117

114118

115-
@unittest.skipIf(True, 'WIP')
116119
class Mask2formerPredictorTest(unittest.TestCase):
117120

118-
def test_single(self):
119-
import cv2
120-
from easycv.predictors.segmentation import Mask2formerPredictor
121-
pan_ckpt = os.path.join(PRETRAINED_MODEL_MASK2FORMER_DIR,
122-
'mask2former_pan_export.pth')
123-
instance_ckpt = os.path.join(PRETRAINED_MODEL_MASK2FORMER_DIR,
124-
'mask2former_r50_instance.pth')
125-
img_path = os.path.join(TEST_IMAGES_DIR, 'mask2former.jpg')
121+
def setUp(self):
122+
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
123+
self.img_path = './data/test/segmentation/data/000000309022.jpg'
124+
self.pan_ckpt = './data/test/segmentation/models/mask2former_pan_export.pth'
125+
self.instance_ckpt = './data/test/segmentation/models/mask2former_r50_instance.pth'
126126

127+
def test_panoptic_single(self):
127128
# panop
129+
segmentation_model_config = MODEL_CONFIG_MASK2FORMER_PAN
128130
predictor = Mask2formerPredictor(
129-
model_path=pan_ckpt, output_mode='panoptic')
130-
img = cv2.imread(img_path)
131-
predict_out = predictor([img])
132-
pan_img = predictor.show_panoptic(img, predict_out[0]['pan'])
131+
model_path=self.pan_ckpt,
132+
task_mode='panoptic',
133+
config_file=segmentation_model_config)
134+
img = cv2.imread(self.img_path)
135+
predict_out = predictor([self.img_path])
136+
self.assertEqual(len(predict_out), 1)
137+
self.assertEqual(len(predict_out[0]['masks']), 14)
138+
self.assertListEqual(
139+
predict_out[0]['labels_ids'].tolist(),
140+
[71, 69, 39, 39, 39, 128, 127, 122, 118, 115, 111, 104, 84, 83])
141+
142+
pan_img = predictor.show_panoptic(
143+
img,
144+
masks=predict_out[0]['masks'],
145+
labels=predict_out[0]['labels_ids'])
133146
cv2.imwrite('pan_out.jpg', pan_img)
134147

148+
def test_panoptic_batch(self):
149+
total_samples = 2
150+
segmentation_model_config = MODEL_CONFIG_MASK2FORMER_PAN
151+
predictor = Mask2formerPredictor(
152+
model_path=self.pan_ckpt,
153+
task_mode='panoptic',
154+
config_file=segmentation_model_config,
155+
batch_size=total_samples)
156+
predict_out = predictor([self.img_path] * total_samples)
157+
self.assertEqual(len(predict_out), total_samples)
158+
img = cv2.imread(self.img_path)
159+
for i in range(total_samples):
160+
save_name = 'pan_out_batch_%d.jpg' % i
161+
self.assertEqual(len(predict_out[i]['masks']), 14)
162+
self.assertListEqual(predict_out[i]['labels_ids'].tolist(), [
163+
71, 69, 39, 39, 39, 128, 127, 122, 118, 115, 111, 104, 84, 83
164+
])
165+
pan_img = predictor.show_panoptic(
166+
img,
167+
masks=predict_out[i]['masks'],
168+
labels=predict_out[i]['labels_ids'])
169+
cv2.imwrite(save_name, pan_img)
170+
171+
def test_instance_single(self):
135172
# instance
173+
segmentation_model_config = MODEL_CONFIG_MASK2FORMER_INS
136174
predictor = Mask2formerPredictor(
137-
model_path=instance_ckpt, output_mode='instance')
138-
img = cv2.imread(img_path)
139-
predict_out = predictor.predict([img], mode='instance')
175+
model_path=self.instance_ckpt,
176+
task_mode='instance',
177+
config_file=segmentation_model_config)
178+
img = cv2.imread(self.img_path)
179+
predict_out = predictor([self.img_path])
180+
self.assertEqual(len(predict_out), 1)
181+
self.assertEqual(predict_out[0]['segms'].shape, (100, 480, 640))
182+
self.assertListEqual(predict_out[0]['labels'][:10].tolist(),
183+
[41, 69, 72, 45, 68, 70, 41, 69, 69, 45])
184+
140185
instance_img = predictor.show_instance(img, **predict_out[0])
141186
cv2.imwrite('instance_out.jpg', instance_img)
142187

188+
def test_instance_batch(self):
189+
total_samples = 2
190+
segmentation_model_config = MODEL_CONFIG_MASK2FORMER_INS
191+
predictor = Mask2formerPredictor(
192+
model_path=self.instance_ckpt,
193+
task_mode='instance',
194+
config_file=segmentation_model_config,
195+
batch_size=total_samples)
196+
img = cv2.imread(self.img_path)
197+
predict_out = predictor([self.img_path] * total_samples)
198+
self.assertEqual(len(predict_out), total_samples)
199+
for i in range(total_samples):
200+
save_name = 'instance_out_batch_%d.jpg' % i
201+
self.assertEqual(predict_out[i]['segms'].shape, (100, 480, 640))
202+
self.assertListEqual(predict_out[0]['labels'][:10].tolist(),
203+
[41, 69, 72, 45, 68, 70, 41, 69, 69, 45])
204+
instance_img = predictor.show_instance(img, **(predict_out[i]))
205+
cv2.imwrite(save_name, instance_img)
206+
143207

144208
if __name__ == '__main__':
145209
unittest.main()

tests/ut_config.py

+4
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,7 @@
132132
'./configs/segmentation/segformer/segformer_b0_coco.py')
133133
SMALL_COCO_WHOLE_BODY_HAND_ROOT = 'data/test/pose/hand/small_whole_body_hand_coco'
134134
SMALL_COCO_WHOLEBODY_ROOT = 'data/test/pose/wholebody/data'
135+
MODEL_CONFIG_MASK2FORMER_PAN = (
136+
'./configs/segmentation/mask2former/mask2former_r50_8xb2_e50_panoptic.py')
137+
MODEL_CONFIG_MASK2FORMER_INS = (
138+
'./configs/segmentation/mask2former/mask2former_r50_8xb2_e50_instance.py')

0 commit comments

Comments
 (0)