Skip to content

Commit 5120a2a

Browse files
authored
1.bug fixes for online inference (#648)
2.introduce online inference based on yaml.
1 parent c877866 commit 5120a2a

File tree

16 files changed

+289
-113
lines changed

16 files changed

+289
-113
lines changed

configs/det/dbnet/db_r50_icdar15.yaml

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ eval:
158158

159159
predict:
160160
ckpt_load_path: tmp_det/best.ckpt
161+
output_save_dir: ./output
161162
dataset_sink_mode: False
162163
dataset:
163164
type: PredictDataset
@@ -169,24 +170,24 @@ predict:
169170
- DecodeImage:
170171
img_mode: RGB
171172
to_float32: False
172-
# - DetLabelEncode:
173-
- DetResize: # GridResize 32
174-
target_size: [ 736, 1280 ]
175-
keep_ratio: False
176-
limit_type: none
177-
divisor: 32
173+
keep_ori: True
174+
- DetResize:
175+
keep_ratio: True
176+
padding: False
177+
limit_side_len: 960
178+
limit_type: max
178179
- NormalizeImage:
179180
bgr_to_rgb: False
180181
is_hwc: True
181182
mean: imagenet
182183
std: imagenet
183184
- ToCHWImage:
184185
# the order of the dataloader list, matching the network input and the labels for evaluation
185-
output_columns: [ 'img_path', 'image', 'raw_img_shape' ] # shape in h, w order
186-
# num_keys_of_labels: 2 # num labels
186+
output_columns: ["image", "img_path", "shape_list", "image_ori"]
187+
net_input_column_index: [ 0 ] # input indices for network forward func in output_columns
187188

188189
loader:
189190
shuffle: False
190-
batch_size: 1 # TODO: due to dynamic shape of polygons (num of boxes varies), BS has to be 1
191+
batch_size: 1
191192
drop_remainder: False
192193
num_workers: 2

configs/rec/crnn/crnn_resnet34.yaml

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -162,30 +162,23 @@ predict:
162162
shuffle: False
163163
transform_pipeline:
164164
- DecodeImage:
165-
img_mode: BGR
165+
img_mode: RGB
166166
to_float32: False
167-
# - RecCTCLabelEncode:
168-
# max_text_len: *max_text_len
169-
# character_dict_path: *character_dict_path
170-
# use_space_char: *use_space_char
171-
# lower: True
172-
- RecResizeImg: # different from paddle (paddle converts image from HWC to CHW and rescale to [-1, 1] after resize.
173-
image_shape: [32, 100] # H, W
174-
infer_mode: *infer_mode
175-
character_dict_path: *character_dict_path
176-
padding: False # aspect ratio will be preserved if true.
177-
- NormalizeImage: # different from paddle (paddle wrongly normalize BGR image with RGB mean/std from ImageNet for det, and simple rescale to [-1, 1] in rec.
178-
bgr_to_rgb: True
179-
is_hwc: True
180-
mean : [127.0, 127.0, 127.0]
181-
std : [127.0, 127.0, 127.0]
167+
- RecResizeNormForInfer:
168+
target_height: 32
169+
target_width: 100
170+
keep_ratio: False
171+
padding: False
172+
norm_before_pad: False
182173
- ToCHWImage:
183174
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize
184-
output_columns: [ 'img_path', 'image', 'raw_img_shape' ]
175+
output_columns: ['image', 'img_path']
176+
net_input_column_index: [0] # input indices for network forward func in output_columns
177+
# label_column_index: [1, 2] # input indices marked as label
185178

186179
loader:
187180
shuffle: False # TODO: tbc
188-
batch_size: 1
181+
batch_size: 2
189182
drop_remainder: True
190183
max_rowsize: 12
191184
num_workers: 8

mindocr/data/predict_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(
4343
raise ValueError("No transform pipeline is specified!")
4444

4545
# prefetch the data keys, to fit GeneratorDataset
46-
_data = self.data_list[0]
46+
_data = self.data_list[0].copy()
4747
_data = run_transforms(_data, transforms=self.transforms)
4848
_available_keys = list(_data.keys())
4949
if output_columns is None:
@@ -60,7 +60,7 @@ def __init__(
6060
)
6161

6262
def __getitem__(self, index):
63-
data = self.data_list[index]
63+
data = self.data_list[index].copy()
6464

6565
# perform transformation on data
6666
data = run_transforms(data, transforms=self.transforms)

mindocr/models/cls_mv3.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ def __init__(self, config):
2727

2828

2929
@register_model
30-
def cls_mobilenet_v3_small_100_model(pretrained=False, **kwargs):
31-
pretrained_backbone = not pretrained
30+
def cls_mobilenet_v3_small_100_model(pretrained=False, pretrained_backbone=True, **kwargs):
3231
model_config = {
3332
"backbone": {
3433
'name': 'cls_mobilenet_v3_small_100',

mindocr/models/det_dbnet.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def __init__(self, config):
3333

3434

3535
@register_model
36-
def dbnet_mobilenetv3(pretrained=False, **kwargs):
37-
pretrained_backbone = 'https://download.mindspore.cn/toolkits/mindcv/mobilenet/mobilenetv3' \
36+
def dbnet_mobilenetv3(pretrained=False, pretrained_backbone=True, **kwargs):
37+
backbone_ckpt_url = 'https://download.mindspore.cn/toolkits/mindcv/mobilenet/mobilenetv3' \
3838
'/mobilenet_v3_large_050_no_scale_se_v2_expand-3c4047ac.ckpt'
3939
model_config = {
4040
"backbone": {
@@ -43,7 +43,7 @@ def dbnet_mobilenetv3(pretrained=False, **kwargs):
4343
'alpha': 0.5,
4444
'out_stages': [5, 8, 14, 20],
4545
'bottleneck_params': {'se_version': 'SqueezeExciteV2', 'always_expand': True},
46-
'pretrained': pretrained_backbone if not pretrained else False
46+
'pretrained': backbone_ckpt_url if pretrained_backbone else False
4747
},
4848
"neck": {
4949
"name": 'DBFPN',
@@ -68,8 +68,7 @@ def dbnet_mobilenetv3(pretrained=False, **kwargs):
6868

6969

7070
@register_model
71-
def dbnet_resnet18(pretrained=False, **kwargs):
72-
pretrained_backbone = not pretrained
71+
def dbnet_resnet18(pretrained=False, pretrained_backbone=True, **kwargs):
7372
model_config = {
7473
"backbone": {
7574
'name': 'det_resnet18',
@@ -98,8 +97,7 @@ def dbnet_resnet18(pretrained=False, **kwargs):
9897

9998

10099
@register_model
101-
def dbnet_resnet50(pretrained=False, **kwargs):
102-
pretrained_backbone = not pretrained
100+
def dbnet_resnet50(pretrained=False, pretrained_backbone=True, **kwargs):
103101
model_config = {
104102
"backbone": {
105103
'name': 'det_resnet50',
@@ -128,8 +126,7 @@ def dbnet_resnet50(pretrained=False, **kwargs):
128126

129127

130128
@register_model
131-
def dbnetpp_resnet50(pretrained=False, **kwargs):
132-
pretrained_backbone = not pretrained
129+
def dbnetpp_resnet50(pretrained=False, pretrained_backbone=True, **kwargs):
133130
model_config = {
134131
"backbone": {
135132
'name': 'det_resnet50',

mindocr/models/det_psenet.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ def __init__(self, config):
2929

3030

3131
@register_model
32-
def psenet_resnet152(pretrained=False, **kwargs):
33-
pretrained_backbone = not pretrained
32+
def psenet_resnet152(pretrained=False, pretrained_backbone=True, **kwargs):
3433
model_config = {
3534
"backbone": {
3635
'name': 'det_resnet152',

mindocr/models/kie_layoutxlm.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,18 @@ def construct(self, x):
4040

4141

4242
@register_model
43-
def layoutxlm_ser(pretrained: bool = True, use_visual_backbone: bool = True, use_float16: bool = False, **kwargs):
43+
def layoutxlm_ser(
44+
pretrained: bool = True,
45+
pretrained_backbone=False,
46+
use_visual_backbone: bool = True,
47+
use_float16: bool = False,
48+
**kwargs
49+
):
4450
model_config = {
4551
"type": "kie",
4652
"backbone": {
4753
"name": "layoutxlm",
48-
"pretrained": pretrained, # backbone pretrained
54+
"pretrained": pretrained_backbone, # backbone pretrained
4955
"use_visual_backbone": use_visual_backbone,
5056
"use_float16": use_float16,
5157
},

tests/ut/test_mindir_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_mindir_infer(model_name):
3737
outputs_mindir = model(x)
3838

3939
# get original ckpt outputs
40-
net = build_model(model_name, pretrained=True)
40+
net = build_model(model_name, pretrained=True, pretrained_backbone=False)
4141
outputs_ckpt = net(x)
4242

4343
for i, o in enumerate(outputs_mindir):

tests/ut/test_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
@pytest.mark.parametrize("pretrained", [True, False])
3232
def test_model_by_name(model_name, pretrained):
3333
print(model_name)
34-
build_model(model_name, pretrained=pretrained)
34+
pretrained_backbone = not pretrained
35+
build_model(model_name, pretrained=pretrained, pretrained_backbone=pretrained_backbone)
3536
print("model created")
3637

3738

tools/export.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,11 @@ def export(model_name_or_config, data_shape, local_ckpt_path, save_dir, is_dynam
9191
amp_level = "O0"
9292

9393
if local_ckpt_path:
94-
net = build_model(model_cfg, pretrained=False, ckpt_load_path=local_ckpt_path, amp_level=amp_level)
94+
net = build_model(
95+
model_cfg, pretrained=False, pretrained_backbone=False, ckpt_load_path=local_ckpt_path, amp_level=amp_level
96+
)
9597
else:
96-
net = build_model(model_cfg, pretrained=True, amp_level=amp_level)
98+
net = build_model(model_cfg, pretrained=True, pretrained_backbone=False, amp_level=amp_level)
9799

98100
logger.info(f"Set the AMP level of the model to be `{amp_level}`.")
99101

0 commit comments

Comments
 (0)