|
30 | 30 | import com.apple.foundationdb.linear.Quantizer; |
31 | 31 | import com.apple.foundationdb.linear.RealVector; |
32 | 32 | import com.apple.foundationdb.linear.StoredVecsIterator; |
| 33 | +import com.apple.foundationdb.rabitq.EncodedRealVector; |
33 | 34 | import com.apple.foundationdb.test.TestDatabaseExtension; |
34 | 35 | import com.apple.foundationdb.test.TestExecutors; |
35 | 36 | import com.apple.foundationdb.test.TestSubspaceExtension; |
|
83 | 84 | import java.util.stream.LongStream; |
84 | 85 | import java.util.stream.Stream; |
85 | 86 |
|
| 87 | +import static com.apple.foundationdb.linear.RealVectorTest.createRandomDoubleVector; |
86 | 88 | import static com.apple.foundationdb.linear.RealVectorTest.createRandomHalfVector; |
87 | 89 | import static org.assertj.core.api.Assertions.within; |
88 | 90 |
|
@@ -285,6 +287,103 @@ void testBasicInsert(final long seed, final boolean useInlining, final boolean e |
285 | 287 | Assertions.assertThat(readIds.size()).isBetween(10, 50); |
286 | 288 | } |
287 | 289 |
|
| 290 | + @ParameterizedTest() |
| 291 | + @RandomSeedSource({0x0fdbL, 0x5ca1eL, 123456L, 78910L, 1123581321345589L}) |
| 292 | + void testBasicInsertWithRaBitQEncodings(final long seed) { |
| 293 | + final Random random = new Random(seed); |
| 294 | + final Metric metric = Metric.EUCLIDEAN_METRIC; |
| 295 | + |
| 296 | + final AtomicLong nextNodeIdAtomic = new AtomicLong(0L); |
| 297 | + final int numDimensions = 128; |
| 298 | + final HNSW hnsw = new HNSW(rtSubspace.getSubspace(), TestExecutors.defaultThreadPool(), |
| 299 | + HNSW.newConfigBuilder().setMetric(metric) |
| 300 | + .setUseRaBitQ(true) |
| 301 | + .setRaBitQNumExBits(5) |
| 302 | + .setSampleVectorStatsProbability(1.0d) // every vector is sampled |
| 303 | + .setMaintainStatsProbability(1.0d) // for every vector we maintain the stats |
| 304 | + .setStatsThreshold(950) // after 950 vectors we enable RaBitQ |
| 305 | + .setM(32).setMMax(32).setMMax0(64).build(numDimensions), |
| 306 | + OnWriteListener.NOOP, OnReadListener.NOOP); |
| 307 | + |
| 308 | + final int k = 499; |
| 309 | + final DoubleRealVector queryVector = createRandomDoubleVector(random, numDimensions); |
| 310 | + final Map<Tuple, RealVector> dataMap = Maps.newHashMap(); |
| 311 | + final TreeSet<PrimaryKeyVectorAndDistance> recordsOrderedByDistance = |
| 312 | + new TreeSet<>(Comparator.comparing(PrimaryKeyVectorAndDistance::getDistance)); |
| 313 | + |
| 314 | + for (int i = 0; i < 1000;) { |
| 315 | + i += basicInsertBatch(hnsw, 100, nextNodeIdAtomic, new TestOnReadListener(), |
| 316 | + tr -> { |
| 317 | + final var primaryKey = createNextPrimaryKey(nextNodeIdAtomic); |
| 318 | + final DoubleRealVector dataVector = createRandomDoubleVector(random, numDimensions); |
| 319 | + final double distance = metric.distance(dataVector, queryVector); |
| 320 | + dataMap.put(primaryKey, dataVector); |
| 321 | + |
| 322 | + final PrimaryKeyVectorAndDistance record = |
| 323 | + new PrimaryKeyVectorAndDistance(primaryKey, dataVector, distance); |
| 324 | + recordsOrderedByDistance.add(record); |
| 325 | + if (recordsOrderedByDistance.size() > k) { |
| 326 | + recordsOrderedByDistance.pollLast(); |
| 327 | + } |
| 328 | + return record; |
| 329 | + }); |
| 330 | + } |
| 331 | + |
| 332 | + // |
| 333 | + // If we fetch the current state back from the db some vectors are regular vectors and some vectors are |
| 334 | + // RaBitQ encoded. Since that information is not surfaced through the API, we need to scan layer 0, get |
| 335 | + // all vectors directly from disk (encoded/not-encoded, transformed/not-transformed) in order to check |
| 336 | + // that transformations/reconstructions are applied properly. |
| 337 | + // |
| 338 | + final Map<Tuple, RealVector> fromDBMap = Maps.newHashMap(); |
| 339 | + hnsw.scanLayer(db, 0, 100, |
| 340 | + node -> fromDBMap.put(node.getPrimaryKey(), |
| 341 | + node.asCompactNode().getVector().getUnderlyingVector())); |
| 342 | + |
| 343 | + // |
| 344 | + // Still run a kNN search to make sure that recall is satisfactory. |
| 345 | + // |
| 346 | + final List<? extends ResultEntry> results = |
| 347 | + db.run(tr -> |
| 348 | + hnsw.kNearestNeighborsSearch(tr, k, 500, true, queryVector).join()); |
| 349 | + |
| 350 | + final ImmutableSet<Tuple> trueNN = |
| 351 | + recordsOrderedByDistance.stream() |
| 352 | + .map(PrimaryKeyAndVector::getPrimaryKey) |
| 353 | + .collect(ImmutableSet.toImmutableSet()); |
| 354 | + |
| 355 | + int recallCount = 0; |
| 356 | + int exactVectorCount = 0; |
| 357 | + int encodedVectorCount = 0; |
| 358 | + for (final ResultEntry resultEntry : results) { |
| 359 | + if (trueNN.contains(resultEntry.getPrimaryKey())) { |
| 360 | + recallCount ++; |
| 361 | + } |
| 362 | + |
| 363 | + final RealVector originalVector = dataMap.get(resultEntry.getPrimaryKey()); |
| 364 | + Assertions.assertThat(originalVector).isNotNull(); |
| 365 | + final RealVector fromDBVector = fromDBMap.get(resultEntry.getPrimaryKey()); |
| 366 | + Assertions.assertThat(fromDBVector).isNotNull(); |
| 367 | + if (!(fromDBVector instanceof EncodedRealVector)) { |
| 368 | + Assertions.assertThat(originalVector).isEqualTo(fromDBVector); |
| 369 | + exactVectorCount ++; |
| 370 | + final double distance = metric.distance(originalVector, |
| 371 | + Objects.requireNonNull(resultEntry.getVector())); |
| 372 | + Assertions.assertThat(distance).isCloseTo(0.0d, within(2E-12)); |
| 373 | + } else { |
| 374 | + encodedVectorCount ++; |
| 375 | + final double distance = metric.distance(originalVector, |
| 376 | + Objects.requireNonNull(resultEntry.getVector()).toDoubleRealVector()); |
| 377 | + Assertions.assertThat(distance).isCloseTo(0.0d, within(20.0d)); |
| 378 | + } |
| 379 | + } |
| 380 | + final double recall = (double)recallCount / (double)k; |
| 381 | + Assertions.assertThat(recall).isGreaterThan(0.9); |
| 382 | + // must have both kinds |
| 383 | + Assertions.assertThat(exactVectorCount).isGreaterThan(0); |
| 384 | + Assertions.assertThat(encodedVectorCount).isGreaterThan(0); |
| 385 | + } |
| 386 | + |
288 | 387 | private int basicInsertBatch(final HNSW hnsw, final int batchSize, |
289 | 388 | @Nonnull final AtomicLong nextNodeIdAtomic, @Nonnull final TestOnReadListener onReadListener, |
290 | 389 | @Nonnull final Function<Transaction, PrimaryKeyAndVector> insertFunction) { |
|
0 commit comments