Skip to content

Commit 20892e6

Browse files
committed
完成第一版
1 parent b2f3e67 commit 20892e6

File tree

7 files changed

+62
-20
lines changed

7 files changed

+62
-20
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .dataset import Dataset
22
from .sampler import Sampler, BatchSampler, SequentialSampler, RandomSampler
33
from .dataloader import DataLoader
4+
from .collate import default_collate
45

5-
__all__ = ['Dataset', 'Sampler', 'BatchSampler', 'SequentialSampler', 'RandomSampler', 'DataLoader']
6+
__all__ = ['Dataset', 'Sampler', 'BatchSampler', 'SequentialSampler', 'RandomSampler', 'DataLoader', 'default_collate']

libv1/collate.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# -*- coding: utf-8 -*-
2+
# ======================================================
3+
# @Time : 20-12-26 下午4:42
4+
# @Author : huang ha
5+
# @Email :
6+
# @File : collate.py
7+
# @Comment:
8+
# ======================================================
9+
import torch
10+
11+
12+
def default_collate(batch):
13+
elem = batch[0]
14+
elem_type = type(elem)
15+
if isinstance(elem, torch.Tensor):
16+
return torch.stack(batch, 0)
17+
elif elem_type.__module__ == 'numpy':
18+
return default_collate([torch.as_tensor(b) for b in batch])
19+
else:
20+
raise NotImplementedError

lib/dataloader.py renamed to libv1/dataloader.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from torch.utils.data._utils.collate import default_collate
2-
32
from .sampler import BatchSampler, SequentialSampler, RandomSampler
43

54

@@ -36,15 +35,18 @@ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
3635
self.batch_size = batch_size
3736
self.drop_last = drop_last
3837
self.sampler = sampler
39-
self.batch_sampler = batch_sampler
38+
self.batch_sampler = iter(batch_sampler)
4039

4140
if collate_fn is None:
4241
collate_fn = default_collate
43-
4442
self.collate_fn = collate_fn
4543

46-
@property
47-
def _auto_collation(self):
48-
return self.batch_sampler is not None
44+
def __next__(self):
45+
index = next(self.batch_sampler)
46+
data = [self.dataset[idx] for idx in index]
47+
data=self.collate_fn(data)
48+
return data
4949

50+
def __iter__(self):
51+
return self
5052

File renamed without changes.

lib/sampler.py renamed to libv1/sampler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ def __len__(self):
1313

1414

1515
class BatchSampler(Sampler):
16-
def __init__(self, sampler, batch_size, drop_last, data_source):
17-
super(BatchSampler, self).__init__(data_source)
16+
def __init__(self, sampler, batch_size, drop_last):
1817
self.sampler = sampler
1918
self.batch_size = batch_size
2019
self.drop_last = drop_last

main.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
1-
# This is a sample Python script.
1+
from libv1 import DataLoader, default_collate
22

3-
# Press Shift+F10 to execute it or replace it with your code.
4-
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
53

6-
import torch.utils.data
7-
def print_hi(name):
8-
# Use a breakpoint in the code line below to debug your script.
9-
print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint.
4+
def demo_test_1():
5+
from simplev1_datatset import SimpleV1Dataset
6+
simple_dataset = SimpleV1Dataset()
7+
dataloader = DataLoader(simple_dataset, batch_size=2, collate_fn=default_collate)
8+
for data in dataloader:
9+
print(data)
1010

1111

12-
# Press the green button in the gutter to run the script.
1312
if __name__ == '__main__':
14-
print_hi('PyCharm')
15-
16-
# See PyCharm help at https://www.jetbrains.com/help/pycharm/
13+
demo_test_1()

simplev1_datatset.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# -*- coding: utf-8 -*-
2+
# ======================================================
3+
# @Time : 20-12-25 下午9:49
4+
# @Author : huang ha
5+
# @Email :
6+
# @File : simple1_datatset.py
7+
# @Comment:
8+
# ======================================================
9+
10+
from libv1 import Dataset
11+
import numpy as np
12+
13+
14+
class SimpleV1Dataset(Dataset):
15+
def __init__(self):
16+
# 伪造数据
17+
self.imgs = np.arange(0, 16).reshape(8, 2)
18+
19+
def __getitem__(self, index):
20+
return self.imgs[index]
21+
22+
def __len__(self):
23+
return self.imgs.shape[0]

0 commit comments

Comments
 (0)