Skip to content

Commit bbcb3c6

Browse files
committed
[NeoML] CProblemSourceLayer uses new CShuffler
Signed-off-by: Kirill Golikov <kirill.golikov@abbyy.com>
1 parent 52d234f commit bbcb3c6

1 file changed

Lines changed: 37 additions & 46 deletions

File tree

NeoML/src/Dnn/Layers/ModelWrapperLayer.cpp

Lines changed: 37 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@ limitations under the License.
2626
namespace NeoML {
2727

2828
// Shuffles the elements array.
29-
static void shuffle( CArray<int>& elements, unsigned seed )
29+
static void shuffle( CArray<int>& elements, CRandom& random )
3030
{
31-
CRandom random( seed );
3231
CShuffler indexGenerator( random, elements.Size() );
3332
CArray<int> oldElements;
3433
elements.CopyTo( oldElements );
@@ -42,41 +41,26 @@ static void shuffle( CArray<int>& elements, unsigned seed )
4241
// Shuffles the elements of an array and returns them one by one.
4342
// Unlike CShuffler, it returns the elements of the array, not the indices of the elements,
4443
// and also supports cyclicity (when the end of the sequence is reached, it will be shuffled again).
45-
class CShuffledGenerator final {
44+
template <typename T>
45+
class CShuffledElements final {
4646
public:
47-
CShuffledGenerator( const CArray<int>& _elements, unsigned seed );
48-
49-
CRandom& Random() { return random; }
47+
CShuffledElements( CArray<int>&& _elements, CRandom& _random ) :
48+
random( _random ), elements( std::move( _elements ) ), shuffler( random, elements.Size() ) {}
49+
CShuffledElements( CShuffledElements&& ) = default;
50+
51+
CShuffledElements& operator=( CShuffledElements&& ) = default;
52+
5053
// The number of elements in the elements array.
5154
int Size() const { return elements.Size(); }
5255
// Generates the next element of the sequence.
53-
int Next();
56+
T Next() { if( shuffler.IsFinished() ) { shuffler.Reset(); } return elements[shuffler.Next()]; }
5457

5558
private:
56-
CRandom random;
57-
CArray<int> elements; // The shuffled elements array.
58-
int position = 0; // The position in the elements array.
59+
CRandom& random; // The random generator
60+
CArray<T> elements; // The shuffled elements array
61+
CShuffler shuffler; // The index shuffled generator
5962
};
6063

61-
CShuffledGenerator::CShuffledGenerator( const CArray<int>& _elements, unsigned seed ) :
62-
random( seed )
63-
{
64-
_elements.CopyTo( elements );
65-
shuffle( elements, seed );
66-
}
67-
68-
int CShuffledGenerator::Next()
69-
{
70-
const int result = elements[position++];
71-
NeoAssert( position <= Size() );
72-
// reached the end - mix the elements.
73-
if( position == Size() ) {
74-
shuffle( elements, random.Next() );
75-
position = 0;
76-
}
77-
return result;
78-
}
79-
8064
//---------------------------------------------------------------------------------------------------------------------
8165

8266
// CBalancedPairBatchGenerator
@@ -100,15 +84,28 @@ class CBalancedPairBatchGenerator : public IShuffledBatchGenerator {
10084
void DeleteUnseenElement( int index ) override { unseenElementsIndices.Delete( index ); }
10185

10286
private:
87+
CRandom random;
10388
// The dictionary of "label -> generator of elements indexes with this label".
104-
CMap<int, CPtrOwner<CShuffledGenerator>> labelToIndexGenerator;
89+
CMap<int, CShuffledElements<int>> labelToIndexGenerator;
10590
// The labels generator
106-
CPtrOwner<CShuffledGenerator> labelGenerator;
91+
CShuffledElements<int> labelGenerator;
10792
// Indexes of elements, which aren't in any batch, to detect epoch's end
10893
CHashTable<int> unseenElementsIndices;
94+
95+
CArray<int> getLabels( const IProblem& );
10996
};
11097

111-
CBalancedPairBatchGenerator::CBalancedPairBatchGenerator( const IProblem& problem, unsigned seed )
98+
CBalancedPairBatchGenerator::CBalancedPairBatchGenerator( const IProblem& problem, unsigned seed ) :
99+
random( seed ),
100+
labelGenerator( getLabels( problem ), random )
101+
{
102+
unseenElementsIndices.SetBufferSize( problem.GetVectorCount() );
103+
for( int i = 0; i < problem.GetVectorCount(); ++i ) {
104+
unseenElementsIndices.Add( i );
105+
}
106+
}
107+
108+
CArray<int> CBalancedPairBatchGenerator::getLabels( const IProblem& problem )
112109
{
113110
CMap<int, CArray<int>> labelToIndexes;
114111
for( int i = 0; i < problem.GetVectorCount(); ++i ) {
@@ -118,16 +115,10 @@ CBalancedPairBatchGenerator::CBalancedPairBatchGenerator( const IProblem& proble
118115
CArray<int> labelsUnique;
119116
labelsUnique.SetBufferSize( labelToIndexes.Size() );
120117
for( auto& item : labelToIndexes ) {
121-
labelToIndexGenerator.CreateNewValue( item.Key ) = new CShuffledGenerator( item.Value, seed );
118+
labelToIndexGenerator.Add( item.Key, CShuffledElements<int>( std::move( item.Value ), random ) );
122119
labelsUnique.Add( item.Key );
123120
}
124-
125-
labelGenerator = new CShuffledGenerator( labelsUnique, seed );
126-
127-
unseenElementsIndices.SetBufferSize( problem.GetVectorCount() );
128-
for( int i = 0; i < problem.GetVectorCount(); ++i ) {
129-
unseenElementsIndices.Add( i );
130-
}
121+
return labelsUnique;
131122
}
132123

133124
CArray<int> CBalancedPairBatchGenerator::GenerateBatchIndexes( int batchSize )
@@ -143,12 +134,12 @@ CArray<int> CBalancedPairBatchGenerator::GenerateBatchIndexes( int batchSize )
143134
const int numOfSingleClassSamples = static_cast<int>( round( idealNumOfSingleClassSamples ) );
144135

145136
// Sample a random class and collect numOfSingleClassSamples elements from it into a batch.
146-
const int majorLabel = labelGenerator->Next();
147-
NeoAssert( numOfSingleClassSamples <= labelToIndexGenerator.Get( majorLabel )->Size() );
137+
const int majorLabel = labelGenerator.Next();
138+
NeoAssert( numOfSingleClassSamples <= labelToIndexGenerator.Get( majorLabel ).Size() );
148139
// Sample a random class and collect numOfSingleClassSamples elements from it into a batch.
149140
CArray<int> batchIndexes;
150141
for( int i = 0; i < numOfSingleClassSamples; ++i ) {
151-
batchIndexes.Add( labelToIndexGenerator.Get( majorLabel )->Next() );
142+
batchIndexes.Add( labelToIndexGenerator.Get( majorLabel ).Next() );
152143
}
153144

154145
// The number of remaining elements, also the number of remaining classes, because for the remaining elements
@@ -164,18 +155,18 @@ CArray<int> CBalancedPairBatchGenerator::GenerateBatchIndexes( int batchSize )
164155
// Will never loop if numOfOtherClasses + 1 <= totalNumOfClasses, because the generator shuffles classes without duplicates.
165156
while( usedClasses.Size() < numOfOtherClasses + 1 ) {
166157
// Sample the class.
167-
const int label = labelGenerator->Next();
158+
const int label = labelGenerator.Next();
168159
// Skip the class if it has already been sampled.
169160
if( usedClasses.Has( label ) ) {
170161
continue;
171162
}
172163
usedClasses.Add( label );
173164
// Sample an element from this class.
174-
const int negativeSample = labelToIndexGenerator.Get( label )->Next();
165+
const int negativeSample = labelToIndexGenerator.Get( label ).Next();
175166
batchIndexes.Add( negativeSample );
176167
}
177168
// Just in case, we mix the batch before giving it to the model so that it doesn’t accidentally learn the structure of the batch.
178-
shuffle( batchIndexes, labelGenerator->Random().Next() );
169+
shuffle( batchIndexes, random );
179170

180171
NeoAssert( batchIndexes.Size() == batchSize );
181172
return batchIndexes;

0 commit comments

Comments
 (0)