Skip to content

Commit b2f3e67

Browse files
committed
第一次提交
1 parent 9ab2874 commit b2f3e67

File tree

6 files changed

+161
-2
lines changed

6 files changed

+161
-2
lines changed

.gitignore

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ dist/
1414
downloads/
1515
eggs/
1616
.eggs/
17-
lib/
18-
lib64/
1917
parts/
2018
sdist/
2119
var/

lib/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .dataset import Dataset
2+
from .sampler import Sampler, BatchSampler, SequentialSampler, RandomSampler
3+
from .dataloader import DataLoader
4+
5+
__all__ = ['Dataset', 'Sampler', 'BatchSampler', 'SequentialSampler', 'RandomSampler', 'DataLoader']

lib/dataloader.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from torch.utils.data._utils.collate import default_collate
2+
3+
from .sampler import BatchSampler, SequentialSampler, RandomSampler
4+
5+
6+
class DataLoader(object):
7+
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
8+
batch_sampler=None, collate_fn=None, drop_last=False):
9+
self.dataset = dataset
10+
11+
# 因为这两个功能是冲突的,假设shuffle=True,但是sampler里面是SequentialSampler,那么就违背设计思想了
12+
if sampler is not None and shuffle:
13+
raise ValueError('sampler option is mutually exclusive with '
14+
'shuffle')
15+
16+
if batch_sampler is not None:
17+
# 一旦设置了batch_sampler,那么batch_size、shuffle、sampler和drop_last四个参数就不能传入
18+
# 因为这4个参数功能和batch_sampler功能冲突了
19+
if batch_size != 1 or shuffle or sampler is not None or drop_last:
20+
raise ValueError('batch_sampler option is mutually exclusive '
21+
'with batch_size, shuffle, sampler, and '
22+
'drop_last')
23+
batch_size = None
24+
drop_last = False
25+
26+
if sampler is None:
27+
if shuffle:
28+
sampler = RandomSampler(dataset)
29+
else:
30+
sampler = SequentialSampler(dataset)
31+
32+
# 也就是说batch_sampler必须要存在,你如果没有设置,那么采用默认类
33+
if batch_sampler is None:
34+
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
35+
36+
self.batch_size = batch_size
37+
self.drop_last = drop_last
38+
self.sampler = sampler
39+
self.batch_sampler = batch_sampler
40+
41+
if collate_fn is None:
42+
collate_fn = default_collate
43+
44+
self.collate_fn = collate_fn
45+
46+
@property
47+
def _auto_collation(self):
48+
return self.batch_sampler is not None
49+
50+

lib/dataset.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
class Dataset(object):
2+
def __getitem__(self, index):
3+
raise NotImplementedError
4+
5+
def __len__(self):
6+
raise NotImplementedError
7+

lib/sampler.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import torch
2+
3+
4+
class Sampler(object):
5+
def __init__(self, data_source):
6+
pass
7+
8+
def __iter__(self):
9+
raise NotImplementedError
10+
11+
def __len__(self):
12+
raise NotImplementedError
13+
14+
15+
class BatchSampler(Sampler):
16+
def __init__(self, sampler, batch_size, drop_last, data_source):
17+
super(BatchSampler, self).__init__(data_source)
18+
self.sampler = sampler
19+
self.batch_size = batch_size
20+
self.drop_last = drop_last
21+
22+
def __iter__(self):
23+
batch = []
24+
for idx in self.sampler:
25+
batch.append(idx)
26+
if len(batch) == self.batch_size:
27+
yield batch
28+
batch = []
29+
if len(batch) > 0 and not self.drop_last:
30+
yield batch
31+
32+
def __len__(self):
33+
if self.drop_last:
34+
return len(self.sampler) // self.batch_size
35+
else:
36+
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
37+
38+
39+
class SequentialSampler(Sampler):
40+
41+
def __init__(self, data_source):
42+
super(SequentialSampler, self).__init__(data_source)
43+
self.data_source = data_source
44+
45+
def __iter__(self):
46+
return iter(range(len(self.data_source)))
47+
48+
def __len__(self):
49+
return len(self.data_source)
50+
51+
52+
class RandomSampler(Sampler):
53+
def __init__(self, data_source, replacement=False, num_samples=None):
54+
super(RandomSampler, self).__init__(data_source)
55+
self.data_source = data_source
56+
self.replacement = replacement
57+
self._num_samples = num_samples
58+
59+
if self._num_samples is not None and not replacement:
60+
raise ValueError("With replacement=False, num_samples should not be specified, "
61+
"since a random permute will be performed.")
62+
63+
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
64+
raise ValueError("num_samples should be a positive integer "
65+
"value, but got num_samples={}".format(self.num_samples))
66+
67+
@property
68+
def num_samples(self) -> int:
69+
if self._num_samples is None:
70+
return len(self.data_source)
71+
return self._num_samples
72+
73+
def __iter__(self):
74+
n = len(self.data_source)
75+
if self.replacement:
76+
for _ in range(self.num_samples // 32):
77+
yield from torch.randint(high=n, size=(32,), dtype=torch.int64).tolist()
78+
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64).tolist()
79+
else:
80+
yield from torch.randperm(n).tolist()
81+
82+
def __len__(self):
83+
return self.num_samples

main.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# This is a sample Python script.
2+
3+
# Press Shift+F10 to execute it or replace it with your code.
4+
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
5+
6+
import torch.utils.data
7+
def print_hi(name):
8+
# Use a breakpoint in the code line below to debug your script.
9+
print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint.
10+
11+
12+
# Press the green button in the gutter to run the script.
13+
if __name__ == '__main__':
14+
print_hi('PyCharm')
15+
16+
# See PyCharm help at https://www.jetbrains.com/help/pycharm/

0 commit comments

Comments
 (0)