Skip to content

Commit 255e450

Browse files
authored
Merge pull request #5313 from FlorentinD/gs-sample-batch-per-iteration
GraphSage sample batch per iteration (drastically shorten runtime)
2 parents 2c584e7 + f2c47a9 commit 255e450

File tree

6 files changed

+139
-81
lines changed

6 files changed

+139
-81
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainer.java

+43-49
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@
5757
import java.util.concurrent.ThreadLocalRandom;
5858
import java.util.concurrent.atomic.AtomicLong;
5959
import java.util.function.Function;
60+
import java.util.function.Supplier;
6061
import java.util.stream.Collectors;
62+
import java.util.stream.IntStream;
6163
import java.util.stream.LongStream;
6264

6365
import static org.neo4j.gds.embeddings.graphsage.GraphSageHelper.embeddingsComputationGraph;
@@ -68,19 +70,12 @@
6870
public class GraphSageModelTrainer {
6971
private final long randomSeed;
7072
private final boolean useWeights;
71-
private final double learningRate;
72-
private final double tolerance;
73-
private final int negativeSampleWeight;
74-
private final int concurrency;
75-
private final int epochs;
76-
private final int maxIterations;
77-
private final int maxSearchDepth;
7873
private final Function<Graph, List<LayerConfig>> layerConfigsFunction;
7974
private final FeatureFunction featureFunction;
8075
private final Collection<Weights<Matrix>> labelProjectionWeights;
8176
private final ExecutorService executor;
8277
private final ProgressTracker progressTracker;
83-
private final int batchSize;
78+
private final GraphSageTrainConfig config;
8479

8580
public GraphSageModelTrainer(GraphSageTrainConfig config, ExecutorService executor, ProgressTracker progressTracker) {
8681
this(config, executor, progressTracker, new SingleLabelFeatureFunction(), Collections.emptyList());
@@ -94,14 +89,7 @@ public GraphSageModelTrainer(
9489
Collection<Weights<Matrix>> labelProjectionWeights
9590
) {
9691
this.layerConfigsFunction = graph -> config.layerConfigs(firstLayerColumns(config, graph));
97-
this.batchSize = config.batchSize();
98-
this.learningRate = config.learningRate();
99-
this.tolerance = config.tolerance();
100-
this.negativeSampleWeight = config.negativeSampleWeight();
101-
this.concurrency = config.concurrency();
102-
this.epochs = config.epochs();
103-
this.maxIterations = config.maxIterations();
104-
this.maxSearchDepth = config.searchDepth();
92+
this.config = config;
10593
this.featureFunction = featureFunction;
10694
this.labelProjectionWeights = labelProjectionWeights;
10795
this.executor = executor;
@@ -139,21 +127,29 @@ public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
139127

140128
var batchTasks = PartitionUtils.rangePartitionWithBatchSize(
141129
graph.nodeCount(),
142-
batchSize,
130+
config.batchSize(),
143131
batch -> createBatchTask(graph, features, layers, weights, batch)
144132
);
133+
var random = new Random(randomSeed);
134+
Supplier<List<BatchTask>> batchTaskSampler = () -> IntStream.range(0, config.batchesPerIteration(graph.nodeCount()))
135+
.mapToObj(__ -> batchTasks.get(random.nextInt(batchTasks.size())))
136+
.collect(Collectors.toList());
145137

146138
progressTracker.endSubTask("Prepare batches");
147139

140+
progressTracker.beginSubTask("Train model");
141+
148142
boolean converged = false;
149143
var iterationLossesPerEpoch = new ArrayList<List<Double>>();
150-
151-
progressTracker.beginSubTask("Train model");
144+
var prevEpochLoss = Double.NaN;
145+
int epochs = config.epochs();
152146

153147
for (int epoch = 1; epoch <= epochs && !converged; epoch++) {
154148
progressTracker.beginSubTask("Epoch");
155-
var epochResult = trainEpoch(batchTasks, weights);
156-
iterationLossesPerEpoch.add(epochResult.losses());
149+
var epochResult = trainEpoch(batchTaskSampler, weights, prevEpochLoss);
150+
List<Double> epochLosses = epochResult.losses();
151+
iterationLossesPerEpoch.add(epochLosses);
152+
prevEpochLoss = epochLosses.get(epochLosses.size() - 1);
157153
converged = epochResult.converged();
158154
progressTracker.endSubTask("Epoch");
159155
}
@@ -188,43 +184,52 @@ private BatchTask createBatchTask(
188184
useWeights ? localGraph::relationshipProperty : UNWEIGHTED,
189185
embeddingVariable,
190186
totalBatch,
191-
negativeSampleWeight
187+
config.negativeSampleWeight()
192188
);
193189

194-
return new BatchTask(lossFunction, weights, tolerance, progressTracker);
190+
return new BatchTask(lossFunction, weights, progressTracker);
195191
}
196192

197-
private EpochResult trainEpoch(List<BatchTask> batchTasks, List<Weights<? extends Tensor<?>>> weights) {
198-
var updater = new AdamOptimizer(weights, learningRate);
193+
private EpochResult trainEpoch(
194+
Supplier<List<BatchTask>> sampledBatchTaskSupplier,
195+
List<Weights<? extends Tensor<?>>> weights,
196+
double prevEpochLoss
197+
) {
198+
var updater = new AdamOptimizer(weights, config.learningRate());
199199

200200
int iteration = 1;
201201
var iterationLosses = new ArrayList<Double>();
202+
double prevLoss = prevEpochLoss;
202203
var converged = false;
203204

204-
for (;iteration <= maxIterations; iteration++) {
205+
int maxIterations = config.maxIterations();
206+
for (; iteration <= maxIterations; iteration++) {
205207
progressTracker.beginSubTask("Iteration");
206208

209+
var sampledBatchTasks = sampledBatchTaskSupplier.get();
210+
207211
// run forward + maybe backward for each Batch
208-
ParallelUtil.runWithConcurrency(concurrency, batchTasks, executor);
209-
var avgLoss = batchTasks.stream().mapToDouble(BatchTask::loss).average().orElseThrow();
212+
ParallelUtil.runWithConcurrency(config.concurrency(), sampledBatchTasks, executor);
213+
var avgLoss = sampledBatchTasks.stream().mapToDouble(BatchTask::loss).average().orElseThrow();
210214
iterationLosses.add(avgLoss);
215+
progressTracker.logMessage(formatWithLocale("LOSS: %.10f", avgLoss));
211216

212-
converged = batchTasks.stream().allMatch(task -> task.converged);
213-
if (converged) {
214-
progressTracker.endSubTask();
217+
if (Math.abs(prevLoss - avgLoss) < config.tolerance()) {
218+
converged = true;
219+
progressTracker.endSubTask("Iteration");
215220
break;
216221
}
217222

218-
var batchedGradients = batchTasks
223+
prevLoss = avgLoss;
224+
225+
var batchedGradients = sampledBatchTasks
219226
.stream()
220227
.map(BatchTask::weightGradients)
221228
.collect(Collectors.toList());
222229

223230
var meanGradients = averageTensors(batchedGradients);
224231

225232
updater.update(meanGradients);
226-
227-
progressTracker.logMessage(formatWithLocale("LOSS: %.10f", avgLoss));
228233
progressTracker.endSubTask("Iteration");
229234
}
230235

@@ -243,34 +248,23 @@ static class BatchTask implements Runnable {
243248
private final Variable<Scalar> lossFunction;
244249
private final List<Weights<? extends Tensor<?>>> weightVariables;
245250
private List<? extends Tensor<?>> weightGradients;
246-
private final double tolerance;
247251
private final ProgressTracker progressTracker;
248-
private boolean converged;
249-
private double prevLoss;
252+
private double loss;
250253

251254
BatchTask(
252255
Variable<Scalar> lossFunction,
253256
List<Weights<? extends Tensor<?>>> weightVariables,
254-
double tolerance,
255257
ProgressTracker progressTracker
256258
) {
257259
this.lossFunction = lossFunction;
258260
this.weightVariables = weightVariables;
259-
this.tolerance = tolerance;
260261
this.progressTracker = progressTracker;
261262
}
262263

263264
@Override
264265
public void run() {
265-
if(converged) { // Don't try to go further
266-
return;
267-
}
268-
269266
var localCtx = new ComputationContext();
270-
var loss = localCtx.forward(lossFunction).value();
271-
272-
converged = Math.abs(prevLoss - loss) < tolerance;
273-
prevLoss = loss;
267+
loss = localCtx.forward(lossFunction).value();
274268

275269
localCtx.backward(lossFunction);
276270
weightGradients = weightVariables.stream().map(localCtx::gradient).collect(Collectors.toList());
@@ -279,7 +273,7 @@ public void run() {
279273
}
280274

281275
public double loss() {
282-
return prevLoss;
276+
return loss;
283277
}
284278

285279
List<? extends Tensor<?>> weightGradients() {
@@ -312,7 +306,7 @@ LongStream neighborBatch(Graph graph, Partition batch, long batchLocalSeed) {
312306
// sample a neighbor for each batchNode
313307
batch.consume(nodeId -> {
314308
// randomWalk with at most maxSearchDepth steps and only save last node
315-
int searchDepth = localRandom.nextInt(maxSearchDepth) + 1;
309+
int searchDepth = localRandom.nextInt(config.searchDepth()) + 1;
316310
AtomicLong currentNode = new AtomicLong(nodeId);
317311
while (searchDepth > 0) {
318312
NeighborhoodSampler neighborhoodSampler = new NeighborhoodSampler(currentNode.get() + searchDepth);

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainConfig.java

+12
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,18 @@ default int maxIterations() {
120120
return 10;
121121
}
122122

123+
@Configuration.Key("batchSamplingRatio")
124+
@Configuration.DoubleRange(min = 0, max = 1, minInclusive = false)
125+
Optional<Double> maybeBatchSamplingRatio();
126+
127+
@Configuration.Ignore
128+
@Value.Derived
129+
default int batchesPerIteration(long nodeCount) {
130+
var samplingRatio = maybeBatchSamplingRatio().orElse(Math.min(1.0, batchSize() * concurrency() / (double) nodeCount));
131+
var totalNumberOfBatches = Math.ceil(nodeCount / (double) batchSize());
132+
return (int) Math.ceil(samplingRatio * totalNumberOfBatches);
133+
}
134+
123135
@Value.Default
124136
default int searchDepth() {
125137
return 5;

algo/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainerTest.java

+57-26
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.junit.jupiter.api.BeforeEach;
2727
import org.junit.jupiter.api.Test;
2828
import org.junit.jupiter.params.ParameterizedTest;
29+
import org.junit.jupiter.params.provider.CsvSource;
2930
import org.junit.jupiter.params.provider.ValueSource;
3031
import org.neo4j.gds.Orientation;
3132
import org.neo4j.gds.api.Graph;
@@ -34,7 +35,7 @@
3435
import org.neo4j.gds.core.utils.partition.PartitionUtils;
3536
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
3637
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
37-
import org.neo4j.gds.embeddings.graphsage.algo.ImmutableGraphSageTrainConfig;
38+
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfigImpl;
3839
import org.neo4j.gds.extension.GdlExtension;
3940
import org.neo4j.gds.extension.GdlGraph;
4041
import org.neo4j.gds.extension.Inject;
@@ -77,7 +78,7 @@ class GraphSageModelTrainerTest {
7778
@Inject
7879
private Graph arrayGraph;
7980
private HugeObjectArray<double[]> features;
80-
private ImmutableGraphSageTrainConfig.Builder configBuilder;
81+
private GraphSageTrainConfigImpl.Builder configBuilder;
8182

8283

8384
@BeforeEach
@@ -87,7 +88,8 @@ void setUp() {
8788

8889
Random random = new Random(19L);
8990
LongStream.range(0, nodeCount).forEach(n -> features.set(n, random.doubles(FEATURES_COUNT).toArray()));
90-
configBuilder = ImmutableGraphSageTrainConfig.builder()
91+
configBuilder = GraphSageTrainConfigImpl.builder()
92+
.username("DUMMY")
9193
.featureProperties(Collections.nCopies(FEATURES_COUNT, "dummyProp"))
9294
.embeddingDimension(EMBEDDING_DIMENSION);
9395
}
@@ -202,7 +204,7 @@ void testLosses() {
202204
.embeddingDimension(12)
203205
.epochs(10)
204206
.tolerance(1e-10)
205-
.addSampleSizes(5, 3)
207+
.sampleSizes(List.of(5, 3))
206208
.batchSize(5)
207209
.maxIterations(100)
208210
.randomSeed(42L)
@@ -228,17 +230,17 @@ void testLosses() {
228230
assertThat(epochLosses).isInstanceOf(List.class);
229231
assertThat(((List<Double>) epochLosses).stream().mapToDouble(Double::doubleValue).toArray())
230232
.contains(new double[]{
231-
91.33327272,
232-
88.17940500,
233-
87.68340477,
234-
85.60797746,
235-
85.59108701,
236-
85.59007234,
237-
81.44403525,
238-
81.44260858,
239-
81.44349342,
240-
81.45612978
241-
}, Offset.offset(1e-8)
233+
78.30,
234+
71.55,
235+
71.07,
236+
71.65,
237+
74.36,
238+
74.08,
239+
73.98,
240+
80.28,
241+
71.07,
242+
71.07
243+
}, Offset.offset(0.05)
242244
);
243245
}
244246

@@ -250,7 +252,7 @@ void testLossesWithPoolAggregator() {
250252
.aggregator(AggregatorType.POOL)
251253
.epochs(10)
252254
.tolerance(1e-10)
253-
.addSampleSizes(5, 3)
255+
.sampleSizes(List.of(5, 3))
254256
.batchSize(5)
255257
.maxIterations(100)
256258
.randomSeed(42L)
@@ -276,16 +278,16 @@ void testLossesWithPoolAggregator() {
276278
assertThat(epochLosses).isInstanceOf(List.class);
277279
assertThat(((List<Double>) epochLosses).stream().mapToDouble(Double::doubleValue).toArray())
278280
.contains(new double[]{
279-
90.53,
280-
83.29,
281-
74.75,
282-
74.61,
283-
74.68,
284-
74.54,
285-
74.46,
286-
74.47,
287-
74.41,
288-
74.41
281+
87.34,
282+
80.75,
283+
74.07,
284+
93.12,
285+
96.36,
286+
80.50,
287+
77.31,
288+
99.70,
289+
83.60,
290+
83.60
289291
}, Offset.offset(0.05)
290292
);
291293
}
@@ -306,6 +308,35 @@ void testConvergence() {
306308
assertThat(trainMetrics.ranIterationsPerEpoch()).containsExactly(2);
307309
}
308310

311+
@ParameterizedTest
312+
@CsvSource({
313+
"0.01, true, 8",
314+
"1.0, false, 10"
315+
})
316+
void batchesPerIteration(double batchSamplingRatio, boolean expectedConvergence, int expectedRanEpochs) {
317+
var trainer = new GraphSageModelTrainer(
318+
configBuilder.modelName("convergingModel:)")
319+
.maybeBatchSamplingRatio(batchSamplingRatio)
320+
.embeddingDimension(12)
321+
.aggregator(AggregatorType.POOL)
322+
.epochs(10)
323+
.tolerance(1e-10)
324+
.sampleSizes(List.of(5, 3))
325+
.batchSize(5)
326+
.maxIterations(100)
327+
.randomSeed(42L)
328+
.build(),
329+
Pools.DEFAULT,
330+
ProgressTracker.NULL_TRACKER
331+
);
332+
333+
var trainResult = trainer.train(graph, features);
334+
335+
var trainMetrics = trainResult.metrics();
336+
assertThat(trainMetrics.didConverge()).isEqualTo(expectedConvergence);
337+
assertThat(trainMetrics.ranEpochs()).isEqualTo(expectedRanEpochs);
338+
}
339+
309340
@ParameterizedTest
310341
@ValueSource(longs = {20L, -100L, 30L})
311342
void seededSingleBatch(long seed) {

0 commit comments

Comments
 (0)