|
6 | 6 | import torch
|
7 | 7 | import torch.distributed as dist
|
8 | 8 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
9 |
| -from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner, |
| 9 | +from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, |
10 | 10 | Fp16OptimizerHook, OptimizerHook, build_optimizer,
|
11 | 11 | build_runner, get_dist_info)
|
12 |
| -from mmcv.utils import build_from_cfg |
13 | 12 |
|
14 | 13 | from mmdet.core import DistEvalHook, EvalHook
|
15 | 14 | from mmdet.datasets import (build_dataloader, build_dataset,
|
@@ -162,9 +161,14 @@ def train_detector(model,
|
162 | 161 | optimizer_config = cfg.optimizer_config
|
163 | 162 |
|
164 | 163 | # register hooks
|
165 |
| - runner.register_training_hooks(cfg.lr_config, optimizer_config, |
166 |
| - cfg.checkpoint_config, cfg.log_config, |
167 |
| - cfg.get('momentum_config', None)) |
| 164 | + runner.register_training_hooks( |
| 165 | + cfg.lr_config, |
| 166 | + optimizer_config, |
| 167 | + cfg.checkpoint_config, |
| 168 | + cfg.log_config, |
| 169 | + cfg.get('momentum_config', None), |
| 170 | + custom_hooks_config=cfg.get('custom_hooks', None)) |
| 171 | + |
168 | 172 | if distributed:
|
169 | 173 | if isinstance(runner, EpochBasedRunner):
|
170 | 174 | runner.register_hook(DistSamplerSeedHook())
|
@@ -192,20 +196,6 @@ def train_detector(model,
|
192 | 196 | runner.register_hook(
|
193 | 197 | eval_hook(val_dataloader, **eval_cfg), priority='LOW')
|
194 | 198 |
|
195 |
| - # user-defined hooks |
196 |
| - if cfg.get('custom_hooks', None): |
197 |
| - custom_hooks = cfg.custom_hooks |
198 |
| - assert isinstance(custom_hooks, list), \ |
199 |
| - f'custom_hooks expect list type, but got {type(custom_hooks)}' |
200 |
| - for hook_cfg in cfg.custom_hooks: |
201 |
| - assert isinstance(hook_cfg, dict), \ |
202 |
| - 'Each item in custom_hooks expects dict type, but got ' \ |
203 |
| - f'{type(hook_cfg)}' |
204 |
| - hook_cfg = hook_cfg.copy() |
205 |
| - priority = hook_cfg.pop('priority', 'NORMAL') |
206 |
| - hook = build_from_cfg(hook_cfg, HOOKS) |
207 |
| - runner.register_hook(hook, priority=priority) |
208 |
| - |
209 | 199 | if cfg.resume_from:
|
210 | 200 | runner.resume(cfg.resume_from)
|
211 | 201 | elif cfg.load_from:
|
|
0 commit comments