@@ -12,36 +12,14 @@ def __len__(self):
12
12
raise NotImplementedError
13
13
14
14
15
- class BatchSampler (Sampler ):
16
- def __init__ (self , sampler , batch_size , drop_last ):
17
- self .sampler = sampler
18
- self .batch_size = batch_size
19
- self .drop_last = drop_last
20
-
21
- def __iter__ (self ):
22
- batch = []
23
- for idx in self .sampler :
24
- batch .append (idx )
25
- if len (batch ) == self .batch_size :
26
- yield batch
27
- batch = []
28
- if len (batch ) > 0 and not self .drop_last :
29
- yield batch
30
-
31
- def __len__ (self ):
32
- if self .drop_last :
33
- return len (self .sampler ) // self .batch_size
34
- else :
35
- return (len (self .sampler ) + self .batch_size - 1 ) // self .batch_size
36
-
37
-
38
15
class SequentialSampler (Sampler ):
39
16
40
17
def __init__ (self , data_source ):
41
18
super (SequentialSampler , self ).__init__ (data_source )
42
19
self .data_source = data_source
43
20
44
21
def __iter__ (self ):
22
+ # 返回迭代器,不然无法 for .. in ..
45
23
return iter (range (len (self .data_source )))
46
24
47
25
def __len__ (self ):
@@ -51,8 +29,11 @@ def __len__(self):
51
29
class RandomSampler (Sampler ):
52
30
def __init__ (self , data_source , replacement = False , num_samples = None ):
53
31
super (RandomSampler , self ).__init__ (data_source )
32
+ # 数据集
54
33
self .data_source = data_source
34
+ # 是否有放回抽象
55
35
self .replacement = replacement
36
+ # 采样长度,一般等于 data_source 长度
56
37
self ._num_samples = num_samples
57
38
58
39
if self ._num_samples is not None and not replacement :
@@ -64,19 +45,50 @@ def __init__(self, data_source, replacement=False, num_samples=None):
64
45
"value, but got num_samples={}" .format (self .num_samples ))
65
46
66
47
@property
67
- def num_samples (self ) -> int :
48
+ def num_samples (self ):
68
49
if self ._num_samples is None :
69
50
return len (self .data_source )
70
51
return self ._num_samples
71
52
72
53
def __iter__ (self ):
73
54
n = len (self .data_source )
55
+ # 通过 yield 关键字返回迭代器对象
74
56
if self .replacement :
57
+ # 有放回抽样
58
+ # 可以直接写 yield from torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist()
59
+ # 之所以按照每次生成32个,可能是因为想减少重复抽样概率 ?
75
60
for _ in range (self .num_samples // 32 ):
76
61
yield from torch .randint (high = n , size = (32 ,), dtype = torch .int64 ).tolist ()
77
62
yield from torch .randint (high = n , size = (self .num_samples % 32 ,), dtype = torch .int64 ).tolist ()
78
63
else :
64
+ # 无放回抽样
79
65
yield from torch .randperm (n ).tolist ()
80
66
81
67
def __len__ (self ):
82
68
return self .num_samples
69
+
70
+
71
+ class BatchSampler (Sampler ):
72
+ def __init__ (self , sampler , batch_size , drop_last ):
73
+ self .sampler = sampler
74
+ self .batch_size = batch_size
75
+ self .drop_last = drop_last
76
+
77
+ def __iter__ (self ):
78
+ batch = []
79
+ # 调用 sampler 内部的迭代器对象
80
+ for idx in self .sampler :
81
+ batch .append (idx )
82
+ # 如果已经得到了 batch 个 索引,则可以通过 yield 关键字生成生成器返回,得到迭代器对象
83
+ if len (batch ) == self .batch_size :
84
+ yield batch
85
+ batch = []
86
+ if len (batch ) > 0 and not self .drop_last :
87
+ yield batch
88
+
89
+ def __len__ (self ):
90
+ if self .drop_last :
91
+ # 如果最后的索引数不够一个 batch,则抛弃
92
+ return len (self .sampler ) // self .batch_size
93
+ else :
94
+ return (len (self .sampler ) + self .batch_size - 1 ) // self .batch_size
0 commit comments