34
34
35
35
@register
36
36
class BBoxPostProcess (nn .Layer ):
37
- __shared__ = ['num_classes' ]
37
+ __shared__ = ['num_classes' , 'export_onnx' ]
38
38
__inject__ = ['decode' , 'nms' ]
39
39
40
- def __init__ (self , num_classes = 80 , decode = None , nms = None ):
40
+ def __init__ (self , num_classes = 80 , decode = None , nms = None ,
41
+ export_onnx = False ):
41
42
super (BBoxPostProcess , self ).__init__ ()
42
43
self .num_classes = num_classes
43
44
self .decode = decode
44
45
self .nms = nms
46
+ self .export_onnx = export_onnx
45
47
46
48
def forward (self , head_out , rois , im_shape , scale_factor ):
47
49
"""
@@ -52,6 +54,7 @@ def forward(self, head_out, rois, im_shape, scale_factor):
52
54
rois (tuple): roi and rois_num of rpn_head output.
53
55
im_shape (Tensor): The shape of the input image.
54
56
scale_factor (Tensor): The scale factor of the input image.
57
+ export_onnx (bool): whether export model to onnx
55
58
Returns:
56
59
bbox_pred (Tensor): The output prediction with shape [N, 6], including
57
60
labels, scores and bboxes. The size of bboxes are corresponding
@@ -62,9 +65,20 @@ def forward(self, head_out, rois, im_shape, scale_factor):
62
65
if self .nms is not None :
63
66
bboxes , score = self .decode (head_out , rois , im_shape , scale_factor )
64
67
bbox_pred , bbox_num , _ = self .nms (bboxes , score , self .num_classes )
68
+
65
69
else :
66
70
bbox_pred , bbox_num = self .decode (head_out , rois , im_shape ,
67
71
scale_factor )
72
+
73
+ if self .export_onnx :
74
+ # add fake box after postprocess when exporting onnx
75
+ fake_bboxes = paddle .to_tensor (
76
+ np .array (
77
+ [[0. , 0.0 , 0.0 , 0.0 , 1.0 , 1.0 ]], dtype = 'float32' ))
78
+
79
+ bbox_pred = paddle .concat ([bbox_pred , fake_bboxes ])
80
+ bbox_num = bbox_num + 1
81
+
68
82
return bbox_pred , bbox_num
69
83
70
84
def get_pred (self , bboxes , bbox_num , im_shape , scale_factor ):
@@ -86,45 +100,55 @@ def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
86
100
pred_result (Tensor): The final prediction results with shape [N, 6]
87
101
including labels, scores and bboxes.
88
102
"""
89
-
90
- bboxes_list = []
91
- bbox_num_list = []
92
- id_start = 0
93
- fake_bboxes = paddle .to_tensor (
94
- np .array (
95
- [[ - 1 , 0.0 , 0.0 , 0.0 , 0 .0 , 0 .0 ]], dtype = 'float32' ))
96
- fake_bbox_num = paddle .to_tensor (np .array ([1 ], dtype = 'int32' ))
97
-
98
- # add fake bbox when output is empty for each batch
99
- for i in range (bbox_num .shape [0 ]):
100
- if bbox_num [i ] == 0 :
101
- bboxes_i = fake_bboxes
102
- bbox_num_i = fake_bbox_num
103
- else :
104
- bboxes_i = bboxes [id_start :id_start + bbox_num [i ], :]
105
- bbox_num_i = bbox_num [i ]
106
- id_start += bbox_num [i ]
107
- bboxes_list .append (bboxes_i )
108
- bbox_num_list .append (bbox_num_i )
109
- bboxes = paddle .concat (bboxes_list )
110
- bbox_num = paddle .concat (bbox_num_list )
103
+ if not self . export_onnx :
104
+ bboxes_list = []
105
+ bbox_num_list = []
106
+ id_start = 0
107
+ fake_bboxes = paddle .to_tensor (
108
+ np .array (
109
+ [[ 0. , 0.0 , 0.0 , 0.0 , 1 .0 , 1 .0 ]], dtype = 'float32' ))
110
+ fake_bbox_num = paddle .to_tensor (np .array ([1 ], dtype = 'int32' ))
111
+
112
+ # add fake bbox when output is empty for each batch
113
+ for i in range (bbox_num .shape [0 ]):
114
+ if bbox_num [i ] == 0 :
115
+ bboxes_i = fake_bboxes
116
+ bbox_num_i = fake_bbox_num
117
+ else :
118
+ bboxes_i = bboxes [id_start :id_start + bbox_num [i ], :]
119
+ bbox_num_i = bbox_num [i ]
120
+ id_start += bbox_num [i ]
121
+ bboxes_list .append (bboxes_i )
122
+ bbox_num_list .append (bbox_num_i )
123
+ bboxes = paddle .concat (bboxes_list )
124
+ bbox_num = paddle .concat (bbox_num_list )
111
125
112
126
origin_shape = paddle .floor (im_shape / scale_factor + 0.5 )
113
127
114
- origin_shape_list = []
115
- scale_factor_list = []
116
- # scale_factor: scale_y, scale_x
117
- for i in range (bbox_num .shape [0 ]):
118
- expand_shape = paddle .expand (origin_shape [i :i + 1 , :],
119
- [bbox_num [i ], 2 ])
120
- scale_y , scale_x = scale_factor [i ][0 ], scale_factor [i ][1 ]
121
- scale = paddle .concat ([scale_x , scale_y , scale_x , scale_y ])
122
- expand_scale = paddle .expand (scale , [bbox_num [i ], 4 ])
123
- origin_shape_list .append (expand_shape )
124
- scale_factor_list .append (expand_scale )
128
+ if not self .export_onnx :
129
+ origin_shape_list = []
130
+ scale_factor_list = []
131
+ # scale_factor: scale_y, scale_x
132
+ for i in range (bbox_num .shape [0 ]):
133
+ expand_shape = paddle .expand (origin_shape [i :i + 1 , :],
134
+ [bbox_num [i ], 2 ])
135
+ scale_y , scale_x = scale_factor [i ][0 ], scale_factor [i ][1 ]
136
+ scale = paddle .concat ([scale_x , scale_y , scale_x , scale_y ])
137
+ expand_scale = paddle .expand (scale , [bbox_num [i ], 4 ])
138
+ origin_shape_list .append (expand_shape )
139
+ scale_factor_list .append (expand_scale )
140
+
141
+ self .origin_shape_list = paddle .concat (origin_shape_list )
142
+ scale_factor_list = paddle .concat (scale_factor_list )
125
143
126
- self .origin_shape_list = paddle .concat (origin_shape_list )
127
- scale_factor_list = paddle .concat (scale_factor_list )
144
+ else :
145
+ # simplify the computation for bs=1 when exporting onnx
146
+ scale_y , scale_x = scale_factor [0 ][0 ], scale_factor [0 ][1 ]
147
+ scale = paddle .concat (
148
+ [scale_x , scale_y , scale_x , scale_y ]).unsqueeze (0 )
149
+ self .origin_shape_list = paddle .expand (origin_shape ,
150
+ [bbox_num [0 ], 2 ])
151
+ scale_factor_list = paddle .expand (scale , [bbox_num [0 ], 4 ])
128
152
129
153
# bboxes: [N, 6], label, score, bbox
130
154
pred_label = bboxes [:, 0 :1 ]
@@ -170,19 +194,20 @@ def paste_mask(self, masks, boxes, im_h, im_w):
170
194
"""
171
195
Paste the mask prediction to the original image.
172
196
"""
173
-
197
+ x0_int , y0_int = 0 , 0
198
+ x1_int , y1_int = im_w , im_h
174
199
x0 , y0 , x1 , y1 = paddle .split (boxes , 4 , axis = 1 )
175
- masks = paddle . unsqueeze ( masks , [ 0 , 1 ])
176
- img_y = paddle .arange (0 , im_h , dtype = 'float32' ) + 0.5
177
- img_x = paddle .arange (0 , im_w , dtype = 'float32' ) + 0.5
200
+ N = masks . shape [ 0 ]
201
+ img_y = paddle .arange (y0_int , y1_int ) + 0.5
202
+ img_x = paddle .arange (x0_int , x1_int ) + 0.5
178
203
img_y = (img_y - y0 ) / (y1 - y0 ) * 2 - 1
179
204
img_x = (img_x - x0 ) / (x1 - x0 ) * 2 - 1
180
- img_x = paddle .unsqueeze (img_x , [1 ])
181
- img_y = paddle .unsqueeze (img_y , [2 ])
182
- N = boxes .shape [0 ]
205
+ # img_x, img_y have shapes (N, w), (N, h)
183
206
184
- gx = paddle .expand (img_x , [N , img_y .shape [1 ], img_x .shape [2 ]])
185
- gy = paddle .expand (img_y , [N , img_y .shape [1 ], img_x .shape [2 ]])
207
+ gx = img_x [:, None , :].expand (
208
+ [N , paddle .shape (img_y )[1 ], paddle .shape (img_x )[1 ]])
209
+ gy = img_y [:, :, None ].expand (
210
+ [N , paddle .shape (img_y )[1 ], paddle .shape (img_x )[1 ]])
186
211
grid = paddle .stack ([gx , gy ], axis = 3 )
187
212
img_masks = F .grid_sample (masks , grid , align_corners = False )
188
213
return img_masks [:, 0 ]
@@ -208,19 +233,13 @@ def __call__(self, mask_out, bboxes, bbox_num, origin_shape):
208
233
# TODO: support bs > 1 and mask output dtype is bool
209
234
pred_result = paddle .zeros (
210
235
[num_mask , origin_shape [0 ][0 ], origin_shape [0 ][1 ]], dtype = 'int32' )
211
- if bbox_num == 1 and bboxes [0 ][0 ] == - 1 :
212
- return pred_result
213
-
214
- # TODO: optimize chunk paste
215
- pred_result = []
216
- for i in range (bboxes .shape [0 ]):
217
- im_h , im_w = origin_shape [i ][0 ], origin_shape [i ][1 ]
218
- pred_mask = self .paste_mask (mask_out [i ], bboxes [i :i + 1 , 2 :], im_h ,
219
- im_w )
220
- pred_mask = pred_mask >= self .binary_thresh
221
- pred_mask = paddle .cast (pred_mask , 'int32' )
222
- pred_result .append (pred_mask )
223
- pred_result = paddle .concat (pred_result )
236
+
237
+ im_h , im_w = origin_shape [0 ][0 ], origin_shape [0 ][1 ]
238
+ pred_mask = self .paste_mask (mask_out [:, None , :, :], bboxes [:, 2 :],
239
+ im_h , im_w )
240
+ pred_mask = pred_mask >= self .binary_thresh
241
+ pred_result = paddle .cast (pred_mask , 'int32' )
242
+
224
243
return pred_result
225
244
226
245
0 commit comments