Skip to content

Commit 2dc59a9

Browse files
committed
adding one more test
1 parent b0b03ff commit 2dc59a9

File tree

3 files changed

+106
-7
lines changed

3 files changed

+106
-7
lines changed

fdb-extensions/src/main/java/com/apple/foundationdb/async/hnsw/InliningStorageAdapter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ private NodeReferenceWithVector neighborFromTuples(@Nonnull final AffineOperator
189189
// Transform the raw vector that was just fetched into the internal coordinate system. If we do not have
190190
// a need to transform coordinates, this transform is the identity transformation. Vectors are always stored
191191
// in the internal coordinate system in use at the time the vector is written. If that coordinate system changes
192-
// afterward, for instance RaBitQ by enabling RaBitQ, subsequent reads of vectors that were written prior to
192+
// afterward, for instance by enabling RaBitQ, subsequent reads of vectors that were written prior to
193193
// the coordinate system change need to be transformed when they are read back.
194194
//
195195
final Transformed<RealVector> neighborVector =

fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/DataRecordsTest.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ void testCompactNode(final long randomSeed) {
5656

5757
final CompactNode compactNode1 = compactNode(new Random(dependentRandomSeed));
5858
final CompactNode compactNode1Clone = compactNode(new Random(dependentRandomSeed));
59-
Assertions.assertThat(compactNode1.toString()).isEqualTo(compactNode1Clone.toString());
59+
Assertions.assertThat(compactNode1).hasToString(compactNode1Clone.toString());
6060

6161
final CompactNode compactNode2 = compactNode(random);
62-
Assertions.assertThat(compactNode1.toString()).isNotEqualTo(compactNode2.toString());
62+
Assertions.assertThat(compactNode1).doesNotHaveToString(compactNode2.toString());
6363

6464
Assertions.assertThatThrownBy(compactNode1::asInliningNode).isInstanceOf(IllegalStateException.class);
6565
}
@@ -72,10 +72,10 @@ void testInliningNode(final long randomSeed) {
7272

7373
final InliningNode inliningNode1 = inliningNode(new Random(dependentRandomSeed));
7474
final InliningNode inliningNode1Clone = inliningNode(new Random(dependentRandomSeed));
75-
Assertions.assertThat(inliningNode1.toString()).isEqualTo(inliningNode1Clone.toString());
75+
Assertions.assertThat(inliningNode1).hasToString(inliningNode1Clone.toString());
7676

7777
final InliningNode inliningNode2 = inliningNode(random);
78-
Assertions.assertThat(inliningNode1.toString()).isNotEqualTo(inliningNode2.toString());
78+
Assertions.assertThat(inliningNode1).doesNotHaveToString(inliningNode2.toString());
7979

8080
Assertions.assertThatThrownBy(inliningNode1::asCompactNode).isInstanceOf(IllegalStateException.class);
8181
}
@@ -126,11 +126,11 @@ private static <T> void assertHashCodeEqualsToString(final long randomSeed, fina
126126
final T t1Clone = createFunction.apply(new Random(dependentRandomSeed));
127127
Assertions.assertThat(t1.hashCode()).isEqualTo(t1Clone.hashCode());
128128
Assertions.assertThat(t1).isEqualTo(t1Clone);
129-
Assertions.assertThat(t1.toString()).isEqualTo(t1Clone.toString());
129+
Assertions.assertThat(t1).hasToString(t1Clone.toString());
130130

131131
final T t2 = createFunction.apply(random);
132132
Assertions.assertThat(t1).isNotEqualTo(t2);
133-
Assertions.assertThat(t1.toString()).isNotEqualTo(t2.toString());
133+
Assertions.assertThat(t1).doesNotHaveToString(t2.toString());
134134
}
135135

136136
@Nonnull

fdb-extensions/src/test/java/com/apple/foundationdb/async/hnsw/HNSWTest.java

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import com.apple.foundationdb.linear.Quantizer;
3131
import com.apple.foundationdb.linear.RealVector;
3232
import com.apple.foundationdb.linear.StoredVecsIterator;
33+
import com.apple.foundationdb.rabitq.EncodedRealVector;
3334
import com.apple.foundationdb.test.TestDatabaseExtension;
3435
import com.apple.foundationdb.test.TestExecutors;
3536
import com.apple.foundationdb.test.TestSubspaceExtension;
@@ -83,6 +84,7 @@
8384
import java.util.stream.LongStream;
8485
import java.util.stream.Stream;
8586

87+
import static com.apple.foundationdb.linear.RealVectorTest.createRandomDoubleVector;
8688
import static com.apple.foundationdb.linear.RealVectorTest.createRandomHalfVector;
8789
import static org.assertj.core.api.Assertions.within;
8890

@@ -285,6 +287,103 @@ void testBasicInsert(final long seed, final boolean useInlining, final boolean e
285287
Assertions.assertThat(readIds.size()).isBetween(10, 50);
286288
}
287289

290+
@ParameterizedTest()
291+
@RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L})
292+
void testBasicInsertWithRaBitQEncodings(final long seed) {
293+
final Random random = new Random(seed);
294+
final Metric metric = Metric.EUCLIDEAN_METRIC;
295+
296+
final AtomicLong nextNodeIdAtomic = new AtomicLong(0L);
297+
final int numDimensions = 128;
298+
final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(),
299+
HNSW.newConfigBuilder().setMetric(metric)
300+
.setUseRaBitQ(true)
301+
.setRaBitQNumExBits(5)
302+
.setSampleVectorStatsProbability(1.0d) // every vector is sampled
303+
.setMaintainStatsProbability(1.0d) // for every vector we maintain the stats
304+
.setStatsThreshold(950) // after 950 vectors we enable RaBitQ
305+
.setM(32).setMMax(32).setMMax0(64).build(numDimensions),
306+
OnWriteListener.NOOP, OnReadListener.NOOP);
307+
308+
final int k = 499;
309+
final DoubleRealVector queryVector = createRandomDoubleVector(random, numDimensions);
310+
final Map<Tuple, RealVector> dataMap = Maps.newHashMap();
311+
final TreeSet<PrimaryKeyVectorAndDistance> recordsOrderedByDistance =
312+
new TreeSet<>(Comparator.comparing(PrimaryKeyVectorAndDistance::getDistance));
313+
314+
for (int i = 0; i < 1000;) {
315+
i += basicInsertBatch(hnsw, 100, nextNodeIdAtomic, new TestOnReadListener(),
316+
tr -> {
317+
final var primaryKey = createNextPrimaryKey(nextNodeIdAtomic);
318+
final DoubleRealVector dataVector = createRandomDoubleVector(random, numDimensions);
319+
final double distance = metric.distance(dataVector, queryVector);
320+
dataMap.put(primaryKey, dataVector);
321+
322+
final PrimaryKeyVectorAndDistance record =
323+
new PrimaryKeyVectorAndDistance(primaryKey, dataVector, distance);
324+
recordsOrderedByDistance.add(record);
325+
if (recordsOrderedByDistance.size() > k) {
326+
recordsOrderedByDistance.pollLast();
327+
}
328+
return record;
329+
});
330+
}
331+
332+
//
333+
// If we fetch the current state back from the db some vectors are regular vectors and some vectors are
334+
// RaBitQ encoded. Since that information is not surfaced through the API, we need to scan layer 0, get
335+
// all vectors directly from disk (encoded/not-encoded, transformed/not-transformed) in order to check
336+
// that transformations/reconstructions are applied properly.
337+
//
338+
final Map<Tuple, RealVector> fromDBMap = Maps.newHashMap();
339+
hnsw.scanLayer(db, 0, 100,
340+
node -> fromDBMap.put(node.getPrimaryKey(),
341+
node.asCompactNode().getVector().getUnderlyingVector()));
342+
343+
//
344+
// Still run a kNN search to make sure that recall is satisfactory.
345+
//
346+
final List<? extends ResultEntry> results =
347+
db.run(tr ->
348+
hnsw.kNearestNeighborsSearch(tr, k, 500, true, queryVector).join());
349+
350+
final ImmutableSet<Tuple> trueNN =
351+
recordsOrderedByDistance.stream()
352+
.map(PrimaryKeyAndVector::getPrimaryKey)
353+
.collect(ImmutableSet.toImmutableSet());
354+
355+
int recallCount = 0;
356+
int exactVectorCount = 0;
357+
int encodedVectorCount = 0;
358+
for (final ResultEntry resultEntry : results) {
359+
if (trueNN.contains(resultEntry.getPrimaryKey())) {
360+
recallCount ++;
361+
}
362+
363+
final RealVector originalVector = dataMap.get(resultEntry.getPrimaryKey());
364+
Assertions.assertThat(originalVector).isNotNull();
365+
final RealVector fromDBVector = fromDBMap.get(resultEntry.getPrimaryKey());
366+
Assertions.assertThat(fromDBVector).isNotNull();
367+
if (!(fromDBVector instanceof EncodedRealVector)) {
368+
Assertions.assertThat(originalVector).isEqualTo(fromDBVector);
369+
exactVectorCount ++;
370+
final double distance = metric.distance(originalVector,
371+
Objects.requireNonNull(resultEntry.getVector()));
372+
Assertions.assertThat(distance).isCloseTo(0.0d, within(2E-12));
373+
} else {
374+
encodedVectorCount ++;
375+
final double distance = metric.distance(originalVector,
376+
Objects.requireNonNull(resultEntry.getVector()).toDoubleRealVector());
377+
Assertions.assertThat(distance).isCloseTo(0.0d, within(20.0d));
378+
}
379+
}
380+
final double recall = (double)recallCount / (double)k;
381+
Assertions.assertThat(recall).isGreaterThan(0.9);
382+
// must have both kinds
383+
Assertions.assertThat(exactVectorCount).isGreaterThan(0);
384+
Assertions.assertThat(encodedVectorCount).isGreaterThan(0);
385+
}
386+
288387
private int basicInsertBatch(final HNSW hnsw, final int batchSize,
289388
@Nonnull final AtomicLong nextNodeIdAtomic, @Nonnull final TestOnReadListener onReadListener,
290389
@Nonnull final Function<Transaction, PrimaryKeyAndVector> insertFunction) {

0 commit comments

Comments
 (0)