@@ -26,9 +26,8 @@ limitations under the License.
2626namespace NeoML {
2727
2828// Shuffles the elements array.
29- static void shuffle ( CArray<int >& elements, unsigned seed )
29+ static void shuffle ( CArray<int >& elements, CRandom& random )
3030{
31- CRandom random ( seed );
3231 CShuffler indexGenerator ( random, elements.Size () );
3332 CArray<int > oldElements;
3433 elements.CopyTo ( oldElements );
@@ -42,41 +41,26 @@ static void shuffle( CArray<int>& elements, unsigned seed )
4241// Shuffles the elements of an array and returns them one by one.
4342// Unlike CShuffler, it returns the elements of the array, not the indices of the elements,
4443// and also supports cyclicity (when the end of the sequence is reached, it will be shuffled again).
45- class CShuffledGenerator final {
44+ template <typename T>
45+ class CShuffledElements final {
4646public:
47- CShuffledGenerator ( const CArray<int >& _elements, unsigned seed );
48-
49- CRandom& Random () { return random; }
47+ CShuffledElements ( CArray<int >&& _elements, CRandom& _random ) :
48+ random ( _random ), elements( std::move( _elements ) ), shuffler( random, elements.Size() ) {}
49+ CShuffledElements ( CShuffledElements&& ) = default ;
50+
51+ CShuffledElements& operator =( CShuffledElements&& ) = default ;
52+
5053 // The number of elements in the elements array.
5154 int Size () const { return elements.Size (); }
5255 // Generates the next element of the sequence.
53- int Next ();
56+ T Next () { if ( shuffler. IsFinished () ) { shuffler. Reset (); } return elements[shuffler. Next ()]; }
5457
5558private:
56- CRandom random;
57- CArray<int > elements; // The shuffled elements array.
58- int position = 0 ; // The position in the elements array.
59+ CRandom& random; // The random generator
60+ CArray<T > elements; // The shuffled elements array
61+ CShuffler shuffler ; // The index shuffled generator
5962};
6063
61- CShuffledGenerator::CShuffledGenerator ( const CArray<int >& _elements, unsigned seed ) :
62- random ( seed )
63- {
64- _elements.CopyTo ( elements );
65- shuffle ( elements, seed );
66- }
67-
68- int CShuffledGenerator::Next ()
69- {
70- const int result = elements[position++];
71- NeoAssert ( position <= Size () );
72- // reached the end - mix the elements.
73- if ( position == Size () ) {
74- shuffle ( elements, random.Next () );
75- position = 0 ;
76- }
77- return result;
78- }
79-
8064// ---------------------------------------------------------------------------------------------------------------------
8165
8266// CBalancedPairBatchGenerator
@@ -100,15 +84,28 @@ class CBalancedPairBatchGenerator : public IShuffledBatchGenerator {
10084 void DeleteUnseenElement ( int index ) override { unseenElementsIndices.Delete ( index ); }
10185
10286private:
87+ CRandom random;
10388 // The dictionary of "label -> generator of elements indexes with this label".
104- CMap<int , CPtrOwner<CShuffledGenerator >> labelToIndexGenerator;
89+ CMap<int , CShuffledElements< int >> labelToIndexGenerator;
10590 // The labels generator
106- CPtrOwner<CShuffledGenerator > labelGenerator;
91+ CShuffledElements< int > labelGenerator;
10792 // Indexes of elements, which aren't in any batch, to detect epoch's end
10893 CHashTable<int > unseenElementsIndices;
94+
95+ CArray<int > getLabels ( const IProblem& );
10996};
11097
111- CBalancedPairBatchGenerator::CBalancedPairBatchGenerator ( const IProblem& problem, unsigned seed )
98+ CBalancedPairBatchGenerator::CBalancedPairBatchGenerator ( const IProblem& problem, unsigned seed ) :
99+ random ( seed ),
100+ labelGenerator ( getLabels( problem ), random )
101+ {
102+ unseenElementsIndices.SetBufferSize ( problem.GetVectorCount () );
103+ for ( int i = 0 ; i < problem.GetVectorCount (); ++i ) {
104+ unseenElementsIndices.Add ( i );
105+ }
106+ }
107+
108+ CArray<int > CBalancedPairBatchGenerator::getLabels ( const IProblem& problem )
112109{
113110 CMap<int , CArray<int >> labelToIndexes;
114111 for ( int i = 0 ; i < problem.GetVectorCount (); ++i ) {
@@ -118,16 +115,10 @@ CBalancedPairBatchGenerator::CBalancedPairBatchGenerator( const IProblem& proble
118115 CArray<int > labelsUnique;
119116 labelsUnique.SetBufferSize ( labelToIndexes.Size () );
120117 for ( auto & item : labelToIndexes ) {
121- labelToIndexGenerator.CreateNewValue ( item.Key ) = new CShuffledGenerator ( item.Value , seed );
118+ labelToIndexGenerator.Add ( item.Key , CShuffledElements< int >( std::move ( item.Value ), random ) );
122119 labelsUnique.Add ( item.Key );
123120 }
124-
125- labelGenerator = new CShuffledGenerator ( labelsUnique, seed );
126-
127- unseenElementsIndices.SetBufferSize ( problem.GetVectorCount () );
128- for ( int i = 0 ; i < problem.GetVectorCount (); ++i ) {
129- unseenElementsIndices.Add ( i );
130- }
121+ return labelsUnique;
131122}
132123
133124CArray<int > CBalancedPairBatchGenerator::GenerateBatchIndexes ( int batchSize )
@@ -143,12 +134,12 @@ CArray<int> CBalancedPairBatchGenerator::GenerateBatchIndexes( int batchSize )
143134 const int numOfSingleClassSamples = static_cast <int >( round ( idealNumOfSingleClassSamples ) );
144135
145136 // Sample a random class and collect numOfSingleClassSamples elements from it into a batch.
146- const int majorLabel = labelGenerator-> Next ();
147- NeoAssert ( numOfSingleClassSamples <= labelToIndexGenerator.Get ( majorLabel )-> Size () );
137+ const int majorLabel = labelGenerator. Next ();
138+ NeoAssert ( numOfSingleClassSamples <= labelToIndexGenerator.Get ( majorLabel ). Size () );
148139 // Sample a random class and collect numOfSingleClassSamples elements from it into a batch.
149140 CArray<int > batchIndexes;
150141 for ( int i = 0 ; i < numOfSingleClassSamples; ++i ) {
151- batchIndexes.Add ( labelToIndexGenerator.Get ( majorLabel )-> Next () );
142+ batchIndexes.Add ( labelToIndexGenerator.Get ( majorLabel ). Next () );
152143 }
153144
154145 // The number of remaining elements, also the number of remaining classes, because for the remaining elements
@@ -164,18 +155,18 @@ CArray<int> CBalancedPairBatchGenerator::GenerateBatchIndexes( int batchSize )
164155 // Will never loop if numOfOtherClasses + 1 <= totalNumOfClasses, because the generator shuffles classes without duplicates.
165156 while ( usedClasses.Size () < numOfOtherClasses + 1 ) {
166157 // Sample the class.
167- const int label = labelGenerator-> Next ();
158+ const int label = labelGenerator. Next ();
168159 // Skip the class if it has already been sampled.
169160 if ( usedClasses.Has ( label ) ) {
170161 continue ;
171162 }
172163 usedClasses.Add ( label );
173164 // Sample an element from this class.
174- const int negativeSample = labelToIndexGenerator.Get ( label )-> Next ();
165+ const int negativeSample = labelToIndexGenerator.Get ( label ). Next ();
175166 batchIndexes.Add ( negativeSample );
176167 }
177168 // Just in case, we mix the batch before giving it to the model so that it doesn’t accidentally learn the structure of the batch.
178- shuffle ( batchIndexes, labelGenerator-> Random (). Next () );
169+ shuffle ( batchIndexes, random );
179170
180171 NeoAssert ( batchIndexes.Size () == batchSize );
181172 return batchIndexes;
0 commit comments