Skip to content

Commit 5d5e23b

Browse files
author
Your Name
committed
sparseinst support export onnx now.
1 parent 8df5f3c commit 5d5e23b

File tree

7 files changed

+107
-68
lines changed

7 files changed

+107
-68
lines changed

.gitignore

+2-4
Original file line numberDiff line numberDiff line change
@@ -138,15 +138,13 @@ dmypy.json
138138
.pyre/
139139

140140
.idea/
141-
weights/
142141
*.pth
143142
*.onnx
144143

145144
.vscode/
146145
output/
147146
datasets/
148-
weights/
149147
vendor/vendor/
150148
vendor/
151-
log2.md
152-
vendor/
149+
log2.md
150+
vendor/

export_onnx.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,11 @@ def change_detr_onnx(onnx_path):
182182
print(f"[INFO] onnx修改完成, 保存在{onnx_path + '_changed.onnx'}.")
183183

184184

185-
def load_test_image(f, h, w):
185+
def load_test_image(f, h, w, bs=1):
186186
a = cv2.imread(f)
187187
a = cv2.resize(a, (w, h))
188-
# a_t = torch.tensor(a.astype(np.float32)).to(device).unsqueeze(0)
189-
a_t = torch.tensor(a.astype(np.float32)).to(device)
188+
a_t = torch.tensor(a.astype(np.float32)).to(device).unsqueeze(0).repeat(bs, 1, 1, 1)
189+
# a_t = torch.tensor(a.astype(np.float32)).to(device)
190190
return a_t, a
191191

192192

@@ -238,6 +238,15 @@ def vis_res_fast(res, img, colors):
238238
return img
239239

240240

241+
def get_output_names_from_config_file(config_file):
242+
if 'sparse_inst' in config_file:
243+
return ['masks', 'scores', 'labels']
244+
elif 'detr' in config_file:
245+
return ['boxes', 'scores', 'labels']
246+
else:
247+
return ['outs']
248+
249+
241250
if __name__ == "__main__":
242251
mp.set_start_method("spawn", force=True)
243252
args = get_parser().parse_args()
@@ -255,14 +264,14 @@ def vis_res_fast(res, img, colors):
255264
metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
256265
predictor = DefaultPredictor(cfg)
257266

258-
h = 1056
259-
w = 1920
260-
# h = 640
261-
# w = 640
267+
# h = 1056
268+
# w = 1920
269+
h = 640
270+
w = 640
262271
inp, ori_img = load_test_image(args.input, h, w)
263272
# TODO: remove hard coded for detr
264273
# inp, ori_img = load_test_image_detr(args.input, h, w)
265-
print("input shape: ", inp.shape)
274+
logger.info(f"input shape: {inp.shape}")
266275

267276
model = predictor.model
268277
model = model.float()
@@ -273,9 +282,9 @@ def vis_res_fast(res, img, colors):
273282
)
274283
torch.onnx.export(
275284
model,
276-
[inp],
285+
inp,
277286
onnx_f,
278-
output_names={"out"},
287+
output_names=get_output_names_from_config_file(args.config_file),
279288
opset_version=12,
280289
do_constant_folding=True,
281290
verbose=args.verbose,

readme.md

+51-52
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ Here are some tasks need to be claimed:
9393

9494

9595

96-
9796
## 💁‍♂️ Results
9897

9998
| YOLOv7 Instance | Face & Detection |
@@ -103,35 +102,9 @@ Here are some tasks need to be claimed:
103102
![](https://s1.ax1x.com/2022/03/25/qN5zp6.png) | ![](https://s2.loli.net/2022/03/25/MBwq9YT7zC5Sd1A.png)
104103

105104

106-
## 🤔 Features
107-
108-
Some highlights of YOLOv7 are:
109-
110-
- A simple and standard training framework for any detection && instance segmentation tasks, based on detectron2;
111-
- Supports DETR and many transformer based detection framework out-of-box;
112-
- Supports easy to deploy pipeline thought onnx.
113-
- **This is the only framework support YOLOv4 + InstanceSegmentation** in single stage style;
114-
- Easily plugin into transformers based detector;
115-
116-
We are strongly recommend you send PR if you have any further development on this project, **the only reason for opensource it is just for using community power to make it stronger and further**. It's very welcome for anyone contribute on any features!
117-
118-
119-
120-
## 😎 Rules
121-
122-
There are some rules you must follow to if you want train on your own dataset:
123-
124-
- Rule No.1: Always set your own anchors on your dataset, using `tools/compute_anchors.py`, this applys to any other anchor-based detection methods as well (EfficientDet etc.);
125-
- Rule No.2: Keep a faith on your loss will goes down eventually, if not, dig deeper to find out why (but do not post issues repeated caused I might don't know either.).
126-
- Rule No.3: No one will tells u but it's real: *do not change backbone easily, whole params coupled with your backbone, dont think its simple as you think it should be*, also a Deeplearning engineer **is not an easy work as you think**, the whole knowledge like an ocean, and your knowledge is just a tiny drop of water...
127-
- Rule No.4: **must** using pretrain weights for **transoformer based backbone**, otherwise your loss will bump;
128-
129-
Make sure you have read **rules** before ask me any questions.
130-
131-
132-
133105
## 🆕 News!
134106

107+
- **2022.04.15**: Now, we support the `SparseInst` onnx expport!
135108
- **2022.03.25**: New instance seg supported! 40 FPS @ 37 mAP!! Which is fast;
136109
- **2021.09.16**: First transformer based DETR model added, will explore more DETR series models;
137110
- **2021.08.02**: **YOLOX** arch added, you can train YOLOX as well in this repo;
@@ -145,21 +118,32 @@ Make sure you have read **rules** before ask me any questions.
145118

146119
- See [docs/install.md](docs/install.md)
147120

148-
## 😎 Train
149121

150-
For training, quite simple, same as detectron2:
151122

152-
```
153-
python train_net.py --config-file configs/coco/darknet53.yaml --num-gpus 8
154-
```
123+
## 🤔 Features
155124

156-
If you want train YOLOX, you can using config file `configs/coco/yolox_s.yaml`. All support arch are:
125+
Some highlights of YOLOv7 are:
157126

158-
- **YOLOX**: anchor free yolo;
159-
- **YOLOv7**: traditional yolo with some explorations, mainly focus on loss experiments;
160-
- **YOLOv7P**: traditional yolo merged with decent arch from YOLOX;
161-
- **YOLOMask**: arch do detection and segmentation at the same time (tbd);
162-
- **YOLOInsSeg**: instance segmentation based on YOLO detection (tbd);
127+
- A simple and standard training framework for any detection && instance segmentation tasks, based on detectron2;
128+
- Supports DETR and many transformer based detection framework out-of-box;
129+
- Supports easy to deploy pipeline thought onnx.
130+
- **This is the only framework support YOLOv4 + InstanceSegmentation** in single stage style;
131+
- Easily plugin into transformers based detector;
132+
133+
We are strongly recommend you send PR if you have any further development on this project, **the only reason for opensource it is just for using community power to make it stronger and further**. It's very welcome for anyone contribute on any features!
134+
135+
## 🧙‍♂️ Pretrained Models
136+
137+
| model | backbone | input | aug | AP<sup>val</sup> | AP | FPS | weights |
138+
| :---- | :------ | :---: | :-: |:--------------: | :--: | :-: | :-----: |
139+
| [SparseInst](configs/sparse_inst_r50_base.yaml) | [R-50]() | 640 | &#x2718; | 32.8 | - | 44.3 | [model](https://drive.google.com/file/d/12RQLHD5EZKIOvlqW3avUCeYjFG1NPKDy/view?usp=sharing) |
140+
| [SparseInst](sparse_inst_r50vd_base.yaml) | [R-50-vd](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth) | 640 | &#x2718; | 34.1 | - | 42.6 | [model]()|
141+
| [SparseInst (G-IAM)](configs/sparse_inst_r50_giam.yaml) | [R-50]() | 608 | &#x2718; | 33.4 | - | 44.6 | [model](https://drive.google.com/file/d/1pXU7Dsa1L7nUiLU9ULG2F6Pl5m5NEguL/view?usp=sharing) |
142+
| [SparseInst (G-IAM)](configs/sparse_inst_r50_giam_aug.yaml) | [R-50]() | 608 | &#10003; | 34.2 | 34.7 | 44.6 | [model](https://drive.google.com/file/d/1MK8rO3qtA7vN9KVSBdp0VvZHCNq8-bvz/view?usp=sharing) |
143+
| [SparseInst (G-IAM)](configs/sparse_inst_r50_dcn_giam_aug.yaml) | [R-50-DCN]() | 608 | &#10003;| 36.4 | 36.8 | 41.6 | [model](https://drive.google.com/file/d/1qxdLRRHbIWEwRYn-NPPeCCk6fhBjc946/view?usp=sharing) |
144+
| [SparseInst (G-IAM)](configs/sparse_inst_r50vd_giam_aug.yaml) | [R-50-vd](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth) | 608 | &#10003;| 35.6 | 36.1 | 42.8| [model](https://drive.google.com/file/d/1dlamg7ych_BdWpPUCuiBXbwE0SXpsfGx/view?usp=sharing) |
145+
| [SparseInst (G-IAM)](configs/sparse_inst_r50vd_dcn_giam_aug.yaml) | [R-50-vd-DCN](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth) | 608 | &#10003; | 37.4 | 37.9 | 40.0 | [model](https://drive.google.com/file/d/1clYPdCNrDNZLbmlAEJ7wjsrOLn1igOpT/view?usp=sharing)|
146+
| [SparseInst (G-IAM)](sparse_inst_r50vd_dcn_giam_aug.yaml) | [R-50-vd-DCN](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth) | 640 | &#10003; | 37.7 | 38.1 | 39.3 | [model](https://drive.google.com/file/d/1clYPdCNrDNZLbmlAEJ7wjsrOLn1igOpT/view?usp=sharing)|
163147

164148

165149
## 🥰 Demo
@@ -182,27 +166,41 @@ python demo.py --config-file configs/coco/sparseinst/sparse_inst_r50vd_giam_aug.
182166
python3 demo_lazyconfig.py --config-file configs/new_baselines/panoptic_fpn_regnetx_0.4g.py --opts train.init_checkpoint=output/model_0004999.pth
183167
```
184168

169+
## 😎 Train
185170

186-
## 🧙‍♂️ Pretrained Models
171+
For training, quite simple, same as detectron2:
172+
173+
```
174+
python train_net.py --config-file configs/coco/darknet53.yaml --num-gpus 8
175+
```
176+
177+
If you want train YOLOX, you can using config file `configs/coco/yolox_s.yaml`. All support arch are:
178+
179+
- **YOLOX**: anchor free yolo;
180+
- **YOLOv7**: traditional yolo with some explorations, mainly focus on loss experiments;
181+
- **YOLOv7P**: traditional yolo merged with decent arch from YOLOX;
182+
- **YOLOMask**: arch do detection and segmentation at the same time (tbd);
183+
- **YOLOInsSeg**: instance segmentation based on YOLO detection (tbd);
184+
185+
186+
## 😎 Rules
187+
188+
There are some rules you must follow to if you want train on your own dataset:
189+
190+
- Rule No.1: Always set your own anchors on your dataset, using `tools/compute_anchors.py`, this applys to any other anchor-based detection methods as well (EfficientDet etc.);
191+
- Rule No.2: Keep a faith on your loss will goes down eventually, if not, dig deeper to find out why (but do not post issues repeated caused I might don't know either.).
192+
- Rule No.3: No one will tells u but it's real: *do not change backbone easily, whole params coupled with your backbone, dont think its simple as you think it should be*, also a Deeplearning engineer **is not an easy work as you think**, the whole knowledge like an ocean, and your knowledge is just a tiny drop of water...
193+
- Rule No.4: **must** using pretrain weights for **transoformer based backbone**, otherwise your loss will bump;
194+
195+
Make sure you have read **rules** before ask me any questions.
187196

188-
| model | backbone | input | aug | AP<sup>val</sup> | AP | FPS | weights |
189-
| :---- | :------ | :---: | :-: |:--------------: | :--: | :-: | :-----: |
190-
| [SparseInst](configs/sparse_inst_r50_base.yaml) | [R-50]() | 640 | &#x2718; | 32.8 | - | 44.3 | [model](https://drive.google.com/file/d/12RQLHD5EZKIOvlqW3avUCeYjFG1NPKDy/view?usp=sharing) |
191-
| [SparseInst](sparse_inst_r50vd_base.yaml) | [R-50-vd](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth) | 640 | &#x2718; | 34.1 | - | 42.6 | [model]()|
192-
| [SparseInst (G-IAM)](configs/sparse_inst_r50_giam.yaml) | [R-50]() | 608 | &#x2718; | 33.4 | - | 44.6 | [model](https://drive.google.com/file/d/1pXU7Dsa1L7nUiLU9ULG2F6Pl5m5NEguL/view?usp=sharing) |
193-
| [SparseInst (G-IAM)](configs/sparse_inst_r50_giam_aug.yaml) | [R-50]() | 608 | &#10003; | 34.2 | 34.7 | 44.6 | [model](https://drive.google.com/file/d/1MK8rO3qtA7vN9KVSBdp0VvZHCNq8-bvz/view?usp=sharing) |
194-
| [SparseInst (G-IAM)](configs/sparse_inst_r50_dcn_giam_aug.yaml) | [R-50-DCN]() | 608 | &#10003;| 36.4 | 36.8 | 41.6 | [model](https://drive.google.com/file/d/1qxdLRRHbIWEwRYn-NPPeCCk6fhBjc946/view?usp=sharing) |
195-
| [SparseInst (G-IAM)](configs/sparse_inst_r50vd_giam_aug.yaml) | [R-50-vd](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth) | 608 | &#10003;| 35.6 | 36.1 | 42.8| [model](https://drive.google.com/file/d/1dlamg7ych_BdWpPUCuiBXbwE0SXpsfGx/view?usp=sharing) |
196-
| [SparseInst (G-IAM)](configs/sparse_inst_r50vd_dcn_giam_aug.yaml) | [R-50-vd-DCN](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth) | 608 | &#10003; | 37.4 | 37.9 | 40.0 | [model](https://drive.google.com/file/d/1clYPdCNrDNZLbmlAEJ7wjsrOLn1igOpT/view?usp=sharing)|
197-
| [SparseInst (G-IAM)](sparse_inst_r50vd_dcn_giam_aug.yaml) | [R-50-vd-DCN](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth) | 640 | &#10003; | 37.7 | 38.1 | 39.3 | [model](https://drive.google.com/file/d/1clYPdCNrDNZLbmlAEJ7wjsrOLn1igOpT/view?usp=sharing)|
198197

199198
## 🔨 Export ONNX && TensorRTT && TVM
200199

201200
1. `detr`:
202201

203202
```
204203
python export_onnx.py --config-file detr/config/file
205-
206204
```
207205

208206
this works has been done, inference script included inside `tools`.
@@ -211,7 +209,8 @@ python3 demo_lazyconfig.py --config-file configs/new_baselines/panoptic_fpn_regn
211209

212210
anchorDETR also supported training and exporting to ONNX.
213211

214-
212+
3. `SparseInst`:
213+
Sparsinst already supported exporting to onnx!!
215214

216215

217216

weights/.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
*
2+
!.gitignore
3+
!get_models.sh

weights/get_models.sh

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
gdown https://drive.google.com/file/d/1MK8rO3qtA7vN9KVSBdp0VvZHCNq8-bvz/view\?usp\=sharing --fuzzy

yolov7/modeling/meta_arch/sparseinst.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -259,5 +259,8 @@ def inference_onnx(self, output, batched_inputs, max_shape, image_sizes):
259259

260260
all_scores = torch.stack(all_scores)
261261
all_labels = torch.stack(all_labels)
262-
all_masks = torch.stack(all_masks)
262+
all_masks = torch.stack(all_masks).to(torch.int64)
263+
logger.info(f'all_scores: {all_scores.shape}')
264+
logger.info(f'all_labels: {all_labels.shape}')
265+
logger.info(f'all_masks: {all_masks.shape}')
263266
return all_masks, all_scores, all_labels

yolov7/modeling/transcoders/encoder_sparseinst.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,33 @@
99

1010
from detectron2.utils.registry import Registry
1111
from detectron2.layers import Conv2d
12+
from alfred.utils.log import logger
1213

1314
SPARSE_INST_ENCODER_REGISTRY = Registry("SPARSE_INST_ENCODER")
1415
SPARSE_INST_ENCODER_REGISTRY.__doc__ = "registry for SparseInst decoder"
1516

1617

18+
class MyAdaptiveAvgPool2d(nn.Module):
19+
def __init__(self, sz=None):
20+
super().__init__()
21+
self.sz = sz
22+
23+
def forward(self, x):
24+
inp_size = x.size()
25+
kernel_width, kernel_height = inp_size[2], inp_size[3]
26+
if self.sz is not None:
27+
if isinstance(self.sz, int):
28+
kernel_width = math.ceil(inp_size[2] / self.sz)
29+
kernel_height = math.ceil(inp_size[3] / self.sz)
30+
elif isinstance(self.sz, list) or isinstance(self.sz, tuple):
31+
assert len(self.sz) == 2
32+
kernel_width = math.ceil(inp_size[2] / self.sz[0])
33+
kernel_height = math.ceil(inp_size[3] / self.sz[1])
34+
return F.avg_pool2d(
35+
input=x, ceil_mode=False, kernel_size=(kernel_width, kernel_height)
36+
)
37+
38+
1739
class PyramidPoolingModule(nn.Module):
1840
def __init__(self, in_channels, channels=512, sizes=(1, 2, 3, 6)):
1941
super().__init__()
@@ -24,7 +46,11 @@ def __init__(self, in_channels, channels=512, sizes=(1, 2, 3, 6)):
2446
self.bottleneck = Conv2d(in_channels + len(sizes) * channels, in_channels, 1)
2547

2648
def _make_stage(self, features, out_features, size):
27-
prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
49+
if torch.onnx.is_in_onnx_export:
50+
logger.warning(f'Replace nn.AdaptiveAvgPool2d for onnx export, size: {size}x{size}')
51+
prior = MyAdaptiveAvgPool2d((size, size))
52+
else:
53+
prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
2854
conv = Conv2d(features, out_features, 1)
2955
return nn.Sequential(prior, conv)
3056

0 commit comments

Comments
 (0)