Skip to content

Commit f1fbebd

Browse files
committed
完成第一版
1 parent 20892e6 commit f1fbebd

File tree

3 files changed

+98
-26
lines changed

3 files changed

+98
-26
lines changed

libv1/dataloader.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,13 @@ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
4141
collate_fn = default_collate
4242
self.collate_fn = collate_fn
4343

44+
# 核心
4445
def __next__(self):
4546
index = next(self.batch_sampler)
4647
data = [self.dataset[idx] for idx in index]
47-
data=self.collate_fn(data)
48+
data = self.collate_fn(data)
4849
return data
4950

51+
# 返回自身,因为自身实现了 __next__
5052
def __iter__(self):
5153
return self
52-

libv1/sampler.py

+36-24
Original file line numberDiff line numberDiff line change
@@ -12,36 +12,14 @@ def __len__(self):
1212
raise NotImplementedError
1313

1414

15-
class BatchSampler(Sampler):
16-
def __init__(self, sampler, batch_size, drop_last):
17-
self.sampler = sampler
18-
self.batch_size = batch_size
19-
self.drop_last = drop_last
20-
21-
def __iter__(self):
22-
batch = []
23-
for idx in self.sampler:
24-
batch.append(idx)
25-
if len(batch) == self.batch_size:
26-
yield batch
27-
batch = []
28-
if len(batch) > 0 and not self.drop_last:
29-
yield batch
30-
31-
def __len__(self):
32-
if self.drop_last:
33-
return len(self.sampler) // self.batch_size
34-
else:
35-
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
36-
37-
3815
class SequentialSampler(Sampler):
3916

4017
def __init__(self, data_source):
4118
super(SequentialSampler, self).__init__(data_source)
4219
self.data_source = data_source
4320

4421
def __iter__(self):
22+
# 返回迭代器,不然无法 for .. in ..
4523
return iter(range(len(self.data_source)))
4624

4725
def __len__(self):
@@ -51,8 +29,11 @@ def __len__(self):
5129
class RandomSampler(Sampler):
5230
def __init__(self, data_source, replacement=False, num_samples=None):
5331
super(RandomSampler, self).__init__(data_source)
32+
# 数据集
5433
self.data_source = data_source
34+
# 是否有放回抽象
5535
self.replacement = replacement
36+
# 采样长度,一般等于 data_source 长度
5637
self._num_samples = num_samples
5738

5839
if self._num_samples is not None and not replacement:
@@ -64,19 +45,50 @@ def __init__(self, data_source, replacement=False, num_samples=None):
6445
"value, but got num_samples={}".format(self.num_samples))
6546

6647
@property
67-
def num_samples(self) -> int:
48+
def num_samples(self):
6849
if self._num_samples is None:
6950
return len(self.data_source)
7051
return self._num_samples
7152

7253
def __iter__(self):
7354
n = len(self.data_source)
55+
# 通过 yield 关键字返回迭代器对象
7456
if self.replacement:
57+
# 有放回抽样
58+
# 可以直接写 yield from torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist()
59+
# 之所以按照每次生成32个,可能是因为想减少重复抽样概率 ?
7560
for _ in range(self.num_samples // 32):
7661
yield from torch.randint(high=n, size=(32,), dtype=torch.int64).tolist()
7762
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64).tolist()
7863
else:
64+
# 无放回抽样
7965
yield from torch.randperm(n).tolist()
8066

8167
def __len__(self):
8268
return self.num_samples
69+
70+
71+
class BatchSampler(Sampler):
72+
def __init__(self, sampler, batch_size, drop_last):
73+
self.sampler = sampler
74+
self.batch_size = batch_size
75+
self.drop_last = drop_last
76+
77+
def __iter__(self):
78+
batch = []
79+
# 调用 sampler 内部的迭代器对象
80+
for idx in self.sampler:
81+
batch.append(idx)
82+
# 如果已经得到了 batch 个 索引,则可以通过 yield 关键字生成生成器返回,得到迭代器对象
83+
if len(batch) == self.batch_size:
84+
yield batch
85+
batch = []
86+
if len(batch) > 0 and not self.drop_last:
87+
yield batch
88+
89+
def __len__(self):
90+
if self.drop_last:
91+
# 如果最后的索引数不够一个 batch,则抛弃
92+
return len(self.sampler) // self.batch_size
93+
else:
94+
return (len(self.sampler) + self.batch_size - 1) // self.batch_size

main.py

+59
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,64 @@ def demo_test_1():
99
print(data)
1010

1111

12+
def demo_test_2():
13+
from simplev1_datatset import SimpleV1Dataset
14+
from libv1 import SequentialSampler, RandomSampler
15+
from collections import Iterator, Iterable
16+
17+
simple_dataset = SimpleV1Dataset()
18+
dataloader = DataLoader(simple_dataset, batch_size=2, collate_fn=default_collate)
19+
20+
print(isinstance(simple_dataset, Iterable))
21+
print(isinstance(simple_dataset, Iterator))
22+
print(isinstance(iter(simple_dataset), Iterator))
23+
24+
print(isinstance(SequentialSampler(simple_dataset), Iterable))
25+
print(isinstance(SequentialSampler(simple_dataset), Iterator))
26+
print(isinstance(iter(SequentialSampler(simple_dataset)), Iterator))
27+
28+
# BatchSampler 和 RandomSampler 内部实现结构一样,结果也是一样
29+
print(isinstance(RandomSampler(simple_dataset), Iterable))
30+
print(isinstance(RandomSampler(simple_dataset), Iterator))
31+
print(isinstance(iter(RandomSampler(simple_dataset)), Iterator))
32+
33+
print(isinstance(dataloader, Iterator))
34+
35+
36+
def demo_test_3():
37+
class DataLoader(object):
38+
def __init__(self):
39+
self.dataset = [[img0, target0], [img1, target1], [img2, target2], ..., [img99, target99]]
40+
self.sampler = [0, 1, 2, 3, 4, ..., 99]
41+
self.batch_size = 4
42+
self.index = 0
43+
44+
def collate_fn(self, data):
45+
batch_img = torch.stack(data[0], 0)
46+
batch_target = torch.stack(data[1], 0)
47+
return batch_img, batch_target
48+
49+
def __next__(self):
50+
# 0.batch_index 输出
51+
i = 0
52+
batch_index = []
53+
while i < self.batch_size:
54+
batch_index.append(self.sampler[self.index])
55+
self.index += 1
56+
i += 1
57+
58+
# 1.得到 batch 个数据了
59+
data = [self.dataset[idx] for idx in batch_index]
60+
61+
# 2.collate_fn 在 batch 维度拼接输出
62+
batch_data = self.collate_fn(data)
63+
return batch_data
64+
65+
def __iter__(self):
66+
return self
67+
68+
1269
if __name__ == '__main__':
1370
demo_test_1()
71+
# demo_test_2()
72+

0 commit comments

Comments
 (0)