57
57
import java .util .concurrent .ThreadLocalRandom ;
58
58
import java .util .concurrent .atomic .AtomicLong ;
59
59
import java .util .function .Function ;
60
+ import java .util .function .Supplier ;
60
61
import java .util .stream .Collectors ;
62
+ import java .util .stream .IntStream ;
61
63
import java .util .stream .LongStream ;
62
64
63
65
import static org .neo4j .gds .embeddings .graphsage .GraphSageHelper .embeddingsComputationGraph ;
68
70
public class GraphSageModelTrainer {
69
71
private final long randomSeed ;
70
72
private final boolean useWeights ;
71
- private final double learningRate ;
72
- private final double tolerance ;
73
- private final int negativeSampleWeight ;
74
- private final int concurrency ;
75
- private final int epochs ;
76
- private final int maxIterations ;
77
- private final int maxSearchDepth ;
78
73
private final Function <Graph , List <LayerConfig >> layerConfigsFunction ;
79
74
private final FeatureFunction featureFunction ;
80
75
private final Collection <Weights <Matrix >> labelProjectionWeights ;
81
76
private final ExecutorService executor ;
82
77
private final ProgressTracker progressTracker ;
83
- private final int batchSize ;
78
+ private final GraphSageTrainConfig config ;
84
79
85
80
public GraphSageModelTrainer (GraphSageTrainConfig config , ExecutorService executor , ProgressTracker progressTracker ) {
86
81
this (config , executor , progressTracker , new SingleLabelFeatureFunction (), Collections .emptyList ());
@@ -94,14 +89,7 @@ public GraphSageModelTrainer(
94
89
Collection <Weights <Matrix >> labelProjectionWeights
95
90
) {
96
91
this .layerConfigsFunction = graph -> config .layerConfigs (firstLayerColumns (config , graph ));
97
- this .batchSize = config .batchSize ();
98
- this .learningRate = config .learningRate ();
99
- this .tolerance = config .tolerance ();
100
- this .negativeSampleWeight = config .negativeSampleWeight ();
101
- this .concurrency = config .concurrency ();
102
- this .epochs = config .epochs ();
103
- this .maxIterations = config .maxIterations ();
104
- this .maxSearchDepth = config .searchDepth ();
92
+ this .config = config ;
105
93
this .featureFunction = featureFunction ;
106
94
this .labelProjectionWeights = labelProjectionWeights ;
107
95
this .executor = executor ;
@@ -139,21 +127,29 @@ public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
139
127
140
128
var batchTasks = PartitionUtils .rangePartitionWithBatchSize (
141
129
graph .nodeCount (),
142
- batchSize ,
130
+ config . batchSize () ,
143
131
batch -> createBatchTask (graph , features , layers , weights , batch )
144
132
);
133
+ var random = new Random (randomSeed );
134
+ Supplier <List <BatchTask >> batchTaskSampler = () -> IntStream .range (0 , config .batchesPerIteration (graph .nodeCount ()))
135
+ .mapToObj (__ -> batchTasks .get (random .nextInt (batchTasks .size ())))
136
+ .collect (Collectors .toList ());
145
137
146
138
progressTracker .endSubTask ("Prepare batches" );
147
139
140
+ progressTracker .beginSubTask ("Train model" );
141
+
148
142
boolean converged = false ;
149
143
var iterationLossesPerEpoch = new ArrayList <List <Double >>();
150
-
151
- progressTracker . beginSubTask ( "Train model" );
144
+ var prevEpochLoss = Double . NaN ;
145
+ int epochs = config . epochs ( );
152
146
153
147
for (int epoch = 1 ; epoch <= epochs && !converged ; epoch ++) {
154
148
progressTracker .beginSubTask ("Epoch" );
155
- var epochResult = trainEpoch (batchTasks , weights );
156
- iterationLossesPerEpoch .add (epochResult .losses ());
149
+ var epochResult = trainEpoch (batchTaskSampler , weights , prevEpochLoss );
150
+ List <Double > epochLosses = epochResult .losses ();
151
+ iterationLossesPerEpoch .add (epochLosses );
152
+ prevEpochLoss = epochLosses .get (epochLosses .size () - 1 );
157
153
converged = epochResult .converged ();
158
154
progressTracker .endSubTask ("Epoch" );
159
155
}
@@ -188,43 +184,52 @@ private BatchTask createBatchTask(
188
184
useWeights ? localGraph ::relationshipProperty : UNWEIGHTED ,
189
185
embeddingVariable ,
190
186
totalBatch ,
191
- negativeSampleWeight
187
+ config . negativeSampleWeight ()
192
188
);
193
189
194
- return new BatchTask (lossFunction , weights , tolerance , progressTracker );
190
+ return new BatchTask (lossFunction , weights , progressTracker );
195
191
}
196
192
197
- private EpochResult trainEpoch (List <BatchTask > batchTasks , List <Weights <? extends Tensor <?>>> weights ) {
198
- var updater = new AdamOptimizer (weights , learningRate );
193
+ private EpochResult trainEpoch (
194
+ Supplier <List <BatchTask >> sampledBatchTaskSupplier ,
195
+ List <Weights <? extends Tensor <?>>> weights ,
196
+ double prevEpochLoss
197
+ ) {
198
+ var updater = new AdamOptimizer (weights , config .learningRate ());
199
199
200
200
int iteration = 1 ;
201
201
var iterationLosses = new ArrayList <Double >();
202
+ double prevLoss = prevEpochLoss ;
202
203
var converged = false ;
203
204
204
- for (;iteration <= maxIterations ; iteration ++) {
205
+ int maxIterations = config .maxIterations ();
206
+ for (; iteration <= maxIterations ; iteration ++) {
205
207
progressTracker .beginSubTask ("Iteration" );
206
208
209
+ var sampledBatchTasks = sampledBatchTaskSupplier .get ();
210
+
207
211
// run forward + maybe backward for each Batch
208
- ParallelUtil .runWithConcurrency (concurrency , batchTasks , executor );
209
- var avgLoss = batchTasks .stream ().mapToDouble (BatchTask ::loss ).average ().orElseThrow ();
212
+ ParallelUtil .runWithConcurrency (config . concurrency (), sampledBatchTasks , executor );
213
+ var avgLoss = sampledBatchTasks .stream ().mapToDouble (BatchTask ::loss ).average ().orElseThrow ();
210
214
iterationLosses .add (avgLoss );
215
+ progressTracker .logMessage (formatWithLocale ("LOSS: %.10f" , avgLoss ));
211
216
212
- converged = batchTasks . stream (). allMatch ( task -> task . converged );
213
- if ( converged ) {
214
- progressTracker .endSubTask ();
217
+ if ( Math . abs ( prevLoss - avgLoss ) < config . tolerance ()) {
218
+ converged = true ;
219
+ progressTracker .endSubTask ("Iteration" );
215
220
break ;
216
221
}
217
222
218
- var batchedGradients = batchTasks
223
+ prevLoss = avgLoss ;
224
+
225
+ var batchedGradients = sampledBatchTasks
219
226
.stream ()
220
227
.map (BatchTask ::weightGradients )
221
228
.collect (Collectors .toList ());
222
229
223
230
var meanGradients = averageTensors (batchedGradients );
224
231
225
232
updater .update (meanGradients );
226
-
227
- progressTracker .logMessage (formatWithLocale ("LOSS: %.10f" , avgLoss ));
228
233
progressTracker .endSubTask ("Iteration" );
229
234
}
230
235
@@ -243,34 +248,23 @@ static class BatchTask implements Runnable {
243
248
private final Variable <Scalar > lossFunction ;
244
249
private final List <Weights <? extends Tensor <?>>> weightVariables ;
245
250
private List <? extends Tensor <?>> weightGradients ;
246
- private final double tolerance ;
247
251
private final ProgressTracker progressTracker ;
248
- private boolean converged ;
249
- private double prevLoss ;
252
+ private double loss ;
250
253
251
254
BatchTask (
252
255
Variable <Scalar > lossFunction ,
253
256
List <Weights <? extends Tensor <?>>> weightVariables ,
254
- double tolerance ,
255
257
ProgressTracker progressTracker
256
258
) {
257
259
this .lossFunction = lossFunction ;
258
260
this .weightVariables = weightVariables ;
259
- this .tolerance = tolerance ;
260
261
this .progressTracker = progressTracker ;
261
262
}
262
263
263
264
@ Override
264
265
public void run () {
265
- if (converged ) { // Don't try to go further
266
- return ;
267
- }
268
-
269
266
var localCtx = new ComputationContext ();
270
- var loss = localCtx .forward (lossFunction ).value ();
271
-
272
- converged = Math .abs (prevLoss - loss ) < tolerance ;
273
- prevLoss = loss ;
267
+ loss = localCtx .forward (lossFunction ).value ();
274
268
275
269
localCtx .backward (lossFunction );
276
270
weightGradients = weightVariables .stream ().map (localCtx ::gradient ).collect (Collectors .toList ());
@@ -279,7 +273,7 @@ public void run() {
279
273
}
280
274
281
275
public double loss () {
282
- return prevLoss ;
276
+ return loss ;
283
277
}
284
278
285
279
List <? extends Tensor <?>> weightGradients () {
@@ -312,7 +306,7 @@ LongStream neighborBatch(Graph graph, Partition batch, long batchLocalSeed) {
312
306
// sample a neighbor for each batchNode
313
307
batch .consume (nodeId -> {
314
308
// randomWalk with at most maxSearchDepth steps and only save last node
315
- int searchDepth = localRandom .nextInt (maxSearchDepth ) + 1 ;
309
+ int searchDepth = localRandom .nextInt (config . searchDepth () ) + 1 ;
316
310
AtomicLong currentNode = new AtomicLong (nodeId );
317
311
while (searchDepth > 0 ) {
318
312
NeighborhoodSampler neighborhoodSampler = new NeighborhoodSampler (currentNode .get () + searchDepth );
0 commit comments