Skip to content

Commit ebf7476

Browse files
authored
[Enhance] Add setup multi-processing both in train and test. (open-mmlab#7036)
* [Enhance] Add setup multi-processing both in train and test. * switch to torch mp
1 parent 75f26c8 commit ebf7476

File tree

5 files changed

+124
-39
lines changed

5 files changed

+124
-39
lines changed

mmdet/utils/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
from .collect_env import collect_env
33
from .logger import get_root_logger
44
from .misc import find_latest_checkpoint
5+
from .setup_env import setup_multi_processes
56

67
__all__ = [
7-
'get_root_logger',
8-
'collect_env',
9-
'find_latest_checkpoint',
8+
'get_root_logger', 'collect_env', 'find_latest_checkpoint',
9+
'setup_multi_processes'
1010
]

mmdet/utils/setup_env.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os
3+
import platform
4+
import warnings
5+
6+
import cv2
7+
import torch.multiprocessing as mp
8+
9+
10+
def setup_multi_processes(cfg):
11+
"""Setup multi-processing environment variables."""
12+
# set multi-process start method as `fork` to speed up the training
13+
if platform.system() != 'Windows':
14+
mp_start_method = cfg.get('mp_start_method', 'fork')
15+
current_method = mp.get_start_method(allow_none=True)
16+
if current_method is not None and current_method != mp_start_method:
17+
warnings.warn(
18+
f'Multi-processing start method `{mp_start_method}` is '
19+
f'different from the previous setting `{current_method}`.'
20+
f'It will be force set to `{mp_start_method}`. You can change '
21+
f'this behavior by changing `mp_start_method` in your config.')
22+
mp.set_start_method(mp_start_method, force=True)
23+
24+
# disable opencv multithreading to avoid system being overloaded
25+
opencv_num_threads = cfg.get('opencv_num_threads', 0)
26+
cv2.setNumThreads(opencv_num_threads)
27+
28+
# setup OMP threads
29+
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
30+
if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
31+
omp_num_threads = 1
32+
warnings.warn(
33+
f'Setting OMP_NUM_THREADS environment variable for each process '
34+
f'to be {omp_num_threads} in default, to avoid your system being '
35+
f'overloaded, please further tune the variable for optimal '
36+
f'performance in your application as needed.')
37+
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
38+
39+
# setup MKL threads
40+
if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
41+
mkl_num_threads = 1
42+
warnings.warn(
43+
f'Setting MKL_NUM_THREADS environment variable for each process '
44+
f'to be {mkl_num_threads} in default, to avoid your system being '
45+
f'overloaded, please further tune the variable for optimal '
46+
f'performance in your application as needed.')
47+
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)

tests/test_utils/test_setup_env.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import multiprocessing as mp
3+
import os
4+
import platform
5+
6+
import cv2
7+
from mmcv import Config
8+
9+
from mmdet.utils import setup_multi_processes
10+
11+
12+
def test_setup_multi_processes():
13+
# temp save system setting
14+
sys_start_mehod = mp.get_start_method(allow_none=True)
15+
sys_cv_threads = cv2.getNumThreads()
16+
# pop and temp save system env vars
17+
sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None)
18+
sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None)
19+
20+
# test config without setting env
21+
config = dict(data=dict(workers_per_gpu=2))
22+
cfg = Config(config)
23+
setup_multi_processes(cfg)
24+
assert os.getenv('OMP_NUM_THREADS') == '1'
25+
assert os.getenv('MKL_NUM_THREADS') == '1'
26+
# when set to 0, the num threads will be 1
27+
assert cv2.getNumThreads() == 1
28+
if platform.system() != 'Windows':
29+
assert mp.get_start_method() == 'fork'
30+
31+
# test num workers <= 1
32+
os.environ.pop('OMP_NUM_THREADS')
33+
os.environ.pop('MKL_NUM_THREADS')
34+
config = dict(data=dict(workers_per_gpu=0))
35+
cfg = Config(config)
36+
setup_multi_processes(cfg)
37+
assert 'OMP_NUM_THREADS' not in os.environ
38+
assert 'MKL_NUM_THREADS' not in os.environ
39+
40+
# test manually set env var
41+
os.environ['OMP_NUM_THREADS'] = '4'
42+
config = dict(data=dict(workers_per_gpu=2))
43+
cfg = Config(config)
44+
setup_multi_processes(cfg)
45+
assert os.getenv('OMP_NUM_THREADS') == '4'
46+
47+
# test manually set opencv threads and mp start method
48+
config = dict(
49+
data=dict(workers_per_gpu=2),
50+
opencv_num_threads=4,
51+
mp_start_method='spawn')
52+
cfg = Config(config)
53+
setup_multi_processes(cfg)
54+
assert cv2.getNumThreads() == 4
55+
assert mp.get_start_method() == 'spawn'
56+
57+
# revert setting to avoid affecting other programs
58+
if sys_start_mehod:
59+
mp.set_start_method(sys_start_mehod, force=True)
60+
cv2.setNumThreads(sys_cv_threads)
61+
if sys_omp_threads:
62+
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
63+
else:
64+
os.environ.pop('OMP_NUM_THREADS')
65+
if sys_mkl_threads:
66+
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
67+
else:
68+
os.environ.pop('MKL_NUM_THREADS')

tools/test.py

+5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from mmdet.datasets import (build_dataloader, build_dataset,
1818
replace_ImageToTensor)
1919
from mmdet.models import build_detector
20+
from mmdet.utils import setup_multi_processes
2021

2122

2223
def parse_args():
@@ -128,6 +129,10 @@ def main():
128129
cfg = Config.fromfile(args.config)
129130
if args.cfg_options is not None:
130131
cfg.merge_from_dict(args.cfg_options)
132+
133+
# set multi-process settings
134+
setup_multi_processes(cfg)
135+
131136
# set cudnn_benchmark
132137
if cfg.get('cudnn_benchmark', False):
133138
torch.backends.cudnn.benchmark = True

tools/train.py

+1-36
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import argparse
33
import copy
4-
import multiprocessing as mp
54
import os
65
import os.path as osp
7-
import platform
86
import time
97
import warnings
108

11-
import cv2
129
import mmcv
1310
import torch
1411
from mmcv import Config, DictAction
@@ -19,7 +16,7 @@
1916
from mmdet.apis import init_random_seed, set_random_seed, train_detector
2017
from mmdet.datasets import build_dataset
2118
from mmdet.models import build_detector
22-
from mmdet.utils import collect_env, get_root_logger
19+
from mmdet.utils import collect_env, get_root_logger, setup_multi_processes
2320

2421

2522
def parse_args():
@@ -91,38 +88,6 @@ def parse_args():
9188
return args
9289

9390

94-
def setup_multi_processes(cfg):
95-
# set multi-process start method as `fork` to speed up the training
96-
if platform.system() != 'Windows':
97-
mp_start_method = cfg.get('mp_start_method', 'fork')
98-
mp.set_start_method(mp_start_method)
99-
100-
# disable opencv multithreading to avoid system being overloaded
101-
opencv_num_threads = cfg.get('opencv_num_threads', 0)
102-
cv2.setNumThreads(opencv_num_threads)
103-
104-
# setup OMP threads
105-
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
106-
if ('OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1):
107-
omp_num_threads = 1
108-
warnings.warn(
109-
f'Setting OMP_NUM_THREADS environment variable for each process '
110-
f'to be {omp_num_threads} in default, to avoid your system being '
111-
f'overloaded, please further tune the variable for optimal '
112-
f'performance in your application as needed.')
113-
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
114-
115-
# setup MKL threads
116-
if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
117-
mkl_num_threads = 1
118-
warnings.warn(
119-
f'Setting MKL_NUM_THREADS environment variable for each process '
120-
f'to be {mkl_num_threads} in default, to avoid your system being '
121-
f'overloaded, please further tune the variable for optimal '
122-
f'performance in your application as needed.')
123-
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
124-
125-
12691
def main():
12792
args = parse_args()
12893

0 commit comments

Comments
 (0)