1+ import numpy as np
2+
3+
4+ def pnorm (p ):
5+ if not isinstance (p , (list , tuple )):
6+ raise ValueError (f'probability map { p } must be of type (list, tuple), not { type (p )} ' )
7+ ptot = np .sum (p )
8+ if not np .allclose (ptot , 1 ):
9+ p = [i / ptot for i in p ]
10+ return p
11+
12+
13+ def multinomial (num_samples , p ):
14+ valid_p = pnorm (p )
15+ res = np .random .multinomial (num_samples , valid_p )
16+ return res
17+
18+
19+ class Sampler (object ):
20+ r"""Base class for all Samplers.
21+ Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
22+ way to iterate over indices of dataset elements, and a :meth:`__len__` method
23+ that returns the length of the returned iterators.
24+ .. note:: The :meth:`__len__` method isn't strictly required by
25+ :class:`~torch.utils.data.DataLoader`, but is expected in any
26+ calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
27+ """
28+
29+ def __init__ (self , data_source ):
30+ pass
31+
32+ def __iter__ (self ):
33+ raise NotImplementedError
34+
35+ # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
36+ #
37+ # Many times we have an abstract class representing a collection/iterable of
38+ # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally
39+ # implementing a `__len__` method. In such cases, we must make sure to not
40+ # provide a default implementation, because both straightforward default
41+ # implementations have their issues:
42+ #
43+ # + `return NotImplemented`:
44+ # Calling `len(subclass_instance)` raises:
45+ # TypeError: 'NotImplementedType' object cannot be interpreted as an integer
46+ #
47+ # + `raise NotImplementedError()`:
48+ # This prevents triggering some fallback behavior. E.g., the built-in
49+ # `list(X)` tries to call `len(X)` first, and executes a different code
50+ # path if the method is not found or `NotImplemented` is returned, while
51+ # raising an `NotImplementedError` will propagate and and make the call
52+ # fail where it could have use `__iter__` to complete the call.
53+ #
54+ # Thus, the only two sensible things to do are
55+ #
56+ # + **not** provide a default `__len__`.
57+ #
58+ # + raise a `TypeError` instead, which is what Python uses when users call
59+ # a method that is not defined on an object.
60+ # (@ssnl verifies that this works on at least Python 3.7.)
61+
62+
63+ class SequentialSampler (Sampler ):
64+ r"""Samples elements sequentially, always in the same order.
65+ Arguments:
66+ data_source (Dataset): dataset to sample from
67+ """
68+
69+ def __init__ (self , data_source ):
70+ self .data_source = data_source
71+
72+ def __iter__ (self ):
73+ return iter (self .data_source .keys ())
74+
75+ def __len__ (self ):
76+ return len (self .data_source )
77+
78+
79+ class RandomSampler (Sampler ):
80+ r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
81+ If with replacement, then user can specify :attr:`num_samples` to draw.
82+ Arguments:
83+ data_source (Dataset): dataset to sample from
84+ replacement (bool): samples are drawn with replacement if ``True``, default=``False``
85+ num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
86+ is supposed to be specified only when `replacement` is ``True``.
87+ """
88+
89+ def __init__ (self , data_source , replacement = False , num_samples = None ):
90+ self .data_source = data_source
91+ self .replacement = replacement
92+ self ._num_samples = num_samples
93+
94+ if not isinstance (self .replacement , bool ):
95+ raise ValueError ("replacement should be a boolean value, but got "
96+ "replacement={}" .format (self .replacement ))
97+
98+ if self ._num_samples is not None and not replacement :
99+ raise ValueError ("With replacement=False, num_samples should not be specified, "
100+ "since a random permute will be performed." )
101+
102+ if not isinstance (self .num_samples , int ) or self .num_samples <= 0 :
103+ raise ValueError ("num_samples should be a positive integer "
104+ "value, but got num_samples={}" .format (self .num_samples ))
105+
106+ @property
107+ def num_samples (self ):
108+ # dataset size might change at runtime
109+ if self ._num_samples is None :
110+ return len (self .data_source )
111+ return self ._num_samples
112+
113+ def __iter__ (self ):
114+ n = len (self .data_source )
115+ keys = list (self .data_source .keys ())
116+ if self .replacement :
117+ choose = np .random .randint (low = 0 , high = n , size = (self .num_samples ,), dtype = np .int64 ).tolist ()
118+ return (keys [x ] for x in choose )
119+ choose = np .random .permutation (self .num_samples )
120+ return (keys [x ] for x in choose )
121+
122+ def __len__ (self ):
123+ return self .num_samples
124+
125+
126+ class SubsetRandomSampler (Sampler ):
127+ r"""Samples elements randomly from a given list of indices, without replacement.
128+ Arguments:
129+ indices (sequence): a sequence of indices
130+ """
131+
132+ def __init__ (self , indices ):
133+ self .indices = indices
134+
135+ def __iter__ (self ):
136+ choose = np .random .permutation (len (self .indices ))
137+ return (self .indices [x ] for x in choose )
138+
139+ def __len__ (self ):
140+ return len (self .indices )
141+
142+
143+ class WeightedRandomSampler (Sampler ):
144+ r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
145+ Args:
146+ weights (sequence) : a sequence of weights, not necessary summing up to one
147+ num_samples (int): number of samples to draw
148+ replacement (bool): if ``True``, samples are drawn with replacement.
149+ If not, they are drawn without replacement, which means that when a
150+ sample index is drawn for a row, it cannot be drawn again for that row.
151+ Example:
152+ >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
153+ [0, 0, 0, 1, 0]
154+ >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
155+ [0, 1, 4, 3, 2]
156+ """
157+
158+ def __init__ (self , weights , num_samples ):
159+ if not isinstance (num_samples , int ) or isinstance (num_samples , bool ) or \
160+ num_samples <= 0 :
161+ raise ValueError ("num_samples should be a positive integer "
162+ "value, but got num_samples={}" .format (num_samples ))
163+ self .weights = tuple (weights )
164+ self .num_samples = num_samples
165+
166+ def __iter__ (self ):
167+ return iter (multinomial (self .num_samples , self .weights ))
168+
169+ def __len__ (self ):
170+ return self .num_samples
171+
172+
173+ class BatchSampler (Sampler ):
174+ r"""Wraps another sampler to yield a mini-batch of indices.
175+ Args:
176+ sampler (Sampler): Base sampler.
177+ batch_size (int): Size of mini-batch.
178+ drop_last (bool): If ``True``, the sampler will drop the last batch if
179+ its size would be less than ``batch_size``
180+ Example:
181+ >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
182+ [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
183+ >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
184+ [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
185+ """
186+
187+ def __init__ (self , sampler , batch_size , drop_last ):
188+ if not isinstance (sampler , Sampler ):
189+ raise ValueError ("sampler should be an instance of "
190+ "torch.utils.data.Sampler, but got sampler={}"
191+ .format (sampler ))
192+ if not isinstance (batch_size , int ) or isinstance (batch_size , bool ) or \
193+ batch_size <= 0 :
194+ raise ValueError ("batch_size should be a positive integer value, "
195+ "but got batch_size={}" .format (batch_size ))
196+ if not isinstance (drop_last , bool ):
197+ raise ValueError ("drop_last should be a boolean value, but got "
198+ "drop_last={}" .format (drop_last ))
199+ self .sampler = sampler
200+ self .batch_size = batch_size
201+ self .drop_last = drop_last
202+
203+ def __iter__ (self ):
204+ batch = []
205+ for idx in self .sampler :
206+ batch .append (idx )
207+ if len (batch ) == self .batch_size :
208+ yield batch
209+ batch = []
210+ if len (batch ) > 0 and not self .drop_last :
211+ yield batch
212+
213+ def __len__ (self ):
214+ if self .drop_last :
215+ return len (self .sampler ) // self .batch_size
216+ else :
217+ return (len (self .sampler ) + self .batch_size - 1 ) // self .batch_size
0 commit comments