Skip to content

Commit 29c1cc5

Browse files
Czm369ZwwWayne
authored andcommitted
[Refactor] Remove some code in mmdet/apis/train.py (open-mmlab#6576)
* remove some code about custom hooks in apis/train.py * files were modified by yapf
1 parent c0b2e80 commit 29c1cc5

File tree

1 file changed

+9
-19
lines changed

1 file changed

+9
-19
lines changed

mmdet/apis/train.py

+9-19
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
import torch
77
import torch.distributed as dist
88
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
9-
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
9+
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner,
1010
Fp16OptimizerHook, OptimizerHook, build_optimizer,
1111
build_runner, get_dist_info)
12-
from mmcv.utils import build_from_cfg
1312

1413
from mmdet.core import DistEvalHook, EvalHook
1514
from mmdet.datasets import (build_dataloader, build_dataset,
@@ -162,9 +161,14 @@ def train_detector(model,
162161
optimizer_config = cfg.optimizer_config
163162

164163
# register hooks
165-
runner.register_training_hooks(cfg.lr_config, optimizer_config,
166-
cfg.checkpoint_config, cfg.log_config,
167-
cfg.get('momentum_config', None))
164+
runner.register_training_hooks(
165+
cfg.lr_config,
166+
optimizer_config,
167+
cfg.checkpoint_config,
168+
cfg.log_config,
169+
cfg.get('momentum_config', None),
170+
custom_hooks_config=cfg.get('custom_hooks', None))
171+
168172
if distributed:
169173
if isinstance(runner, EpochBasedRunner):
170174
runner.register_hook(DistSamplerSeedHook())
@@ -192,20 +196,6 @@ def train_detector(model,
192196
runner.register_hook(
193197
eval_hook(val_dataloader, **eval_cfg), priority='LOW')
194198

195-
# user-defined hooks
196-
if cfg.get('custom_hooks', None):
197-
custom_hooks = cfg.custom_hooks
198-
assert isinstance(custom_hooks, list), \
199-
f'custom_hooks expect list type, but got {type(custom_hooks)}'
200-
for hook_cfg in cfg.custom_hooks:
201-
assert isinstance(hook_cfg, dict), \
202-
'Each item in custom_hooks expects dict type, but got ' \
203-
f'{type(hook_cfg)}'
204-
hook_cfg = hook_cfg.copy()
205-
priority = hook_cfg.pop('priority', 'NORMAL')
206-
hook = build_from_cfg(hook_cfg, HOOKS)
207-
runner.register_hook(hook, priority=priority)
208-
209199
if cfg.resume_from:
210200
runner.resume(cfg.resume_from)
211201
elif cfg.load_from:

0 commit comments

Comments
 (0)