Skip to content

Commit dd0e8ed

Browse files
authored
[Feature] Add support for mask diagonal flip in TTA (open-mmlab#5403)
* Add support for mask diagonal flip in tta * Add unit test * Fix unit test * Fix unit test
1 parent 102f379 commit dd0e8ed

File tree

3 files changed

+16
-10
lines changed

3 files changed

+16
-10
lines changed

mmdet/core/post_processing/merge_augs.py

+3
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ def merge_aug_masks(aug_masks, img_metas, rcnn_test_cfg, weights=None):
137137
mask = mask[:, :, :, ::-1]
138138
elif flip_direction == 'vertical':
139139
mask = mask[:, :, ::-1, :]
140+
elif flip_direction == 'diagonal':
141+
mask = mask[:, :, :, ::-1]
142+
mask = mask[:, :, ::-1, :]
140143
else:
141144
raise ValueError(
142145
f"Invalid flipping direction '{flip_direction}'")

mmdet/datasets/pipelines/test_time_aug.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,10 @@ class MultiScaleFlipAug:
4444
scale_factor (float | list[float] | None): Scale factors for resizing.
4545
flip (bool): Whether apply flip augmentation. Default: False.
4646
flip_direction (str | list[str]): Flip augmentation directions,
47-
options are "horizontal" and "vertical". If flip_direction is list,
48-
multiple flip augmentations will be applied.
49-
It has no effect when flip == False. Default: "horizontal".
47+
options are "horizontal", "vertical" and "diagonal". If
48+
flip_direction is a list, multiple flip augmentations will be
49+
applied. It has no effect when flip == False. Default:
50+
"horizontal".
5051
"""
5152

5253
def __init__(self,

tests/test_data/test_pipelines/test_transform/test_models_aug_test.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def model_aug_test_template(cfg_file):
2020
# init test pipeline and set aug test
2121
load_cfg, multi_scale_cfg = cfg.test_pipeline
2222
multi_scale_cfg['flip'] = True
23+
multi_scale_cfg['flip_direction'] = ['horizontal', 'vertical', 'diagonal']
2324
multi_scale_cfg['img_scale'] = [(1333, 800), (800, 600), (640, 480)]
2425

2526
load = build_from_cfg(load_cfg, PIPELINES)
@@ -29,8 +30,8 @@ def model_aug_test_template(cfg_file):
2930
img_prefix=osp.join(osp.dirname(__file__), '../../../data'),
3031
img_info=dict(filename='color.jpg'))
3132
results = transform(load(results))
32-
assert len(results['img']) == 6
33-
assert len(results['img_metas']) == 6
33+
assert len(results['img']) == 12
34+
assert len(results['img_metas']) == 12
3435

3536
results['img'] = [collate([x]) for x in results['img']]
3637
results['img_metas'] = [collate([x]).data[0] for x in results['img_metas']]
@@ -56,14 +57,14 @@ def test_aug_test_size():
5657
transforms=[],
5758
img_scale=[(1333, 800), (800, 600), (640, 480)],
5859
flip=True,
59-
flip_direction=['horizontal', 'vertical'])
60+
flip_direction=['horizontal', 'vertical', 'diagonal'])
6061
multi_aug_test_module = build_from_cfg(transform, PIPELINES)
6162

6263
results = load(results)
6364
results = multi_aug_test_module(load(results))
64-
# len(["original", "horizontal", "vertical"]) *
65+
# len(["original", "horizontal", "vertical", "diagonal"]) *
6566
# len([(1333, 800), (800, 600), (640, 480)])
66-
assert len(results['img']) == 9
67+
assert len(results['img']) == 12
6768

6869

6970
def test_cascade_rcnn_aug_test():
@@ -107,6 +108,7 @@ def test_cornernet_aug_test():
107108
# init test pipeline and set aug test
108109
load_cfg, multi_scale_cfg = cfg.test_pipeline
109110
multi_scale_cfg['flip'] = True
111+
multi_scale_cfg['flip_direction'] = ['horizontal', 'vertical', 'diagonal']
110112
multi_scale_cfg['scale_factor'] = [0.5, 1.0, 2.0]
111113

112114
load = build_from_cfg(load_cfg, PIPELINES)
@@ -116,8 +118,8 @@ def test_cornernet_aug_test():
116118
img_prefix=osp.join(osp.dirname(__file__), '../../../data'),
117119
img_info=dict(filename='color.jpg'))
118120
results = transform(load(results))
119-
assert len(results['img']) == 6
120-
assert len(results['img_metas']) == 6
121+
assert len(results['img']) == 12
122+
assert len(results['img_metas']) == 12
121123

122124
results['img'] = [collate([x]) for x in results['img']]
123125
results['img_metas'] = [collate([x]).data[0] for x in results['img_metas']]

0 commit comments

Comments
 (0)