21
21
import paddle .nn .functional as F
22
22
23
23
24
- def get_reverse_list (ori_shape , transforms ):
25
- """
26
- get reverse list of transform.
27
-
28
- Args:
29
- ori_shape (list): Origin shape of image.
30
- transforms (list): List of transform.
31
-
32
- Returns:
33
- list: List of tuple, there are two format:
34
- ('resize', (h, w)) The image shape before resize,
35
- ('padding', (h, w)) The image shape before padding.
36
- """
37
- reverse_list = []
38
- h , w = ori_shape [0 ], ori_shape [1 ]
39
- for op in transforms :
40
- if op .__class__ .__name__ in ['Resize' ]:
41
- reverse_list .append (('resize' , (h , w )))
42
- h , w = op .target_size [0 ], op .target_size [1 ]
43
- if op .__class__ .__name__ in ['ResizeByLong' ]:
44
- reverse_list .append (('resize' , (h , w )))
45
- long_edge = max (h , w )
46
- short_edge = min (h , w )
47
- short_edge = int (round (short_edge * op .long_size / long_edge ))
48
- long_edge = op .long_size
49
- if h > w :
50
- h = long_edge
51
- w = short_edge
52
- else :
53
- w = long_edge
54
- h = short_edge
55
- if op .__class__ .__name__ in ['ResizeByShort' ]:
56
- reverse_list .append (('resize' , (h , w )))
57
- long_edge = max (h , w )
58
- short_edge = min (h , w )
59
- long_edge = int (round (long_edge * op .short_size / short_edge ))
60
- short_edge = op .short_size
61
- if h > w :
62
- h = long_edge
63
- w = short_edge
64
- else :
65
- w = long_edge
66
- h = short_edge
67
- if op .__class__ .__name__ in ['Padding' ]:
68
- reverse_list .append (('padding' , (h , w )))
69
- w , h = op .target_size [0 ], op .target_size [1 ]
70
- if op .__class__ .__name__ in ['PaddingByAspectRatio' ]:
71
- reverse_list .append (('padding' , (h , w )))
72
- ratio = w / h
73
- if ratio == op .aspect_ratio :
74
- pass
75
- elif ratio > op .aspect_ratio :
76
- h = int (w / op .aspect_ratio )
77
- else :
78
- w = int (h * op .aspect_ratio )
79
- if op .__class__ .__name__ in ['LimitLong' ]:
80
- long_edge = max (h , w )
81
- short_edge = min (h , w )
82
- if ((op .max_long is not None ) and (long_edge > op .max_long )):
83
- reverse_list .append (('resize' , (h , w )))
84
- long_edge = op .max_long
85
- short_edge = int (round (short_edge * op .max_long / long_edge ))
86
- elif ((op .min_long is not None ) and (long_edge < op .min_long )):
87
- reverse_list .append (('resize' , (h , w )))
88
- long_edge = op .min_long
89
- short_edge = int (round (short_edge * op .min_long / long_edge ))
90
- if h > w :
91
- h = long_edge
92
- w = short_edge
93
- else :
94
- w = long_edge
95
- h = short_edge
96
- return reverse_list
97
-
98
-
99
- def reverse_transform (pred , ori_shape , transforms , mode = 'nearest' ):
24
+ def reverse_transform (pred , trans_info , mode = 'nearest' ):
100
25
"""recover pred to origin shape"""
101
- reverse_list = get_reverse_list (ori_shape , transforms )
102
26
intTypeList = [paddle .int8 , paddle .int16 , paddle .int32 , paddle .int64 ]
103
27
dtype = pred .dtype
104
- for item in reverse_list [::- 1 ]:
105
- if item [0 ] == 'resize' :
28
+ for item in trans_info [::- 1 ]:
29
+ if isinstance (item [0 ], list ):
30
+ trans_mode = item [0 ][0 ]
31
+ else :
32
+ trans_mode = item [0 ]
33
+ if trans_mode == 'resize' :
106
34
h , w = item [1 ][0 ], item [1 ][1 ]
107
35
if paddle .get_device () == 'cpu' and dtype in intTypeList :
108
36
pred = paddle .cast (pred , 'float32' )
109
37
pred = F .interpolate (pred , (h , w ), mode = mode )
110
38
pred = paddle .cast (pred , dtype )
111
39
else :
112
40
pred = F .interpolate (pred , (h , w ), mode = mode )
113
- elif item [ 0 ] == 'padding' :
41
+ elif trans_mode == 'padding' :
114
42
h , w = item [1 ][0 ], item [1 ][1 ]
115
43
pred = pred [:, :, 0 :h , 0 :w ]
116
44
else :
@@ -205,8 +133,7 @@ def slide_inference(model, im, crop_size, stride):
205
133
206
134
def inference (model ,
207
135
im ,
208
- ori_shape = None ,
209
- transforms = None ,
136
+ trans_info = None ,
210
137
is_slide = False ,
211
138
stride = None ,
212
139
crop_size = None ):
@@ -216,8 +143,7 @@ def inference(model,
216
143
Args:
217
144
model (paddle.nn.Layer): model to get logits of image.
218
145
im (Tensor): the input image.
219
- ori_shape (list): Origin shape of image.
220
- transforms (list): Transforms for image.
146
+ trans_info (list): Image shape informating changed process. Default: None.
221
147
is_slide (bool): Whether to infer by sliding window. Default: False.
222
148
crop_size (tuple|list). The size of sliding window, (w, h). It should be probided if is_slide is True.
223
149
stride (tuple|list). The size of stride, (w, h). It should be probided if is_slide is True.
@@ -239,8 +165,8 @@ def inference(model,
239
165
logit = slide_inference (model , im , crop_size = crop_size , stride = stride )
240
166
if hasattr (model , 'data_format' ) and model .data_format == 'NHWC' :
241
167
logit = logit .transpose ((0 , 3 , 1 , 2 ))
242
- if ori_shape is not None :
243
- logit = reverse_transform (logit , ori_shape , transforms , mode = 'bilinear' )
168
+ if trans_info is not None :
169
+ logit = reverse_transform (logit , trans_info , mode = 'bilinear' )
244
170
pred = paddle .argmax (logit , axis = 1 , keepdim = True , dtype = 'int32' )
245
171
return pred , logit
246
172
else :
@@ -249,8 +175,7 @@ def inference(model,
249
175
250
176
def aug_inference (model ,
251
177
im ,
252
- ori_shape ,
253
- transforms ,
178
+ trans_info ,
254
179
scales = 1.0 ,
255
180
flip_horizontal = False ,
256
181
flip_vertical = False ,
@@ -263,8 +188,7 @@ def aug_inference(model,
263
188
Args:
264
189
model (paddle.nn.Layer): model to get logits of image.
265
190
im (Tensor): the input image.
266
- ori_shape (list): Origin shape of image.
267
- transforms (list): Transforms for image.
191
+ trans_info (list): Transforms for image.
268
192
scales (float|tuple|list): Scales for resize. Default: 1.
269
193
flip_horizontal (bool): Whether to flip horizontally. Default: False.
270
194
flip_vertical (bool): Whether to flip vertically. Default: False.
@@ -302,8 +226,7 @@ def aug_inference(model,
302
226
logit = F .softmax (logit , axis = 1 )
303
227
final_logit = final_logit + logit
304
228
305
- final_logit = reverse_transform (
306
- final_logit , ori_shape , transforms , mode = 'bilinear' )
229
+ final_logit = reverse_transform (final_logit , trans_info , mode = 'bilinear' )
307
230
pred = paddle .argmax (final_logit , axis = 1 , keepdim = True , dtype = 'int32' )
308
231
309
232
return pred , final_logit
0 commit comments