Skip to content

Commit 00aa7c2

Browse files
authored
Support_mask2former_semantic (#199)
* add mask2former semantic branch
1 parent 8a484e8 commit 00aa7c2

File tree

13 files changed

+573
-48
lines changed

13 files changed

+573
-48
lines changed

configs/ocr/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
## PP-OCRv3
33
We convert [PaddleOCRv3](https://github.com/PaddlePaddle/PaddleOCR) models to pytorch style, and provide end2end interface to recognize text in images, by simplely load exported models.
44
### detection
5-
We test on on icdar2015 dataset.
5+
We test on icdar2015 dataset.
66
|Algorithm|backbone|configs|precison|recall|Hmean|Download|
77
|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
88
|DB|MobileNetv3|[det_model_en.py](configs/ocr/detection/det_model_en.py)|0.7803|0.7250|0.7516|[log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/ocr/det/fintune_icdar2015_mobilev3/20220902_140307.log.json)-[model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/ocr/det/fintune_icdar2015_mobilev3/epoch_70.pth)|
@@ -59,7 +59,7 @@ out = predictor([img_path])
5959
print(out)
6060
```
6161
![rec_input](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/ocr/export_model/test_image/japan_rec.jpg)<br/>
62-
![rec_putput](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/ocr/export_model/test_image/japan_predict.jpg)
62+
![rec_output](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/ocr/export_model/test_image/japan_predict.jpg)
6363
##### end2end
6464
```
6565
import cv2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
_base_ = ['configs/base.py']
2+
3+
CLASSES = [
4+
'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road, route',
5+
'bed', 'window ', 'grass', 'cabinet', 'sidewalk, pavement', 'person',
6+
'earth, ground', 'door', 'table', 'mountain, mount', 'plant', 'curtain',
7+
'chair', 'car', 'water', 'painting, picture', 'sofa', 'shelf', 'house',
8+
'sea', 'mirror', 'rug', 'field', 'armchair', 'seat', 'fence', 'desk',
9+
'rock, stone', 'wardrobe, closet, press', 'lamp', 'tub', 'rail', 'cushion',
10+
'base, pedestal, stand', 'box', 'column, pillar', 'signboard, sign',
11+
'chest of drawers, chest, bureau, dresser', 'counter', 'sand', 'sink',
12+
'skyscraper', 'fireplace', 'refrigerator, icebox',
13+
'grandstand, covered stand', 'path', 'stairs', 'runway',
14+
'case, display case, showcase, vitrine',
15+
'pool table, billiard table, snooker table', 'pillow',
16+
'screen door, screen', 'stairway, staircase', 'river', 'bridge, span',
17+
'bookcase', 'blind, screen', 'coffee table',
18+
'toilet, can, commode, crapper, pot, potty, stool, throne', 'flower',
19+
'book', 'hill', 'bench', 'countertop', 'stove', 'palm, palm tree',
20+
'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
21+
'arcade machine', 'hovel, hut, hutch, shack, shanty', 'bus', 'towel',
22+
'light', 'truck', 'tower', 'chandelier', 'awning, sunshade, sunblind',
23+
'street lamp', 'booth', 'tv', 'plane', 'dirt track', 'clothes', 'pole',
24+
'land, ground, soil',
25+
'bannister, banister, balustrade, balusters, handrail',
26+
'escalator, moving staircase, moving stairway',
27+
'ottoman, pouf, pouffe, puff, hassock', 'bottle',
28+
'buffet, counter, sideboard',
29+
'poster, posting, placard, notice, bill, card', 'stage', 'van', 'ship',
30+
'fountain',
31+
'conveyer belt, conveyor belt, conveyer, conveyor, transporter', 'canopy',
32+
'washer, automatic washer, washing machine', 'plaything, toy', 'pool',
33+
'stool', 'barrel, cask', 'basket, handbasket', 'falls', 'tent', 'bag',
34+
'minibike, motorbike', 'cradle', 'oven', 'ball', 'food, solid food',
35+
'step, stair', 'tank, storage tank', 'trade name', 'microwave', 'pot',
36+
'animal', 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket, cover',
37+
'sculpture', 'hood, exhaust hood', 'sconce', 'vase', 'traffic light',
38+
'tray', 'trash can', 'fan', 'pier', 'crt screen', 'plate', 'monitor',
39+
'bulletin board', 'shower', 'radiator', 'glass, drinking glass', 'clock',
40+
'flag'
41+
]
42+
43+
PALETTE = [(120, 120, 120), (180, 120, 120), (6, 230, 230), (80, 50, 50),
44+
(4, 200, 3), (120, 120, 80), (140, 140, 140), (204, 5, 255),
45+
(230, 230, 230), (4, 250, 7), (224, 5, 255), (235, 255, 7),
46+
(150, 5, 61), (120, 120, 70), (8, 255, 51), (255, 6, 82),
47+
(143, 255, 140), (204, 255, 4), (255, 51, 7), (204, 70, 3),
48+
(0, 102, 200), (61, 230, 250), (255, 6, 51), (11, 102, 255),
49+
(255, 7, 71), (255, 9, 224), (9, 7, 230), (220, 220, 220),
50+
(255, 9, 92), (112, 9, 255), (8, 255, 214), (7, 255, 224),
51+
(255, 184, 6), (10, 255, 71), (255, 41, 10), (7, 255, 255),
52+
(224, 255, 8), (102, 8, 255), (255, 61, 6), (255, 194, 7),
53+
(255, 122, 8), (0, 255, 20), (255, 8, 41), (255, 5, 153),
54+
(6, 51, 255), (235, 12, 255), (160, 150, 20), (0, 163, 255),
55+
(140, 140, 140), (250, 10, 15), (20, 255, 0), (31, 255, 0),
56+
(255, 31, 0), (255, 224, 0), (153, 255, 0), (0, 0, 255),
57+
(255, 71, 0), (0, 235, 255), (0, 173, 255), (31, 0, 255),
58+
(11, 200, 200), (255, 82, 0), (0, 255, 245), (0, 61, 255),
59+
(0, 255, 112), (0, 255, 133), (255, 0, 0), (255, 163, 0),
60+
(255, 102, 0), (194, 255, 0), (0, 143, 255), (51, 255, 0),
61+
(0, 82, 255), (0, 255, 41), (0, 255, 173), (10, 0, 255),
62+
(173, 255, 0), (0, 255, 153), (255, 92, 0), (255, 0, 255),
63+
(255, 0, 245), (255, 0, 102), (255, 173, 0), (255, 0, 20),
64+
(255, 184, 184), (0, 31, 255), (0, 255, 61), (0, 71, 255),
65+
(255, 0, 204), (0, 255, 194), (0, 255, 82), (0, 10, 255),
66+
(0, 112, 255), (51, 0, 255), (0, 194, 255), (0, 122, 255),
67+
(0, 255, 163), (255, 153, 0), (0, 255, 10), (255, 112, 0),
68+
(143, 255, 0), (82, 0, 255), (163, 255, 0), (255, 235, 0),
69+
(8, 184, 170), (133, 0, 255), (0, 255, 92), (184, 0, 255),
70+
(255, 0, 31), (0, 184, 255), (0, 214, 255), (255, 0, 112),
71+
(92, 255, 0), (0, 224, 255), (112, 224, 255), (70, 184, 160),
72+
(163, 0, 255), (153, 0, 255), (71, 255, 0), (255, 0, 163),
73+
(255, 204, 0), (255, 0, 143), (0, 255, 235), (133, 255, 0),
74+
(255, 0, 235), (245, 0, 255), (255, 0, 122), (255, 245, 0),
75+
(10, 190, 212), (214, 255, 0), (0, 204, 255), (20, 0, 255),
76+
(255, 255, 0), (0, 153, 255), (0, 41, 255), (0, 255, 204),
77+
(41, 0, 255), (41, 255, 0), (173, 0, 255), (0, 245, 255),
78+
(71, 0, 255), (122, 0, 255), (0, 255, 184), (0, 92, 255),
79+
(184, 255, 0), (0, 133, 255), (255, 214, 0), (25, 194, 194),
80+
(102, 255, 0), (92, 0, 255)]
81+
82+
model = dict(
83+
type='Mask2Former',
84+
backbone=dict(
85+
type='ResNet',
86+
depth=50,
87+
num_stages=4,
88+
out_indices=(1, 2, 3, 4),
89+
frozen_stages=-1,
90+
norm_cfg=dict(type='BN', requires_grad=False),
91+
norm_eval=True),
92+
head=dict(
93+
type='Mask2FormerHead',
94+
pixel_decoder=dict(
95+
input_stride=[4, 8, 16, 32],
96+
input_channel=[256, 512, 1024, 2048],
97+
transformer_dropout=0.0,
98+
transformer_nheads=8,
99+
transformer_dim_feedforward=1024,
100+
transformer_enc_layers=6,
101+
conv_dim=256,
102+
mask_dim=256,
103+
norm='GN',
104+
transformer_in_features=[1, 2, 3],
105+
common_stride=4,
106+
),
107+
transformer_decoder=dict(
108+
in_channels=256,
109+
num_classes=150,
110+
hidden_dim=256,
111+
num_queries=100,
112+
nheads=8,
113+
dim_feedforward=2048,
114+
dec_layers=9,
115+
pre_norm=False,
116+
mask_dim=256,
117+
enforce_input_project=False,
118+
),
119+
num_things_classes=150,
120+
num_stuff_classes=0,
121+
),
122+
train_cfg=dict(
123+
class_weight=2.0,
124+
mask_weight=5.0,
125+
dice_weight=5.0,
126+
deep_supervision=True,
127+
dec_layers=10,
128+
num_points=12554,
129+
no_object_weight=0.1,
130+
oversample_ratio=3.0,
131+
importance_sample_ratio=0.75,
132+
),
133+
test_cfg=dict(
134+
semantic_on=True,
135+
max_per_image=100,
136+
),
137+
pretrained=True,
138+
)
139+
140+
data_root = 'ADEChallengeData2016/'
141+
img_norm_cfg = dict(
142+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
143+
crop_size = (512, 512)
144+
train_pipeline = [
145+
dict(type='MMResize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
146+
dict(type='SegRandomCrop', crop_size=crop_size),
147+
dict(type='MMRandomFlip', flip_ratio=0.5),
148+
dict(type='MMPhotoMetricDistortion'),
149+
dict(type='MMPad', size=crop_size),
150+
dict(type='MMNormalize', **img_norm_cfg),
151+
dict(type='DefaultFormatBundle'),
152+
dict(
153+
type='Collect',
154+
keys=['img', 'gt_semantic_seg'],
155+
meta_keys=('filename', 'ori_filename', 'ori_shape', 'img_shape',
156+
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
157+
'img_norm_cfg')),
158+
]
159+
160+
test_pipeline = [
161+
dict(
162+
type='MMMultiScaleFlipAug',
163+
img_scale=(2048, 512),
164+
flip=False,
165+
transforms=[
166+
dict(type='MMResize', keep_ratio=True),
167+
dict(type='MMRandomFlip'),
168+
dict(type='MMNormalize', **img_norm_cfg),
169+
dict(type='ImageToTensor', keys=['img']),
170+
dict(
171+
type='Collect',
172+
keys=['img'],
173+
meta_keys=('filename', 'ori_filename', 'ori_shape',
174+
'img_shape', 'pad_shape', 'scale_factor', 'flip',
175+
'flip_direction', 'img_norm_cfg')),
176+
])
177+
]
178+
179+
train_dataset = dict(
180+
type='SegDataset',
181+
data_source=dict(
182+
type='SegSourceRaw',
183+
cache_on_the_fly=True,
184+
img_root=data_root + 'images/training',
185+
label_root=data_root + 'annotations/training',
186+
reduce_zero_label=True,
187+
classes=CLASSES,
188+
),
189+
pipeline=train_pipeline)
190+
191+
val_dataset = dict(
192+
type='SegDataset',
193+
imgs_per_gpu=1,
194+
data_source=dict(
195+
type='SegSourceRaw',
196+
cache_on_the_fly=True,
197+
img_root=data_root + 'images/validation',
198+
label_root=data_root + 'annotations/validation',
199+
reduce_zero_label=True,
200+
classes=CLASSES,
201+
),
202+
pipeline=test_pipeline)
203+
204+
data = dict(
205+
imgs_per_gpu=2, workers_per_gpu=2, train=train_dataset, val=val_dataset)
206+
207+
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
208+
# optimizer
209+
optimizer = dict(
210+
type='AdamW',
211+
lr=0.0001,
212+
weight_decay=0.05,
213+
eps=1e-8,
214+
betas=(0.9, 0.999),
215+
paramwise_options={
216+
'backbone': dict(lr_mult=0.1),
217+
'query_embed': dict(weight_decay=0.),
218+
'query_feat': dict(weight_decay=0.),
219+
'level_embed': dict(weight_decay=0.),
220+
'norm': dict(weight_decay=0.),
221+
})
222+
# it seems grad clip not influence result
223+
# optimizer_config = dict(grad_clip=dict(max_norm=0.01, norm_type=2))
224+
total_epochs = 127
225+
226+
lr_config = dict(
227+
policy='Poly',
228+
min_lr=0,
229+
warmup='linear',
230+
warmup_iters=1,
231+
warmup_ratio=1e-4,
232+
warmup_by_epoch=False,
233+
by_epoch=False,
234+
power=0.9)
235+
checkpoint_config = dict(interval=1)
236+
237+
eval_config = dict(initial=False, interval=1, gpu_collect=False)
238+
239+
eval_pipelines = [
240+
dict(
241+
mode='test',
242+
evaluators=[
243+
dict(
244+
type='SegmentationEvaluator',
245+
classes=CLASSES,
246+
ignore_index=255,
247+
metric_names=['mIoU'])
248+
],
249+
)
250+
]

docs/source/model_zoo_seg.md

+11-6
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,21 @@ Pretrained on **Pascal VOC 2012 + Aug**.
1818
## Mask2former
1919

2020
### Instance Segmentation on COCO
21-
| Algorithm | Config | Train memory<br/>(GB) | box MAP | Mask mAP | Download |
22-
| ---------- | ------------------------------------------------------------ |----------|----------|----------|----------|
23-
| mask2former_r50 | [mask2former_r50_8xb2_e50_instance](https://github.com/alibaba/EasyCV/tree/master/configs/segmentation/mask2former/mask2former_r50_8xb2_e50_instance.py) | 18.8 | 46.09 | 43.26 |[model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/mask2former_r50_instance/epoch_50.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/mask2former_r50_instance/20220620_113639.log.json) |
21+
| Algorithm | Config | Params<br/>(backbone/total) | Train memory<br/>(GB) | inference time(A100)<br/>(ms/img) | box MAP | Mask mAP | Download |
22+
| ---------- | ------------------------------------------------------------ | ------------------------ |----------|----------|----------|----------|---|
23+
| mask2former_r50 | [mask2former_r50_8xb2_e50_instance](https://github.com/alibaba/EasyCV/tree/master/configs/segmentation/mask2former/mask2former_r50_8xb2_e50_instance.py) | 23.5M/44M | 18.8 | 214ms | 46.09 | 43.26 |[model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/mask2former_r50_instance/epoch_50.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/mask2former_r50_instance/20220620_113639.log.json) |
2424

2525
### Panoptic Segmentation on COCO
2626

27-
| Algorithm | Config | Train memory<br/>(GB) | PQ | box MAP | Mask mAP | Download |
28-
| ---------- | ---------- | ------------------------------------------------------------ | ------------------------ |----------|---------------------------------------------------------------------------- |---------------------------------------------------------------------------- |
29-
| mask2former_r50 | [mask2former_r50_8xb2_e50_panopatic](https://github.com/alibaba/EasyCV/tree/master/configs/segmentation/mask2former/mask2former_r50_8xb2_e50_panopatic.py) | 18.8 | 51.64 | 44.81 | 41.88 |[model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/mask2former_r50_panoptic/epoch_50.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/mask2former_r50_panoptic/20220629_170721.log.json) |
27+
| Algorithm | Config | Params<br/>(backbone/total) | Train memory<br/>(GB) | inference time(A100)<br/>(ms/img) | PQ | box MAP | Mask mAP | Download |
28+
| ---------- | ---------- | ------------------------------------------------------------ | ------------------------ |----------|---------------------------------------------------------------------------- |---------------------------------------------------------------------------- |---|---|
29+
| mask2former_r50 | [mask2former_r50_8xb2_e50_panopatic](https://github.com/alibaba/EasyCV/tree/master/configs/segmentation/mask2former/mask2former_r50_8xb2_e50_panopatic.py) | 23.5M/44M | 18.8 | 241ms | 51.64 | 44.81 | 41.88 |[model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/mask2former_r50_panoptic/epoch_50.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/mask2former_r50_panoptic/20220629_170721.log.json) |
3030

31+
### Semantic Segmentation on ADE20K
32+
33+
| Algorithm | Config | Params<br/>(backbone/total) |Train memory<br/>(GB) | inference time(A100)<br/>(ms/img)| mIOU |Download |
34+
| ---------- | ---------- | ------------------------------------------------------------ |---------------------------------------------------------------------------- |---------------------------------------------------------------------------- |---|---|
35+
| mask2former_r50 | [mask2former_r50_8xb2_e127_semantic](https://github.com/alibaba/EasyCV/tree/master/configs/segmentation/mask2former/mask2former_r50_8xb2_e127_semantic.py) | 23.5M/44M | 5.6 | 504ms | 47.03 |[model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/mask2former_r50_semantic/epoch_116.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/mask2former_r50_semantic/20220929_145919.log.json) |
3136

3237
## SegFormer
3338

0 commit comments

Comments
 (0)