Skip to content

Commit bf01bdd

Browse files
authored
Fix bc break of fp16 (open-mmlab#3822)
* fix bc break of mmdet.core.fp16, refactor import of wrap_fp16_model * changed warning method * added docstring for Depr_Fp16OptimizerHook * docformatter * fix docstring * changed names from depr to deprecated
1 parent 9c95543 commit bf01bdd

File tree

7 files changed

+63
-16
lines changed

7 files changed

+63
-16
lines changed

.dev_scripts/batch_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
import torch
1919
from mmcv import Config, get_logger
2020
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
21-
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
21+
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
22+
wrap_fp16_model)
2223

2324
from mmdet.apis import multi_gpu_test, single_gpu_test
24-
from mmdet.core import wrap_fp16_model
2525
from mmdet.datasets import (build_dataloader, build_dataset,
2626
replace_ImageToTensor)
2727
from mmdet.models import build_detector

mmdet/core/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .anchor import * # noqa: F401, F403
22
from .bbox import * # noqa: F401, F403
33
from .evaluation import * # noqa: F401, F403
4+
from .fp16 import * # noqa: F401, F403
45
from .mask import * # noqa: F401, F403
56
from .post_processing import * # noqa: F401, F403
67
from .utils import * # noqa: F401, F403

mmdet/core/fp16/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .deprecated_fp16_utils import \
2+
DeprecatedFp16OptimizerHook as Fp16OptimizerHook
3+
from .deprecated_fp16_utils import deprecated_auto_fp16 as auto_fp16
4+
from .deprecated_fp16_utils import deprecated_force_fp32 as force_fp32
5+
from .deprecated_fp16_utils import \
6+
deprecated_wrap_fp16_model as wrap_fp16_model
7+
8+
__all__ = ['auto_fp16', 'force_fp32', 'Fp16OptimizerHook', 'wrap_fp16_model']
+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import warnings
2+
3+
from mmcv.runner import (Fp16OptimizerHook, auto_fp16, force_fp32,
4+
wrap_fp16_model)
5+
6+
7+
class DeprecatedFp16OptimizerHook(Fp16OptimizerHook):
8+
"""A wrapper class for the FP16 optimizer hook. This class wraps
9+
:class:`Fp16OptimizerHook` in `mmcv.runner` and shows a warning that the
10+
:class:`Fp16OptimizerHook` from `mmdet.core` will be deprecated.
11+
12+
Refer to :class:`Fp16OptimizerHook` in `mmcv.runner` for more details.
13+
14+
Args:
15+
loss_scale (float): Scale factor multiplied with loss.
16+
"""
17+
18+
def __init__(*args, **kwargs):
19+
super().__init__(*args, **kwargs)
20+
warnings.warn(
21+
'Importing Fp16OptimizerHook from "mmdet.core" will be '
22+
'deprecated in the future. Please import them from "mmcv.runner" '
23+
'instead')
24+
25+
26+
def deprecated_auto_fp16(*args, **kwargs):
27+
warnings.warn(
28+
'Importing auto_fp16 from "mmdet.core" will be '
29+
'deprecated in the future. Please import them from "mmcv.runner" '
30+
'instead')
31+
return auto_fp16(*args, **kwargs)
32+
33+
34+
def deprecated_force_fp32(*args, **kwargs):
35+
warnings.warn(
36+
'Importing force_fp32 from "mmdet.core" will be '
37+
'deprecated in the future. Please import them from "mmcv.runner" '
38+
'instead')
39+
return force_fp32(*args, **kwargs)
40+
41+
42+
def deprecated_wrap_fp16_model(*args, **kwargs):
43+
warnings.warn(
44+
'Importing wrap_fp16_model from "mmdet.core" will be '
45+
'deprecated in the future. Please import them from "mmcv.runner" '
46+
'instead')
47+
wrap_fp16_model(*args, **kwargs)

tools/benchmark.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
from mmcv import Config
66
from mmcv.cnn import fuse_conv_bn
77
from mmcv.parallel import MMDataParallel
8-
from mmcv.runner import load_checkpoint
9-
from mmcv.runner.fp16_utils import wrap_fp16_model
8+
from mmcv.runner import load_checkpoint, wrap_fp16_model
109

1110
from mmdet.datasets import (build_dataloader, build_dataset,
1211
replace_ImageToTensor)

tools/test.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,14 @@
77
from mmcv import Config, DictAction
88
from mmcv.cnn import fuse_conv_bn
99
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
10-
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
10+
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
11+
wrap_fp16_model)
1112

1213
from mmdet.apis import multi_gpu_test, single_gpu_test
1314
from mmdet.datasets import (build_dataloader, build_dataset,
1415
replace_ImageToTensor)
1516
from mmdet.models import build_detector
1617

17-
try:
18-
from mmcv.runner import wrap_fp16_model
19-
except ImportError:
20-
from mmcv.runner.fp16_utils import wrap_fp16_model
21-
2218

2319
def parse_args():
2420
parser = argparse.ArgumentParser(

tools/test_robustness.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import torch
1010
import torch.distributed as dist
1111
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
12-
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
12+
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
13+
wrap_fp16_model)
1314
from pycocotools.coco import COCO
1415
from pycocotools.cocoeval import COCOeval
1516
from robustness_eval import get_results
@@ -20,11 +21,6 @@
2021
from mmdet.datasets import build_dataloader, build_dataset
2122
from mmdet.models import build_detector
2223

23-
try:
24-
from mmcv.runner import wrap_fp16_model
25-
except ImportError:
26-
from mmcv.runner.fp16_utils import wrap_fp16_model
27-
2824

2925
def coco_eval_with_return(result_files,
3026
result_types,

0 commit comments

Comments
 (0)