Skip to content

Commit 6215cc1

Browse files
authored
[Improvment] Add seed option for sampler (open-mmlab#4665)
1 parent f52dac8 commit 6215cc1

File tree

3 files changed

+21
-9
lines changed

3 files changed

+21
-9
lines changed

mmdet/datasets/builder.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,11 @@ def build_dataloader(dataset,
106106
# DistributedGroupSampler will definitely shuffle the data to satisfy
107107
# that images on each GPU are in the same group
108108
if shuffle:
109-
sampler = DistributedGroupSampler(dataset, samples_per_gpu,
110-
world_size, rank)
109+
sampler = DistributedGroupSampler(
110+
dataset, samples_per_gpu, world_size, rank, seed=seed)
111111
else:
112112
sampler = DistributedSampler(
113-
dataset, world_size, rank, shuffle=False)
113+
dataset, world_size, rank, shuffle=False, seed=seed)
114114
batch_size = samples_per_gpu
115115
num_workers = workers_per_gpu
116116
else:

mmdet/datasets/samplers/distributed_sampler.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,22 @@
66

77
class DistributedSampler(_DistributedSampler):
88

9-
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
10-
super().__init__(dataset, num_replicas=num_replicas, rank=rank)
11-
self.shuffle = shuffle
9+
def __init__(self,
10+
dataset,
11+
num_replicas=None,
12+
rank=None,
13+
shuffle=True,
14+
seed=0):
15+
super().__init__(
16+
dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
17+
# for the compatibility from PyTorch 1.3+
18+
self.seed = seed if seed is not None else 0
1219

1320
def __iter__(self):
1421
# deterministically shuffle based on epoch
1522
if self.shuffle:
1623
g = torch.Generator()
17-
g.manual_seed(self.epoch)
24+
g.manual_seed(self.epoch + self.seed)
1825
indices = torch.randperm(len(self.dataset), generator=g).tolist()
1926
else:
2027
indices = torch.arange(len(self.dataset)).tolist()

mmdet/datasets/samplers/group_sampler.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,17 @@ class DistributedGroupSampler(Sampler):
6464
num_replicas (optional): Number of processes participating in
6565
distributed training.
6666
rank (optional): Rank of the current process within num_replicas.
67+
seed (int, optional): random seed used to shuffle the sampler if
68+
``shuffle=True``. This number should be identical across all
69+
processes in the distributed group. Default: 0.
6770
"""
6871

6972
def __init__(self,
7073
dataset,
7174
samples_per_gpu=1,
7275
num_replicas=None,
73-
rank=None):
76+
rank=None,
77+
seed=0):
7478
_rank, _num_replicas = get_dist_info()
7579
if num_replicas is None:
7680
num_replicas = _num_replicas
@@ -81,6 +85,7 @@ def __init__(self,
8185
self.num_replicas = num_replicas
8286
self.rank = rank
8387
self.epoch = 0
88+
self.seed = seed if seed is not None else 0
8489

8590
assert hasattr(self.dataset, 'flag')
8691
self.flag = self.dataset.flag
@@ -96,7 +101,7 @@ def __init__(self,
96101
def __iter__(self):
97102
# deterministically shuffle based on epoch
98103
g = torch.Generator()
99-
g.manual_seed(self.epoch)
104+
g.manual_seed(self.epoch + self.seed)
100105

101106
indices = []
102107
for i, size in enumerate(self.group_sizes):

0 commit comments

Comments
 (0)