Skip to content

Commit 3c6c2c0

Browse files
authored
add mask2former algo (#115)
add mask2former algo support panopitc pipeline add segment predictor
1 parent 0f09b45 commit 3c6c2c0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+6090
-46
lines changed

.github/workflows/citest.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ jobs:
6161
# do not uncomments, casue faild in Online UT, install requirements by yourself on UT machine
6262
# pip install -r requirements.txt
6363
#run test
64-
export CUDA_VISIBLE_DEVICES=6
64+
export CUDA_VISIBLE_DEVICES=7
6565
source ~/workspace/anaconda2/etc/profile.d/conda.sh
6666
conda activate evtorch_torch1.8.0
6767
PYTHONPATH=. python tests/run.py
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
_base_ = ['configs/base.py']
2+
3+
CLASSES = [
4+
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
5+
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
6+
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
7+
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
8+
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
9+
'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
10+
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
11+
'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
12+
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
13+
'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
14+
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
15+
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
16+
'hair drier', 'toothbrush'
17+
]
18+
19+
model = dict(
20+
type='Mask2Former',
21+
backbone=dict(
22+
type='ResNet',
23+
depth=50,
24+
num_stages=4,
25+
out_indices=(1, 2, 3, 4),
26+
frozen_stages=-1,
27+
norm_cfg=dict(type='BN', requires_grad=False),
28+
norm_eval=True),
29+
head=dict(
30+
type='Mask2FormerHead',
31+
pixel_decoder=dict(
32+
input_stride=[4, 8, 16, 32],
33+
input_channel=[256, 512, 1024, 2048],
34+
transformer_dropout=0.0,
35+
transformer_nheads=8,
36+
transformer_dim_feedforward=1024,
37+
transformer_enc_layers=6,
38+
conv_dim=256,
39+
mask_dim=256,
40+
norm='GN',
41+
transformer_in_features=[1, 2, 3],
42+
common_stride=4,
43+
),
44+
transformer_decoder=dict(
45+
in_channels=256,
46+
num_classes=80,
47+
hidden_dim=256,
48+
num_queries=100,
49+
nheads=8,
50+
dim_feedforward=2048,
51+
dec_layers=9,
52+
pre_norm=False,
53+
mask_dim=256,
54+
enforce_input_project=False,
55+
),
56+
num_things_classes=80,
57+
num_stuff_classes=0,
58+
),
59+
train_cfg=dict(
60+
class_weight=2.0,
61+
mask_weight=5.0,
62+
dice_weight=5.0,
63+
deep_supervision=True,
64+
dec_layers=10,
65+
num_points=12554,
66+
no_object_weight=0.1,
67+
oversample_ratio=3.0,
68+
importance_sample_ratio=0.75,
69+
),
70+
test_cfg=dict(
71+
instance_on=True,
72+
max_per_image=100,
73+
),
74+
pretrained=True,
75+
)
76+
# dataset settings
77+
data_root = 'database/coco/'
78+
image_size = (1024, 1024)
79+
img_norm_cfg = dict(
80+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
81+
pad_cfg = dict(img=(128, 128, 128), masks=0, seg=255)
82+
train_pipeline = [
83+
dict(type='MMRandomFlip', flip_ratio=0.5),
84+
dict(
85+
type='MMResize',
86+
img_scale=image_size,
87+
ratio_range=(0.1, 2.0),
88+
multiscale_mode='range',
89+
keep_ratio=True),
90+
dict(
91+
type='MMRandomCrop',
92+
crop_size=image_size,
93+
crop_type='absolute',
94+
recompute_bbox=True,
95+
allow_negative_crop=True),
96+
dict(
97+
type='MMFilterAnnotations', min_gt_bbox_wh=(1e-5, 1e-5), by_mask=True),
98+
dict(type='MMPad', size=image_size, pad_val=pad_cfg),
99+
dict(type='MMNormalize', **img_norm_cfg),
100+
dict(type='DefaultFormatBundle'),
101+
dict(
102+
type='Collect',
103+
keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks'],
104+
meta_keys=('filename', 'ori_filename', 'ori_shape', 'ori_img_shape',
105+
'img_shape', 'pad_shape', 'scale_factor', 'flip',
106+
'flip_direction', 'img_norm_cfg')),
107+
]
108+
109+
test_pipeline = [
110+
# dict(type='LoadImageFromFile'),
111+
dict(
112+
type='MMMultiScaleFlipAug',
113+
img_scale=(1333, 800),
114+
flip=False,
115+
transforms=[
116+
dict(type='MMResize', keep_ratio=True),
117+
dict(type='MMRandomFlip'),
118+
dict(type='MMPad', size_divisor=32, pad_val=pad_cfg),
119+
dict(type='MMNormalize', **img_norm_cfg),
120+
dict(type='ImageToTensor', keys=['img']),
121+
dict(
122+
type='Collect',
123+
keys=['img'],
124+
meta_keys=('filename', 'ori_filename', 'ori_shape',
125+
'ori_img_shape', 'img_shape', 'pad_shape',
126+
'scale_factor', 'flip', 'flip_direction',
127+
'img_norm_cfg')),
128+
])
129+
]
130+
131+
train_dataset = dict(
132+
type='DetDataset',
133+
data_source=dict(
134+
type='DetSourceCoco',
135+
ann_file=data_root + 'annotations/instances_train2017.json',
136+
img_prefix=data_root + 'train2017/',
137+
# seg_prefix=data_root + 'panoptic_train2017/',
138+
pipeline=[
139+
dict(type='LoadImageFromFile', to_float32=True),
140+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True)
141+
],
142+
classes=CLASSES,
143+
filter_empty_gt=True,
144+
iscrowd=False,
145+
),
146+
pipeline=train_pipeline)
147+
148+
val_dataset = dict(
149+
type='DetDataset',
150+
imgs_per_gpu=1,
151+
data_source=dict(
152+
type='DetSourceCoco',
153+
ann_file=data_root + 'annotations/instances_val2017.json',
154+
img_prefix=data_root + 'val2017/',
155+
pipeline=[
156+
dict(type='LoadImageFromFile', to_float32=True),
157+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True)
158+
],
159+
classes=CLASSES,
160+
test_mode=True,
161+
iscrowd=True,
162+
),
163+
pipeline=test_pipeline)
164+
165+
data = dict(
166+
imgs_per_gpu=2, workers_per_gpu=2, train=train_dataset, val=val_dataset)
167+
168+
embed_multi = dict(lr_mult=1.0, decay_mult=0.0)
169+
# optimizer
170+
optimizer = dict(
171+
type='AdamW',
172+
lr=0.0001,
173+
weight_decay=0.05,
174+
eps=1e-8,
175+
betas=(0.9, 0.999),
176+
paramwise_options={
177+
'backbone': dict(lr_mult=0.1),
178+
'query_embed': dict(weight_decay=0.),
179+
'query_feat': dict(weight_decay=0.),
180+
'level_embed': dict(weight_decay=0.),
181+
'norm': dict(weight_decay=0.),
182+
})
183+
optimizer_config = dict(grad_clip=dict(max_norm=0.01, norm_type=2))
184+
total_epochs = 50
185+
186+
# learning policy
187+
lr_config = dict(
188+
policy='step',
189+
gamma=0.1,
190+
by_epoch=False,
191+
step=[327778, 355092],
192+
warmup='linear',
193+
warmup_by_epoch=False,
194+
warmup_ratio=1.0, # no warmup
195+
warmup_iters=10)
196+
197+
checkpoint_config = dict(interval=1)
198+
199+
eval_config = dict(initial=False, interval=1, gpu_collect=False)
200+
eval_pipelines = [
201+
dict(
202+
mode='test',
203+
dist_eval=True,
204+
evaluators=[
205+
dict(type='CocoDetectionEvaluator', classes=CLASSES),
206+
dict(type='CocoMaskEvaluator', classes=CLASSES)
207+
],
208+
)
209+
]

0 commit comments

Comments
 (0)