diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java index 2366a6d579..e24e1c5544 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java @@ -11,6 +11,7 @@ package org.opensearch.knn.index.codec.KNN990Codec; +import lombok.extern.slf4j.Slf4j; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.index.ByteVectorValues; @@ -22,12 +23,16 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.IOSupplier; import org.apache.lucene.util.IOUtils; import org.opensearch.common.UUIDs; import org.opensearch.knn.index.codec.util.KNNCodecUtil; import org.opensearch.knn.index.codec.util.NativeMemoryCacheKeyHelper; +import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.memoryoptsearch.VectorSearcher; +import org.opensearch.knn.memoryoptsearch.VectorSearcherFactory; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager; import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; @@ -37,23 +42,89 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; + +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapper.KNN_FIELD; /** * Vectors reader class for reading the flat vectors for native engines. The class provides methods for iterating * over the vectors and retrieving their values. */ +@Slf4j public class NativeEngines990KnnVectorsReader extends KnnVectorsReader { + private static final int RESERVE_TWICE_SPACE = 2; + private static final float SUFFICIENT_LOAD_FACTOR = 0.6f; private final FlatVectorsReader flatVectorsReader; private Map quantizationStateCacheKeyPerField; private SegmentReadState segmentReadState; private final List cacheKeys; + private Map vectorSearchers; public NativeEngines990KnnVectorsReader(final SegmentReadState state, final FlatVectorsReader flatVectorsReader) { this.flatVectorsReader = flatVectorsReader; this.segmentReadState = state; this.cacheKeys = getVectorCacheKeysFromSegmentReaderState(state); + this.vectorSearchers = new HashMap<>(RESERVE_TWICE_SPACE * segmentReadState.fieldInfos.size(), SUFFICIENT_LOAD_FACTOR); loadCacheKeyMap(); + + // + // TMP(KDY) : Dynamic update will be covered in part-7. Please refer to + // https://github.com/opensearch-project/k-NN/issues/2401#issuecomment-2699777824 + // + final boolean isMemoryOptimizedSearchEnabled = false; + if (isMemoryOptimizedSearchEnabled) { + loadMemoryOptimizedSearcher(); + } + } + + private IOSupplier getIndexFileNameIfMemoryOptimizedSearchSupported(final FieldInfo fieldInfo) { + // Skip non-knn fields. + final Map attributes = fieldInfo.attributes(); + if (attributes == null || attributes.containsKey(KNN_FIELD) == false) { + return null; + } + + // Get engine + final String engineName = attributes.getOrDefault(KNN_ENGINE, KNNEngine.DEFAULT.getName()); + final KNNEngine knnEngine = KNNEngine.getEngine(engineName); + + // Get memory optimized searcher from engine + final VectorSearcherFactory searcherFactory = knnEngine.getVectorSearcherFactory(); + if (searcherFactory == null) { + // It's not supported + return null; + } + + // Start creating searcher + final String fileName = KNNCodecUtil.getNativeEngineFileFromFieldInfo(fieldInfo, segmentReadState.segmentInfo); + if (fileName != null) { + return () -> searcherFactory.createVectorSearcher(segmentReadState.directory, fileName); + } + + // Not supported + return null; + } + + private void loadMemoryOptimizedSearcher() { + try { + for (FieldInfo fieldInfo : segmentReadState.fieldInfos) { + final IOSupplier searcherSupplier = getIndexFileNameIfMemoryOptimizedSearchSupported(fieldInfo); + if (searcherSupplier != null) { + final VectorSearcher searcher = Objects.requireNonNull(searcherSupplier.get()); + vectorSearchers.put(fieldInfo.getName(), searcher); + } + } + } catch (Exception e) { + // Close opened searchers first, then suppress + try { + IOUtils.closeWhileHandlingException(vectorSearchers.values()); + } catch (Exception closeException) { + log.error(closeException.getMessage(), closeException); + } + throw new RuntimeException(e); + } } /** @@ -135,6 +206,14 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits ((QuantizationConfigKNNCollector) knnCollector).setQuantizationState(quantizationState); return; } + + // Try with memory optimized searcher + final VectorSearcher memoryOptimizedSearcher = vectorSearchers.get(field); + if (memoryOptimizedSearcher != null) { + memoryOptimizedSearcher.search(target, knnCollector, acceptDocs); + return; + } + throw new UnsupportedOperationException("Search functionality using codec is not supported with Native Engine Reader"); } @@ -197,6 +276,9 @@ public void close() throws IOException { quantizationStateCacheManager.evict(cacheKey); } } + + // TODO(KDY) + // Close all memory optimized searchers. } private void loadCacheKeyMap() { diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java index 0bd4b0f27a..6d073250f5 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java @@ -8,6 +8,7 @@ import com.google.common.collect.ImmutableSet; import org.opensearch.common.ValidationException; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.memoryoptsearch.VectorSearcherFactory; import org.opensearch.knn.index.engine.faiss.Faiss; import org.opensearch.knn.index.engine.lucene.Lucene; import org.opensearch.knn.index.engine.nmslib.Nmslib; @@ -216,4 +217,9 @@ public ResolvedMethodContext resolveMethod( public boolean supportsRemoteIndexBuild() { return knnLibrary.supportsRemoteIndexBuild(); } + + @Override + public VectorSearcherFactory getVectorSearcherFactory() { + return knnLibrary.getVectorSearcherFactory(); + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java index 29e6442f48..818fb3dbca 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java @@ -7,6 +7,7 @@ import org.opensearch.common.ValidationException; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.memoryoptsearch.VectorSearcherFactory; import java.util.Collections; import java.util.List; @@ -140,11 +141,21 @@ default List mmapFileExtensions() { return Collections.emptyList(); } - /** + /* * Returns whether or not the engine implementation supports remote index build * @return true if remote index build is supported, false otherwise */ default boolean supportsRemoteIndexBuild() { return false; } + + /** + * Create a new vector searcher factory that compatible with on Lucene search API. + * @return New searcher factory that returns {@link org.opensearch.knn.memoryoptsearch.VectorSearcher} + * If it is not supported, it should return null. + * But, if it is supported, the factory shall not return null searcher. + */ + default VectorSearcherFactory getVectorSearcherFactory() { + return null; + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java b/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java index d23a475aa7..4206f692e7 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java @@ -8,12 +8,14 @@ import com.google.common.collect.ImmutableMap; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.memoryoptsearch.VectorSearcherFactory; import org.opensearch.knn.index.engine.KNNMethod; import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodResolver; import org.opensearch.knn.index.engine.NativeLibrary; import org.opensearch.knn.index.engine.ResolvedMethodContext; +import org.opensearch.knn.memoryoptsearch.faiss.FaissMemoryOptimizedSearcherFactory; import java.util.Map; import java.util.function.Function; @@ -123,4 +125,9 @@ public ResolvedMethodContext resolveMethod( public boolean supportsRemoteIndexBuild() { return true; } + + @Override + public VectorSearcherFactory getVectorSearcherFactory() { + return new FaissMemoryOptimizedSearcherFactory(); + } } diff --git a/src/main/java/org/opensearch/knn/memoryoptsearch/VectorSearcher.java b/src/main/java/org/opensearch/knn/memoryoptsearch/VectorSearcher.java new file mode 100644 index 0000000000..cefaf0c7c0 --- /dev/null +++ b/src/main/java/org/opensearch/knn/memoryoptsearch/VectorSearcher.java @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.memoryoptsearch; + +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.util.Bits; + +import java.io.Closeable; +import java.io.IOException; + +/** + * This searcher performs vector search on non-Lucene index, for example FAISS index. + * Two search APIs will be compatible with Lucene, taking {@link KnnCollector} and {@link Bits}. + * In its implementation, it must collect top vectors that is similar to the given query. Make sure to transform the result to similarity + * value if internally calculates distance between. + */ +public interface VectorSearcher extends Closeable { + /** + * Return the k nearest neighbor documents as determined by comparison of their vector values for + * this field, to the given vector, by the field's similarity function. The score of each document + * is derived from the vector similarity in a way that ensures scores are positive and that a + * larger score corresponds to a higher ranking. + * + *

The search is allowed to be approximate, meaning the results are not guaranteed to be the + * true k closest neighbors. For large values of k (for example when k is close to the total + * number of documents), the search may also retrieve fewer than k documents. + * + * @param target the vector-valued float vector query + * @param knnCollector a KnnResults collector and relevant settings for gathering vector results + * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} + * if they are all allowed to match. + */ + void search(float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException; + + /** + * Return the k nearest neighbor documents as determined by comparison of their vector values for + * this field, to the given vector, by the field's similarity function. The score of each document + * is derived from the vector similarity in a way that ensures scores are positive and that a + * larger score corresponds to a higher ranking. + * + *

The search is allowed to be approximate, meaning the results are not guaranteed to be the + * true k closest neighbors. For large values of k (for example when k is close to the total + * number of documents), the search may also retrieve fewer than k documents. + * + * @param target the vector-valued byte vector query + * @param knnCollector a KnnResults collector and relevant settings for gathering vector results + * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} + * if they are all allowed to match. + */ + void search(byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException; +} diff --git a/src/main/java/org/opensearch/knn/memoryoptsearch/VectorSearcherFactory.java b/src/main/java/org/opensearch/knn/memoryoptsearch/VectorSearcherFactory.java new file mode 100644 index 0000000000..d87cf7aa02 --- /dev/null +++ b/src/main/java/org/opensearch/knn/memoryoptsearch/VectorSearcherFactory.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.memoryoptsearch; + +import org.apache.lucene.store.Directory; + +import java.io.IOException; + +/** + * Factory to create {@link VectorSearcher}. + * Provided parameters will have {@link Directory} and a file name where implementation can rely on it to open an input stream. + */ +public interface VectorSearcherFactory { + /** + * Create a non-null {@link VectorSearcher} with given Lucene's {@link Directory}. + * + * @param directory Lucene's Directory. + * @param fileName Logical file name to load. + * @return It must return a non-null {@link VectorSearcher} + * @throws IOException + */ + VectorSearcher createVectorSearcher(Directory directory, String fileName) throws IOException; +} diff --git a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissMemoryOptimizedSearcher.java b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissMemoryOptimizedSearcher.java new file mode 100644 index 0000000000..c3d07325b5 --- /dev/null +++ b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissMemoryOptimizedSearcher.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.memoryoptsearch.faiss; + +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.opensearch.knn.memoryoptsearch.VectorSearcher; + +import java.io.IOException; + +/** + * This searcher directly reads FAISS index file via the provided {@link IndexInput} then perform vector search on it. + */ +public class FaissMemoryOptimizedSearcher implements VectorSearcher { + private final IndexInput indexInput; + + public FaissMemoryOptimizedSearcher(IndexInput indexInput) { + this.indexInput = indexInput; + } + + @Override + public void search(float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + // TODO(KDY) : This will be covered in subsequent parts. + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public void search(byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + // TODO(KDY) : This will be covered in subsequent parts. + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public void close() throws IOException { + indexInput.close(); + } +} diff --git a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissMemoryOptimizedSearcherFactory.java b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissMemoryOptimizedSearcherFactory.java new file mode 100644 index 0000000000..73747c5c12 --- /dev/null +++ b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissMemoryOptimizedSearcherFactory.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.memoryoptsearch.faiss; + +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.ReadAdvice; +import org.opensearch.knn.memoryoptsearch.VectorSearcher; +import org.opensearch.knn.memoryoptsearch.VectorSearcherFactory; + +import java.io.IOException; + +/** + * This factory returns {@link VectorSearcher} that performs vector search directly on FAISS index. + * Note that we pass `RANDOM` as advice to prevent the underlying storage from performing read-ahead. Since vector search naturally accesses + * random vector locations, read-ahead does not improve performance. By passing the `RANDOM` context, we explicitly indicate that + * this searcher will access vectors randomly. + */ +public class FaissMemoryOptimizedSearcherFactory implements VectorSearcherFactory { + @Override + public VectorSearcher createVectorSearcher(final Directory directory, final String fileName) throws IOException { + final IndexInput indexInput = directory.openInput( + fileName, + new IOContext(IOContext.Context.DEFAULT, null, null, ReadAdvice.RANDOM) + ); + return new FaissMemoryOptimizedSearcher(indexInput); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReaderTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReaderTests.java new file mode 100644 index 0000000000..8ac925010f --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReaderTests.java @@ -0,0 +1,103 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN990Codec; + +import lombok.SneakyThrows; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IndexInput; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.memoryoptsearch.VectorSearcher; + +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.Collections; +import java.util.Map; +import java.util.Set; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapper.KNN_FIELD; + +public class NativeEngines990KnnVectorsReaderTests extends KNNTestCase { + @SneakyThrows + public void testWhenMemoryOptimizedSearchIsEnabled_emptyCase() { + // Prepare field infos + final FieldInfo[] fieldInfoArray = new FieldInfo[] {}; + final FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + + // Load vector searchers + final Map vectorSearchers = loadSearchers(fieldInfos, Collections.emptySet()); + assertTrue(vectorSearchers.isEmpty()); + } + + @SneakyThrows + public void testWhenMemoryOptimizedSearchIsEnabled_mixedCase() { + // Prepare field infos + // - field1: Non KNN field + // - field2: KNN field, but using Lucene engine + // - field3: KNN field, FAISS + // - field4: KNN field, FAISS + // - field5: KNN field, FAISS, but it does not have file for some reason. + final FieldInfo[] fieldInfoArray = new FieldInfo[] { + createFieldInfo("field1", null, 0), + createFieldInfo("field2", KNNEngine.LUCENE, 1), + createFieldInfo("field3", KNNEngine.FAISS, 2), + createFieldInfo("field4", KNNEngine.FAISS, 3), + createFieldInfo("field5", KNNEngine.FAISS, 4) }; + final FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + final Set filesInSegment = Set.of("_0_165_field3.faiss", "_0_165_field4.faiss"); + + // Load vector searchers + final Map vectorSearchers = loadSearchers(fieldInfos, filesInSegment); + + // Validate #searchers + assertEquals(2, vectorSearchers.size()); + } + + private static FieldInfo createFieldInfo(final String fieldName, final KNNEngine engine, final int fieldNo) { + final KNNCodecTestUtil.FieldInfoBuilder builder = KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName); + builder.fieldNumber(fieldNo); + if (engine != null) { + builder.addAttribute(KNN_FIELD, "true"); + builder.addAttribute(KNN_ENGINE, engine.getName()); + } + return builder.build(); + } + + @SneakyThrows + private static Map loadSearchers(final FieldInfos fieldInfos, final Set filesInSegment) { + // Prepare infra + final IndexInput mockIndexInput = mock(IndexInput.class); + final Directory mockDirectory = mock(Directory.class); + when(mockDirectory.openInput(any(), any())).thenReturn(mockIndexInput); + final SegmentInfo segmentInfo = mock(SegmentInfo.class); + when(segmentInfo.files()).thenReturn(filesInSegment); + when(segmentInfo.getId()).thenReturn((segmentInfo.hashCode() + "").getBytes()); + final SegmentReadState readState = new SegmentReadState(mockDirectory, segmentInfo, fieldInfos, null); + + // Create reader + final NativeEngines990KnnVectorsReader reader = new NativeEngines990KnnVectorsReader(readState, null); + final Class clazz = NativeEngines990KnnVectorsReader.class; + + // Call loadMemoryOptimizedSearcher() + final Method loadMethod = clazz.getDeclaredMethod("loadMemoryOptimizedSearcher"); + loadMethod.setAccessible(true); + loadMethod.invoke(reader); + + // Get searcher table + final Field tableField = clazz.getDeclaredField("vectorSearchers"); + tableField.setAccessible(true); + return (Map) tableField.get(reader); + } +} diff --git a/src/test/java/org/opensearch/knn/memoryoptsearch/FaissEngineTests.java b/src/test/java/org/opensearch/knn/memoryoptsearch/FaissEngineTests.java new file mode 100644 index 0000000000..4db4e28da0 --- /dev/null +++ b/src/test/java/org/opensearch/knn/memoryoptsearch/FaissEngineTests.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.memoryoptsearch; + +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IndexInput; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.engine.KNNEngine; + +import java.io.IOException; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class FaissEngineTests extends KNNTestCase { + public void testFaissEngineToReturnSearcher() throws IOException { + final VectorSearcherFactory factory = KNNEngine.FAISS.getVectorSearcherFactory(); + assertNotNull(factory); + + final IndexInput mockIndexInput = mock(IndexInput.class); + final Directory mockDirectory = mock(Directory.class); + when(mockDirectory.openInput(any(), any())).thenReturn(mockIndexInput); + final String fileName = "_0_165_target_field.faiss"; + try (VectorSearcher searcher = factory.createVectorSearcher(mockDirectory, fileName)) { + assertNotNull(searcher); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public void testNonFaissEngineToReturnNullSearcher() { + assertNull(KNNEngine.LUCENE.getVectorSearcherFactory()); + assertNull(KNNEngine.NMSLIB.getVectorSearcherFactory()); + } +}