Skip to content

Commit 14d250f

Browse files
[Enhance]: use isinstance method to get loading pipeline (open-mmlab#4619)
* use isinstance method to get loading pipeline * Fix isinstance error * Add unit test * Fix lint * Fix lint Co-authored-by: hhaAndroid <[email protected]>
1 parent e1599e7 commit 14d250f

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

mmdet/datasets/utils.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from mmcv.cnn import VGG
55
from mmcv.runner.hooks import HOOKS, Hook
66

7+
from mmdet.datasets.builder import PIPELINES
8+
from mmdet.datasets.pipelines import LoadAnnotations, LoadImageFromFile
79
from mmdet.models.dense_heads import GARPNHead, RPNHead
810
from mmdet.models.roi_heads.mask_heads import FusedSemanticHead
911

@@ -98,7 +100,10 @@ def get_loading_pipeline(pipeline):
98100
"""
99101
loading_pipeline_cfg = []
100102
for cfg in pipeline:
101-
if cfg['type'].startswith('Load'):
103+
obj_cls = PIPELINES.get(cfg['type'])
104+
# TODO:use more elegant way to distinguish loading modules
105+
if obj_cls is not None and obj_cls in (LoadImageFromFile,
106+
LoadAnnotations):
102107
loading_pipeline_cfg.append(cfg)
103108
assert len(loading_pipeline_cfg) == 2, \
104109
'The data pipeline in your config file must include ' \

tests/test_data/test_utils.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from mmdet.datasets import replace_ImageToTensor
3+
from mmdet.datasets import get_loading_pipeline, replace_ImageToTensor
44

55

66
def test_replace_ImageToTensor():
@@ -59,3 +59,21 @@ def test_replace_ImageToTensor():
5959
]
6060
with pytest.warns(UserWarning):
6161
assert expected_pipelines == replace_ImageToTensor(pipelines)
62+
63+
64+
def test_get_loading_pipeline():
65+
pipelines = [
66+
dict(type='LoadImageFromFile'),
67+
dict(type='LoadAnnotations', with_bbox=True),
68+
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
69+
dict(type='RandomFlip', flip_ratio=0.5),
70+
dict(type='Pad', size_divisor=32),
71+
dict(type='DefaultFormatBundle'),
72+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
73+
]
74+
expected_pipelines = [
75+
dict(type='LoadImageFromFile'),
76+
dict(type='LoadAnnotations', with_bbox=True)
77+
]
78+
assert expected_pipelines == \
79+
get_loading_pipeline(pipelines)

0 commit comments

Comments
 (0)