forked from open-mmlab/mmdetection
-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathbuilder.py
199 lines (175 loc) · 7.53 KB
/
builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import platform
import random
import warnings
from functools import partial
import numpy as np
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import TORCH_VERSION, Registry, build_from_cfg, digit_version
from torch.utils.data import DataLoader
from .samplers import (DistributedGroupSampler, DistributedSampler,
GroupSampler, InfiniteBatchSampler,
InfiniteGroupBatchSampler)
if platform.system() != 'Windows':
# https://github.com/pytorch/pytorch/issues/973
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
base_soft_limit = rlimit[0]
hard_limit = rlimit[1]
soft_limit = min(max(4096, base_soft_limit), hard_limit)
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
DATASETS = Registry('dataset')
PIPELINES = Registry('pipeline')
def _concat_dataset(cfg, default_args=None):
from .dataset_wrappers import ConcatDataset
ann_files = cfg['ann_file']
img_prefixes = cfg.get('img_prefix', None)
seg_prefixes = cfg.get('seg_prefix', None)
proposal_files = cfg.get('proposal_file', None)
separate_eval = cfg.get('separate_eval', True)
datasets = []
num_dset = len(ann_files)
for i in range(num_dset):
data_cfg = copy.deepcopy(cfg)
# pop 'separate_eval' since it is not a valid key for common datasets.
if 'separate_eval' in data_cfg:
data_cfg.pop('separate_eval')
data_cfg['ann_file'] = ann_files[i]
if isinstance(img_prefixes, (list, tuple)):
data_cfg['img_prefix'] = img_prefixes[i]
if isinstance(seg_prefixes, (list, tuple)):
data_cfg['seg_prefix'] = seg_prefixes[i]
if isinstance(proposal_files, (list, tuple)):
data_cfg['proposal_file'] = proposal_files[i]
datasets.append(build_dataset(data_cfg, default_args))
return ConcatDataset(datasets, separate_eval)
def build_dataset(cfg, default_args=None):
from .dataset_wrappers import (ConcatDataset, RepeatDataset,
ClassBalancedDataset, MultiImageMixDataset)
if isinstance(cfg, (list, tuple)):
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
elif cfg['type'] == 'ConcatDataset':
dataset = ConcatDataset(
[build_dataset(c, default_args) for c in cfg['datasets']],
cfg.get('separate_eval', True))
elif cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(
build_dataset(cfg['dataset'], default_args), cfg['times'])
elif cfg['type'] == 'ClassBalancedDataset':
dataset = ClassBalancedDataset(
build_dataset(cfg['dataset'], default_args), cfg['oversample_thr'])
elif cfg['type'] == 'MultiImageMixDataset':
cp_cfg = copy.deepcopy(cfg)
cp_cfg['dataset'] = build_dataset(cp_cfg['dataset'])
cp_cfg.pop('type')
dataset = MultiImageMixDataset(**cp_cfg)
elif isinstance(cfg.get('ann_file'), (list, tuple)):
dataset = _concat_dataset(cfg, default_args)
else:
dataset = build_from_cfg(cfg, DATASETS, default_args)
return dataset
def build_dataloader(dataset,
samples_per_gpu,
workers_per_gpu,
num_gpus=1,
dist=True,
shuffle=True,
seed=None,
runner_type='EpochBasedRunner',
persistent_workers=False,
**kwargs):
"""Build PyTorch DataLoader.
In distributed training, each GPU/process has a dataloader.
In non-distributed training, there is only one dataloader for all GPUs.
Args:
dataset (Dataset): A PyTorch dataset.
samples_per_gpu (int): Number of training samples on each GPU, i.e.,
batch size of each GPU.
workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU.
num_gpus (int): Number of GPUs. Only used in non-distributed training.
dist (bool): Distributed training/test or not. Default: True.
shuffle (bool): Whether to shuffle the data at every epoch.
Default: True.
seed (int, Optional): Seed to be used. Default: None.
runner_type (str): Type of runner. Default: `EpochBasedRunner`
persistent_workers (bool): If True, the data loader will not shutdown
the worker processes after a dataset has been consumed once.
This allows to maintain the workers `Dataset` instances alive.
This argument is only valid when PyTorch>=1.7.0. Default: False.
kwargs: any keyword argument to be used to initialize DataLoader
Returns:
DataLoader: A PyTorch dataloader.
"""
rank, world_size = get_dist_info()
if dist:
# When model is :obj:`DistributedDataParallel`,
# `batch_size` of :obj:`dataloader` is the
# number of training samples on each GPU.
batch_size = samples_per_gpu
num_workers = workers_per_gpu
else:
# When model is obj:`DataParallel`
# the batch size is samples on all the GPUS
batch_size = num_gpus * samples_per_gpu
num_workers = num_gpus * workers_per_gpu
if runner_type == 'IterBasedRunner':
# this is a batch sampler, which can yield
# a mini-batch indices each time.
# it can be used in both `DataParallel` and
# `DistributedDataParallel`
if shuffle:
batch_sampler = InfiniteGroupBatchSampler(
dataset, batch_size, world_size, rank, seed=seed)
else:
batch_sampler = InfiniteBatchSampler(
dataset,
batch_size,
world_size,
rank,
seed=seed,
shuffle=False)
batch_size = 1
sampler = None
else:
if dist:
# DistributedGroupSampler will definitely shuffle the data to
# satisfy that images on each GPU are in the same group
if shuffle:
sampler = DistributedGroupSampler(
dataset, samples_per_gpu, world_size, rank, seed=seed)
else:
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=False, seed=seed)
else:
sampler = GroupSampler(dataset,
samples_per_gpu) if shuffle else None
batch_sampler = None
init_fn = partial(
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None
if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) >= digit_version('1.7.0')):
kwargs['persistent_workers'] = persistent_workers
elif persistent_workers is True:
warnings.warn('persistent_workers is invalid because your pytorch '
'version is lower than 1.7.0')
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
batch_sampler=batch_sampler,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=False,
worker_init_fn=init_fn,
**kwargs)
return data_loader
def worker_init_fn(worker_id, num_workers, rank, seed):
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)