diff --git a/jni/include/faiss_index_service.h b/jni/include/faiss_index_service.h index d96c3e7557..eb92b2ad70 100644 --- a/jni/include/faiss_index_service.h +++ b/jni/include/faiss_index_service.h @@ -60,6 +60,17 @@ class IndexService { */ virtual void insertToIndex(int dim, int numIds, int threadCount, int64_t vectorsAddress, std::vector &ids, jlong idMapAddress); + /** + * Build a flat (IndexFlat) FAISS index from a set of vectors. + * + * @param numVectors Number of vectors to add to the index + * @param dim Dimension of each vector + * @param vectors Vector data as a flat array (size should be numVectors * dim) + * @param metricType Metric type for distance calculations (e.g., L2, IP) + * @return Memory address of the native IndexFlat object + */ + jlong buildFlatIndexFromVectors(int numVectors, int dim, const std::vector &vectors, faiss::MetricType metricType); + /** * Write index to disk * diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index e48e6faa91..6431a86429 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -23,6 +23,10 @@ namespace knn_jni { void InsertToIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jlong indexAddr, jint threadCount, IndexService *indexService); + // Build a flat (IndexFlat) FAISS index from a set of vectors. + // Returns the memory address of the created index. + jlong BuildFlatIndexFromVectors(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jfloatArray vectorsJ, jint numVectors, jint dimJ, jstring metricTypeJ, knn_jni::faiss_wrapper::IndexService *indexService); + void WriteIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jobject output, jlong indexAddr, IndexService *indexService); // Create an index with ids and vectors. Instead of creating a new index, this function creates the index diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index dce5801383..ada950207e 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -52,6 +52,14 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initByteIndex(J JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JNIEnv * env, jclass cls, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jlong indexAddress, jint threadCount); + +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: buildFlatIndexFromVectors + * Signature: ([FIILjava/lang/String;)J + */ +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_buildFlatIndexFromVectors(JNIEnv *, jclass, jfloatArray, jint, jint, jstring); + /* * Class: org_opensearch_knn_jni_FaissService * Method: insertToBinaryIndex diff --git a/jni/src/faiss_index_service.cpp b/jni/src/faiss_index_service.cpp index 4999e31729..d253630112 100644 --- a/jni/src/faiss_index_service.cpp +++ b/jni/src/faiss_index_service.cpp @@ -16,12 +16,16 @@ #include "faiss/IndexIVFFlat.h" #include "faiss/IndexBinaryIVF.h" #include "faiss/IndexIDMap.h" +#include "faiss/IndexFlat.h" #include #include #include #include +#include +#include + namespace knn_jni { namespace faiss_wrapper { @@ -138,6 +142,33 @@ void IndexService::insertToIndex( idMap->add_with_ids(numVectors, inputVectors->data(), ids.data()); } +jlong IndexService::buildFlatIndexFromVectors( + int numVectors, + int dim, + const std::vector &vectors, + faiss::MetricType metricType) { + + if (vectors.empty()) { + throw std::runtime_error("Input vectors cannot be empty"); + } + + if ((int)vectors.size() != numVectors * dim) { + throw std::runtime_error("Mismatch between vector count/dimension and actual data length"); + } + + faiss::IndexFlat *index = nullptr; + + if (metricType == faiss::METRIC_INNER_PRODUCT) { + index = new faiss::IndexFlatIP(dim); + } else { + index = new faiss::IndexFlatL2(dim); + } + + index->add(numVectors, vectors.data()); + + return reinterpret_cast(index); +} + void IndexService::writeIndex( faiss::IOWriter* writer, jlong idMapAddress diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index c02c410c1f..c892b55738 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -183,6 +183,38 @@ void knn_jni::faiss_wrapper::InsertToIndex(knn_jni::JNIUtilInterface * jniUtil, indexService->insertToIndex(dim, numIds, threadCount, vectorsAddress, ids, index_ptr); } +jlong knn_jni::faiss_wrapper::BuildFlatIndexFromVectors(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jfloatArray vectorsJ, jint numVectors, + jint dimJ, jstring metricTypeJ, IndexService *indexService) { + + if (vectorsJ == nullptr) { + throw std::runtime_error("Vector data cannot be null"); + } + + if (dimJ <= 0 || numVectors <= 0) { + throw std::runtime_error("Invalid dimensions or number of vectors"); + } + + const char *metricTypeC = env->GetStringUTFChars(metricTypeJ, nullptr); + jsize totalLength = env->GetArrayLength(vectorsJ); + + if (totalLength != numVectors * dimJ) { + env->ReleaseStringUTFChars(metricTypeJ, metricTypeC); + throw std::runtime_error("Vector data length does not match numVectors * dimension"); + } + + jfloat* vectors = env->GetFloatArrayElements(vectorsJ, nullptr); + std::vector cppVectors(vectors, vectors + totalLength); + + faiss::MetricType metric = (strcmp(metricTypeC, "IP") == 0) ? faiss::METRIC_INNER_PRODUCT : faiss::METRIC_L2; + + jlong indexPtr = indexService->buildFlatIndexFromVectors(numVectors, dimJ, cppVectors, metric); + + env->ReleaseFloatArrayElements(vectorsJ, vectors, JNI_ABORT); + env->ReleaseStringUTFChars(metricTypeJ, metricTypeC); + + return indexPtr; +} + void knn_jni::faiss_wrapper::WriteIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject output, jlong index_ptr, IndexService* indexService) { @@ -597,7 +629,7 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter } else { auto ivfReader = dynamic_cast(indexReader->index); auto ivfFlatReader = dynamic_cast(indexReader->index); - + if(ivfReader || ivfFlatReader) { int indexNprobe = ivfReader == nullptr ? ivfFlatReader->nprobe : ivfReader->nprobe; ivfParams.nprobe = commons::getIntegerMethodParameter(env, jniUtil, methodParams, NPROBES, indexNprobe); @@ -1228,4 +1260,4 @@ jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInter } return results; -} +} \ No newline at end of file diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 836774402f..0dfae2ac62 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -19,6 +19,8 @@ #include "jni_util.h" #include "faiss_stream_support.h" + + static knn_jni::JNIUtil jniUtil; static const jint KNN_FAISS_JNI_VERSION = JNI_VERSION_1_1; @@ -96,6 +98,19 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JN } } +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_buildFlatIndexFromVectors( + JNIEnv *env, jclass cls, jfloatArray vectorsJ, jint numVectors, jint dimJ, jstring metricTypeJ) { + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + return knn_jni::faiss_wrapper::BuildFlatIndexFromVectors(&jniUtil, env, vectorsJ, numVectors, dimJ, metricTypeJ, &indexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + + return -1; +} + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jlong indexAddress, jint threadCount) diff --git a/jni/tests/faiss_index_service_test.cpp b/jni/tests/faiss_index_service_test.cpp index 127ca07b89..74289ea6d8 100644 --- a/jni/tests/faiss_index_service_test.cpp +++ b/jni/tests/faiss_index_service_test.cpp @@ -18,6 +18,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "commons.h" +#include using ::testing::NiceMock; using ::testing::Return; @@ -160,4 +161,122 @@ TEST(CreateByteIndexTest, BasicAssertions) { long indexAddress = indexService.initIndex(&mockJNIUtil, jniEnv, metricType, indexDescription, dim, numIds, threadCount, parametersMap); indexService.insertToIndex(dim, numIds, threadCount, (int64_t) &vectors, ids, indexAddress); indexService.writeIndex(&fileIOWriter, indexAddress); +} + +//buildFlatIndexFromVectors tests + +// Helper: Create dummy data for float vectors +std::vector makeVectors(int num, int dim, float val = 1.0f) { + std::vector v(num * dim, val); + return v; +} + +/** + * Test that a Flat L2 index is successfully created from vectors. + * Checks that the returned pointer is not null and the number of vectors is correct. + */ +TEST(BuildFlatIndexFromVectorsTest, BuildsL2Index) { + int numVectors = 5, dim = 3; + std::vector data = makeVectors(numVectors, dim, 2.0f); + + std::unique_ptr mockFaissMethods(new MockFaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(mockFaissMethods)); + + jlong indexPtr = indexService.buildFlatIndexFromVectors(numVectors, dim, data, faiss::METRIC_L2); + + ASSERT_NE(indexPtr, 0); + auto* index = reinterpret_cast(indexPtr); + ASSERT_EQ(index->ntotal, numVectors); + delete index; +} + +/** + * Test that a Flat Inner Product index is successfully created from vectors. + * Checks that the returned pointer is not null and the number of vectors is correct. + */ +TEST(BuildFlatIndexFromVectorsTest, BuildsIPIndex) { + int numVectors = 4, dim = 2; + std::vector data = makeVectors(numVectors, dim, 3.0f); + + std::unique_ptr mockFaissMethods(new MockFaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(mockFaissMethods)); + + jlong indexPtr = indexService.buildFlatIndexFromVectors(numVectors, dim, data, faiss::METRIC_INNER_PRODUCT); + + ASSERT_NE(indexPtr, 0); + auto* index = reinterpret_cast(indexPtr); + ASSERT_EQ(index->ntotal, numVectors); + delete index; +} + +/** + * Test that providing empty vectors throws a runtime_error. + */ +TEST(BuildFlatIndexFromVectorsTest, ThrowsOnEmptyVectors) { + int numVectors = 10, dim = 4; + std::vector empty; + + std::unique_ptr mockFaissMethods(new MockFaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(mockFaissMethods)); + + EXPECT_THROW( + indexService.buildFlatIndexFromVectors(numVectors, dim, empty, faiss::METRIC_L2), + std::runtime_error + ); +} + +/** + * Test that providing a vector whose size does not match numVectors * dim throws a runtime_error. + */ +TEST(BuildFlatIndexFromVectorsTest, ThrowsOnMismatchedSize) { + int numVectors = 3, dim = 5; + std::vector badData(7, 1.0f); + + std::unique_ptr mockFaissMethods(new MockFaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(mockFaissMethods)); + + EXPECT_THROW( + indexService.buildFlatIndexFromVectors(numVectors, dim, badData, faiss::METRIC_L2), + std::runtime_error + ); +} + +/** + * Test that the vectors inserted into the flat index are preserved in order and value. + * This reconstructs each vector from the index and compares to the original input. + */ +TEST(BuildFlatIndexFromVectorsTest, IndexContainsInsertedVectorsInOrder) { + int numVectors = 5, dim = 3; + // Prepare 5 unique vectors + std::vector data = { + 1.0f, 2.0f, 3.0f, // vector 0 + 4.0f, 5.0f, 6.0f, // vector 1 + 7.0f, 8.0f, 9.0f, // vector 2 + 10.0f, 11.0f, 12.0f, // vector 3 + 13.0f, 14.0f, 15.0f // vector 4 + }; + + // Use L2 metric for this test + std::unique_ptr mockFaissMethods(new MockFaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(mockFaissMethods)); + + jlong indexPtr = indexService.buildFlatIndexFromVectors(numVectors, dim, data, faiss::METRIC_L2); + + ASSERT_NE(indexPtr, 0); + auto* index = reinterpret_cast(indexPtr); + ASSERT_EQ(index->ntotal, numVectors); + + // Check each vector in the index matches input and order + std::vector reconstructed(dim); + for (int i = 0; i < numVectors; ++i) { + index->reconstruct(i, reconstructed.data()); + for (int j = 0; j < dim; ++j) { + float expected = data[i * dim + j]; + ASSERT_FLOAT_EQ(reconstructed[j], expected) + << "Vector " << i << " element " << j << " mismatch"; + } + } + + // Clean up + delete index; } \ No newline at end of file diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 0dd9ac8366..ef13537043 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -38,8 +38,8 @@ const float rangeSearchRandomDataMax = 50; const float rangeSearchRadius = 20000; void createIndexIteratively( - knn_jni::JNIUtilInterface * JNIUtil, - JNIEnv *jniEnv, + knn_jni::JNIUtilInterface * JNIUtil, + JNIEnv *jniEnv, std::vector & ids, std::vector & vectors, int dim, @@ -73,13 +73,13 @@ void createIndexIteratively( } void createBinaryIndexIteratively( - knn_jni::JNIUtilInterface * JNIUtil, - JNIEnv *jniEnv, + knn_jni::JNIUtilInterface * JNIUtil, + JNIEnv *jniEnv, std::vector & ids, std::vector & vectors, int dim, jobject javaFileOutputMock, - std::unordered_map parametersMap, + std::unordered_map parametersMap, IndexService * indexService, int insertions = 10 ) { @@ -1336,4 +1336,4 @@ TEST(FaissRangeSearchQueryIndexTestWithParentFilterTest, BasicAssertions) { delete it; } } -} +} \ No newline at end of file diff --git a/jni/tests/faiss_wrapper_unit_test.cpp b/jni/tests/faiss_wrapper_unit_test.cpp index e23ca85284..1ba3470de0 100644 --- a/jni/tests/faiss_wrapper_unit_test.cpp +++ b/jni/tests/faiss_wrapper_unit_test.cpp @@ -486,3 +486,60 @@ namespace query_index_with_filter_test_ivf { ) ); } + +// Helper: Simulate float array as jfloatArray for JNI-style tests +static jfloatArray ToJFloatArray(std::vector& data) { + return reinterpret_cast(data.data()); +} + +// Helper: Simulate std::string as jstring for JNI-style tests +static jstring ToJString(std::string& str) { + return reinterpret_cast(&str); +} + +TEST(FaissBuildFlatIndexFromVectorsTest, ThrowsIfVectorsNull) { + NiceMock jniEnv; + NiceMock mockJNIUtil; + std::unique_ptr faissMethods( + new knn_jni::faiss_wrapper::FaissMethods() + ); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + jfloatArray vectorsJ = nullptr; + jint numVectors = 2; + jint dim = 2; + std::string metricType = "L2"; + jstring metricTypeJ = ToJString(metricType); + + EXPECT_THROW( + knn_jni::faiss_wrapper::BuildFlatIndexFromVectors( + &mockJNIUtil, &jniEnv, vectorsJ, numVectors, dim, metricTypeJ, &indexService), + std::runtime_error + ); +} + +TEST(FaissBuildFlatIndexFromVectorsTest, ThrowsIfInvalidDimsOrNumVectors) { + NiceMock jniEnv; + NiceMock mockJNIUtil; + std::unique_ptr faissMethods( + new knn_jni::faiss_wrapper::FaissMethods() + ); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + std::vector data(4, 1.0f); + jfloatArray vectorsJ = ToJFloatArray(data); + std::string metricType = "L2"; + jstring metricTypeJ = ToJString(metricType); + + // Zero dim + EXPECT_THROW( + knn_jni::faiss_wrapper::BuildFlatIndexFromVectors( + &mockJNIUtil, &jniEnv, vectorsJ, 2, 0, metricTypeJ, &indexService), + std::runtime_error + ); + + // Negative numVectors + EXPECT_THROW( + knn_jni::faiss_wrapper::BuildFlatIndexFromVectors( + &mockJNIUtil, &jniEnv, vectorsJ, -1, 2, metricTypeJ, &indexService), + std::runtime_error + ); +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java index 845884139a..7a888f76fd 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java @@ -30,6 +30,8 @@ import org.opensearch.repositories.Repository; import org.opensearch.repositories.RepositoryMissingException; import org.opensearch.repositories.blobstore.BlobStoreRepository; +import org.opensearch.knn.jni.JNIService; +import org.opensearch.knn.index.engine.KNNEngine; import java.io.IOException; import java.util.Map; @@ -155,10 +157,13 @@ public void buildAndWriteIndex(BuildIndexParams indexInfo) throws IOException { RemoteIndexClient client = RemoteIndexClientFactory.getRemoteIndexClient(KNNSettings.getRemoteBuildServiceEndpoint()); RemoteBuildResponse remoteBuildResponse = submitBuild(repositoryContext, indexInfo, client); - // 3. Await vector build completion + // 3. Build flat index + buildFlatIndex(indexInfo); // this will return a pointer to send to readFromRepository in complete implementation + + // 4. Await vector build completion RemoteBuildStatusResponse remoteBuildStatusResponse = awaitIndexBuild(remoteBuildResponse, indexInfo, client); - // 4. Download index file and write to indexOutput + // 5. Download index file and write to indexOutput readFromRepository(indexInfo, repositoryContext, remoteBuildStatusResponse); success = true; return; @@ -218,6 +223,50 @@ private RemoteBuildResponse submitBuild(RepositoryContext repositoryContext, Bui } } + private void buildFlatIndex(BuildIndexParams indexInfo) throws IOException { + KNNVectorValues knnVectorValues = indexInfo.getKnnVectorValuesSupplier().get(); + int totalDocs = indexInfo.getTotalLiveDocs(); + Object firstVector = null; + int dimension; + int idx = 0; + float[] vectorData; + + if (knnVectorValues.nextDoc() == org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS) { + throw new IllegalStateException("No vectors to index"); + } + + // First vector, need to access first before getting values needed for loop + firstVector = knnVectorValues.getVector(); + if (firstVector instanceof float[] floatVector) { + dimension = floatVector.length; + vectorData = new float[totalDocs * dimension]; + System.arraycopy(floatVector, 0, vectorData, 0, dimension); + } else { + throw new IllegalArgumentException("Unknown vector type: " + firstVector.getClass()); + } + idx = 1; + + // Rest of the vectors + while (knnVectorValues.nextDoc() != org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS) { + Object vec = knnVectorValues.getVector(); + if (vec instanceof float[] floatVec) { + System.arraycopy(floatVec, 0, vectorData, idx * dimension, dimension); + } else { + throw new IllegalArgumentException("Unknown vector type: " + vec.getClass()); + } + idx++; + } + + String metricType = "L2"; + Object spaceType = indexInfo.getParameters().get("space_type"); + if (spaceType != null && spaceType.toString().toUpperCase().contains("IP")) { + metricType = "IP"; + } + + long indexPtr = JNIService.buildFlatIndexFromVectors(vectorData, idx, dimension, metricType); + JNIService.free(indexPtr, KNNEngine.FAISS); + } + /** * Awaits the vector build to complete * @return RemoteBuildStatusResponse containing the completed status response from the remote service. diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 0376ef19c8..3749f3a46b 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -103,6 +103,17 @@ class FaissService { */ public static native void insertToIndex(int[] ids, long vectorsAddress, int dim, long indexAddress, int threadCount); + /** + * Builds a FAISS IndexFlat index from vectors. + * + * @param vectors vector data + * @param numVectors number of vectors + * @param dimension dimension of the vectors + * @param metricType either "L2" or "IP" + * @return pointer to the native IndexFlat object + */ + public static native long buildFlatIndexFromVectors(float[] vectors, int numVectors, int dimension, String metricType); + /** * Inserts to a faiss index. The memory occupied by the vectorsAddress will be freed up during the * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 2681032e80..bf478185e3 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -90,6 +90,19 @@ public static void insertToIndex( ); } + /** + * Builds a flat FAISS index from the given float vectors. + * + * @param vectors Array of float vectors (size: numVectors * dimension) + * @param numVectors Number of vectors + * @param dimension Dimension of each vector + * @param metricType Metric type for the index ("L2", "IP", etc.) + * @return Native memory address of the created index + */ + public static long buildFlatIndexFromVectors(float[] vectors, int numVectors, int dimension, String metricType) { + return FaissService.buildFlatIndexFromVectors(vectors, numVectors, dimension, metricType); + } + /** * Writes a faiss index to disk. *