diff --git a/jni/cmake/init-faiss.cmake b/jni/cmake/init-faiss.cmake index 7e9622a9c8..a58493ea42 100644 --- a/jni/cmake/init-faiss.cmake +++ b/jni/cmake/init-faiss.cmake @@ -22,6 +22,7 @@ if(NOT DEFINED APPLY_LIB_PATCHES OR "${APPLY_LIB_PATCHES}" STREQUAL true) list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0004-Custom-patch-to-support-binary-vector.patch") list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0005-Custom-patch-to-support-multi-vector-IndexHNSW-search_level_0.patch") list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0006-Add-nested-search-support-for-IndexBinaryHNSWCagra.patch") + list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0007-Custom-patch-to-support-dedup-flatvector-faiss-file.patch") # Get patch id of the last commit execute_process(COMMAND sh -c "git --no-pager show HEAD | git patch-id --stable" OUTPUT_VARIABLE PATCH_ID_OUTPUT_FROM_COMMIT WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss) diff --git a/jni/include/faiss_stream_support.h b/jni/include/faiss_stream_support.h index eb1b2a404c..1ef1fcab62 100644 --- a/jni/include/faiss_stream_support.h +++ b/jni/include/faiss_stream_support.h @@ -32,28 +32,101 @@ namespace stream { * This will then indirectly call the mediator component and eventually read required bytes from Lucene's IndexInput. */ class FaissOpenSearchIOReader final : public faiss::IOReader { - public: - explicit FaissOpenSearchIOReader(NativeEngineIndexInputMediator *_mediator) - : faiss::IOReader(), - mediator(knn_jni::util::ParameterCheck::require_non_null(_mediator, "mediator")) { - name = "FaissOpenSearchIOReader"; - } +public: + explicit FaissOpenSearchIOReader(NativeEngineIndexInputMediator* _mediator) + : faiss::IOReader(), + mediator(knn_jni::util::ParameterCheck::require_non_null(_mediator, "mediator")) { - size_t operator()(void *ptr, size_t size, size_t nitems) final { - const auto readBytes = size * nitems; - if (readBytes > 0) { - // Mediator calls IndexInput, then copy read bytes to `ptr`. - mediator->copyBytes(readBytes, (uint8_t *) ptr); + name = "FaissOpenSearchIOReader"; } - return nitems; - } - int filedescriptor() final { - throw std::runtime_error("filedescriptor() is not supported in FaissOpenSearchIOReader."); - } + ~FaissOpenSearchIOReader() override { + JNIEnv* env = mediator->getEnv(); + if (vectorReaderGlobalRef && env) { + env->DeleteGlobalRef(vectorReaderGlobalRef); + } + } - private: - NativeEngineIndexInputMediator *mediator; + size_t operator()(void* ptr, size_t size, size_t nitems) override { + const auto bytes = size * nitems; + mediator->copyBytes(bytes, static_cast(ptr)); + return nitems; + } + + bool copy(void* dest, int expectedByteSize, bool isFloat) override { + JNIEnv* env = mediator->getEnv(); + if (env == nullptr) return false; + + jobject readStream = mediator->getJavaObject(); + if (!vectorReaderGlobalRef) { + jclass streamClass = env->GetObjectClass(readStream); + if (env->ExceptionCheck() || streamClass == nullptr) return false; + + jmethodID getVectorsMid = env->GetMethodID( + streamClass, + "getFullPrecisionVectors", + "()Lorg/opensearch/knn/index/store/VectorReader;" + ); + if (env->ExceptionCheck() || !getVectorsMid) return false; + + jobject vectorReader = env->CallObjectMethod(readStream, getVectorsMid); + if (env->ExceptionCheck() || vectorReader == nullptr) return false; + + vectorReaderGlobalRef = env->NewGlobalRef(vectorReader); + if (env->ExceptionCheck() || vectorReaderGlobalRef == nullptr) return false; + } + + if (isFloat) { + jclass vectorReaderClass = env->GetObjectClass(vectorReaderGlobalRef); + jmethodID nextFloatMid = env->GetMethodID(vectorReaderClass, "nextFloatVector", "()[F"); + if (env->ExceptionCheck() || !nextFloatMid) return false; + + jfloatArray vector = (jfloatArray) env->CallObjectMethod(vectorReaderGlobalRef, nextFloatMid); + if (env->ExceptionCheck() || vector == nullptr) return false; + + jsize length = env->GetArrayLength(vector); + jfloat* elems = env->GetFloatArrayElements(vector, nullptr); + + JNIReleaseElements release_elems([=]() { + env->ReleaseFloatArrayElements(vector, elems, JNI_ABORT); + }); + + int vectorByteSize = sizeof(float) * length; + if (vectorByteSize != expectedByteSize) return false; + + std::memcpy(dest, elems, vectorByteSize); + return true; + + } else { + jclass vectorReaderClass = env->GetObjectClass(vectorReaderGlobalRef); + jmethodID nextByteMid = env->GetMethodID(vectorReaderClass, "nextByteVector", "()[B"); + if (env->ExceptionCheck() || !nextByteMid) return false; + + jbyteArray vector = (jbyteArray) env->CallObjectMethod(vectorReaderGlobalRef, nextByteMid); + if (env->ExceptionCheck() || vector == nullptr) return false; + + jsize length = env->GetArrayLength(vector); + jbyte* elems = env->GetByteArrayElements(vector, nullptr); + + JNIReleaseElements release_elems([=]() { + env->ReleaseByteArrayElements(vector, elems, JNI_ABORT); + }); + + int vectorByteSize = sizeof(float) * length; + if (vectorByteSize != expectedByteSize) return false; + + float* floatDest = static_cast(dest); + for (int i = 0; i < length; ++i) { + floatDest[i] = static_cast(elems[i]); + } + + return true; + } + } + +private: + NativeEngineIndexInputMediator* mediator; + jobject vectorReaderGlobalRef = nullptr; }; // class FaissOpenSearchIOReader diff --git a/jni/include/native_engines_stream_support.h b/jni/include/native_engines_stream_support.h index 07f97f3ac9..095b8a9d0c 100644 --- a/jni/include/native_engines_stream_support.h +++ b/jni/include/native_engines_stream_support.h @@ -44,6 +44,18 @@ class NativeEngineIndexInputMediator { remainingBytesMethod(getRemainingBytesMethod(_jni_interface, _env)) { } + JNIEnv* getEnv() const { + return env; + } + + jobject getJavaObject() const { + return indexInput; + } + + JNIUtilInterface* getJNIUtil() const { + return jni_interface; + } + void copyBytes(int64_t nbytes, uint8_t * RESTRICT destination) { auto jclazz = getIndexInputWithBufferClass(jni_interface, env); diff --git a/jni/patches/faiss/0007-Custom-patch-to-support-dedup-flatvector-faiss-file.patch b/jni/patches/faiss/0007-Custom-patch-to-support-dedup-flatvector-faiss-file.patch new file mode 100644 index 0000000000..e653635f63 --- /dev/null +++ b/jni/patches/faiss/0007-Custom-patch-to-support-dedup-flatvector-faiss-file.patch @@ -0,0 +1,890 @@ +From e0b1fb7696cde79fb108ce2ca20cd32935eab50e Mon Sep 17 00:00:00 2001 +From: sobhu17 +Date: Fri, 8 Aug 2025 09:29:31 -0700 +Subject: [PATCH] change faiss to support dedup optimization + +During index loading, without this optimization the `.faiss` file +contains the flat vector section. With this patch, the `.faiss` file +will no longer store flat vectors; instead, vectors are loaded from the +`.vec` file into the C++ side using our custom `VectorReader`, which +streams the vectors one by one. +To enable this, vector writing to the `.faiss` file is disabled, and +the loading path is updated to fetch vectors via `VectorReader` at +index load time. + +To enable this: +- In `write_index()`, we skip writing vectors to the `.faiss` file for + FP32 and byte vector types. Other cases follow the existing baseline + write path. +- In `read_index()`, we load vectors via the Java-side `VectorReader` + to support deduplication optimization, ensuring that vectors are read + from the `.vec` file instead of being embedded in `.faiss`. + +This approach reduces `.faiss` file size and enables deduplication while +keeping vector loading flexible and driven from the Java side. + +--- + faiss/impl/index_read.cpp | 123 +++++++++++++++++++---------- + faiss/impl/index_read_utils.h | 7 +- + faiss/impl/index_write.cpp | 141 ++++++++++++++++++---------------- + faiss/impl/io.h | 5 ++ + faiss/index_io.h | 25 +++--- + 5 files changed, 182 insertions(+), 119 deletions(-) + +diff --git a/faiss/impl/index_read.cpp b/faiss/impl/index_read.cpp +index 89f05c757..9024adfca 100644 +--- a/faiss/impl/index_read.cpp ++++ b/faiss/impl/index_read.cpp +@@ -194,18 +194,21 @@ void read_xb_vector(VectorT& target, IOReader* f) { + * Read + **************************************************************/ + +-void read_index_header(Index* idx, IOReader* f) { +- READ1(idx->d); +- READ1(idx->ntotal); +- idx_t dummy; +- READ1(dummy); +- READ1(dummy); +- READ1(idx->is_trained); +- READ1(idx->metric_type); +- if (idx->metric_type > 1) { +- READ1(idx->metric_arg); +- } +- idx->verbose = false; ++bool read_index_header(Index* idx, IOReader* f) { ++ READ1(idx->d); ++ READ1(idx->ntotal); ++ idx_t dummy; ++ READ1(dummy); ++ READ1(dummy); ++ const bool dedup_applied = dummy == 0x7FFFFFFFFFFFFFFFULL; ++ READ1(idx->is_trained); ++ READ1(idx->metric_type); ++ if (idx->metric_type > 1) { ++ READ1(idx->metric_arg); ++ } ++ idx->verbose = false; ++ ++ return dedup_applied; + } + + VectorTransform* read_VectorTransform(IOReader* f) { +@@ -317,7 +320,7 @@ static void read_ArrayInvertedLists_sizes( + } + } + +-InvertedLists* read_InvertedLists(IOReader* f, int io_flags) { ++InvertedLists* read_InvertedLists(IOReader* f, uint64_t io_flags) { + uint32_t h; + READ1(h); + if (h == fourcc("il00")) { +@@ -364,7 +367,7 @@ InvertedLists* read_InvertedLists(IOReader* f, int io_flags) { + } + } + +-void read_InvertedLists(IndexIVF* ivf, IOReader* f, int io_flags) { ++void read_InvertedLists(IndexIVF* ivf, IOReader* f, uint64_t io_flags) { + InvertedLists* ils = read_InvertedLists(f, io_flags); + if (ils) { + FAISS_THROW_IF_NOT(ils->nlist == ivf->nlist); +@@ -427,7 +430,7 @@ static void read_AdditiveQuantizer(AdditiveQuantizer* aq, IOReader* f) { + static void read_ResidualQuantizer( + ResidualQuantizer* rq, + IOReader* f, +- int io_flags) { ++ uint64_t io_flags) { + read_AdditiveQuantizer(rq, f); + READ1(rq->train_type); + READ1(rq->max_beam_size); +@@ -464,7 +467,7 @@ static void read_ProductAdditiveQuantizer( + static void read_ProductResidualQuantizer( + ProductResidualQuantizer* prq, + IOReader* f, +- int io_flags) { ++ uint64_t io_flags) { + read_ProductAdditiveQuantizer(prq, f); + + for (size_t i = 0; i < prq->nsplits; i++) { +@@ -601,11 +604,12 @@ void read_direct_map(DirectMap* dm, IOReader* f) { + void read_ivf_header( + IndexIVF* ivf, + IOReader* f, ++ uint64_t io_flags, + std::vector>* ids) { + read_index_header(ivf, f); + READ1(ivf->nlist); + READ1(ivf->nprobe); +- ivf->quantizer = read_index(f); ++ ivf->quantizer = read_index(f, io_flags); + ivf->own_fields = true; + if (ids) { // used in legacy "Iv" formats + ids->resize(ivf->nlist); +@@ -632,7 +636,7 @@ ArrayInvertedLists* set_array_invlist( + return ail; + } + +-static IndexIVFPQ* read_ivfpq(IOReader* f, uint32_t h, int io_flags) { ++static IndexIVFPQ* read_ivfpq(IOReader* f, uint32_t h, uint64_t io_flags) { + bool legacy = h == fourcc("IvQR") || h == fourcc("IvPQ"); + + IndexIVFPQR* ivfpqr = h == fourcc("IvQR") || h == fourcc("IwQR") +@@ -641,7 +645,7 @@ static IndexIVFPQ* read_ivfpq(IOReader* f, uint32_t h, int io_flags) { + IndexIVFPQ* ivpq = ivfpqr ? ivfpqr : new IndexIVFPQ(); + + std::vector> ids; +- read_ivf_header(ivpq, f, legacy ? &ids : nullptr); ++ read_ivf_header(ivpq, f, io_flags, legacy ? &ids : nullptr); + READ1(ivpq->by_residual); + READ1(ivpq->code_size); + read_ProductQuantizer(&ivpq->pq, f); +@@ -674,7 +678,7 @@ static IndexIVFPQ* read_ivfpq(IOReader* f, uint32_t h, int io_flags) { + + int read_old_fmt_hack = 0; + +-Index* read_index(IOReader* f, int io_flags) { ++Index* read_index(IOReader* f, uint64_t io_flags) { + Index* idx = nullptr; + uint32_t h; + READ1(h); +@@ -691,9 +695,23 @@ Index* read_index(IOReader* f, int io_flags) { + } else { + idxf = new IndexFlat(); + } +- read_index_header(idxf, f); ++ const bool dedup_applied = read_index_header(idxf, f); + idxf->code_size = idxf->d * sizeof(float); +- read_xb_vector(idxf->codes, f); ++ if(dedup_applied){ ++ idxf->codes.resize(idxf->ntotal * idxf->code_size); ++ if (idxf->codes.size() > 0) { ++ float* dataPtr = reinterpret_cast(idxf->codes.data()); ++ size_t dim = idxf->d; ++ for (size_t i = 0; i < idxf->ntotal; i++) { ++ if (!f->copy(dataPtr + i * dim, dim * sizeof(float), true)) { ++ throw std::runtime_error("Failed to load flat vectors via IOReader::copy at index " + std::to_string(i)); ++ } ++ } ++ } ++ }else{ ++ read_xb_vector(idxf->codes, f); ++ } ++ + FAISS_THROW_IF_NOT( + idxf->codes.size() == idxf->ntotal * idxf->code_size); + // leak! +@@ -860,7 +878,7 @@ Index* read_index(IOReader* f, int io_flags) { + } else { + ivaqfs = new IndexIVFProductResidualQuantizerFastScan(); + } +- read_ivf_header(ivaqfs, f); ++ read_ivf_header(ivaqfs, f, io_flags); + + if (is_LSQ) { + read_LocalSearchQuantizer((LocalSearchQuantizer*)ivaqfs->aq, f); +@@ -896,7 +914,7 @@ Index* read_index(IOReader* f, int io_flags) { + } else if (h == fourcc("IvFl") || h == fourcc("IvFL")) { // legacy + IndexIVFFlat* ivfl = new IndexIVFFlat(); + std::vector> ids; +- read_ivf_header(ivfl, f, &ids); ++ read_ivf_header(ivfl, f, io_flags, &ids); + ivfl->code_size = ivfl->d * sizeof(float); + ArrayInvertedLists* ail = set_array_invlist(ivfl, ids); + +@@ -915,7 +933,7 @@ Index* read_index(IOReader* f, int io_flags) { + idx = ivfl; + } else if (h == fourcc("IwFd")) { + IndexIVFFlatDedup* ivfl = new IndexIVFFlatDedup(); +- read_ivf_header(ivfl, f); ++ read_ivf_header(ivfl, f, io_flags); + ivfl->code_size = ivfl->d * sizeof(float); + { + std::vector tab; +@@ -929,15 +947,38 @@ Index* read_index(IOReader* f, int io_flags) { + idx = ivfl; + } else if (h == fourcc("IwFl")) { + IndexIVFFlat* ivfl = new IndexIVFFlat(); +- read_ivf_header(ivfl, f); ++ read_ivf_header(ivfl, f, io_flags); + ivfl->code_size = ivfl->d * sizeof(float); + read_InvertedLists(ivfl, f, io_flags); + idx = ivfl; + } else if (h == fourcc("IxSQ")) { + IndexScalarQuantizer* idxs = new IndexScalarQuantizer(); +- read_index_header(idxs, f); ++ const bool dedup_applied = read_index_header(idxs, f); + read_ScalarQuantizer(&idxs->sq, f); +- read_vector(idxs->codes, f); ++ if((dedup_applied) && (idxs->sq.qtype != 4)){ ++ idxs->code_size = idxs->sq.code_size; ++ idxs->codes.resize(idxs->ntotal * idxs->code_size); ++ size_t dim = idxs->d; ++ std::vector floatBuffer(dim); ++ ++ for (size_t i = 0; i < idxs->ntotal; ++i) { ++ // Convert byte[] from Java → float[] into floatBuffer ++ if (!f->copy(floatBuffer.data(), sizeof(float) * dim, false)) { ++ throw std::runtime_error("Failed to load byte vector at index " + std::to_string(i)); ++ } ++ ++ // Quantize with FAISS compute_codes() ++ idxs->sq.compute_codes( ++ floatBuffer.data(), ++ idxs->codes.data() + i * idxs->code_size, ++ 1 ++ ); ++ } ++ FAISS_THROW_IF_NOT(idxs->codes.size() == idxs->ntotal * idxs->code_size); ++ }else{ ++ read_vector(idxs->codes, f); ++ } ++ + idxs->code_size = idxs->sq.code_size; + idx = idxs; + } else if (h == fourcc("IxLa")) { +@@ -953,7 +994,7 @@ Index* read_index(IOReader* f, int io_flags) { + } else if (h == fourcc("IvSQ")) { // legacy + IndexIVFScalarQuantizer* ivsc = new IndexIVFScalarQuantizer(); + std::vector> ids; +- read_ivf_header(ivsc, f, &ids); ++ read_ivf_header(ivsc, f, io_flags, &ids); + read_ScalarQuantizer(&ivsc->sq, f); + READ1(ivsc->code_size); + ArrayInvertedLists* ail = set_array_invlist(ivsc, ids); +@@ -962,7 +1003,7 @@ Index* read_index(IOReader* f, int io_flags) { + idx = ivsc; + } else if (h == fourcc("IwSQ") || h == fourcc("IwSq")) { + IndexIVFScalarQuantizer* ivsc = new IndexIVFScalarQuantizer(); +- read_ivf_header(ivsc, f); ++ read_ivf_header(ivsc, f, io_flags); + read_ScalarQuantizer(&ivsc->sq, f); + READ1(ivsc->code_size); + if (h == fourcc("IwSQ")) { +@@ -988,7 +1029,7 @@ Index* read_index(IOReader* f, int io_flags) { + } else { + iva = new IndexIVFProductResidualQuantizer(); + } +- read_ivf_header(iva, f); ++ read_ivf_header(iva, f, io_flags); + READ1(iva->code_size); + if (is_LSQ) { + read_LocalSearchQuantizer((LocalSearchQuantizer*)iva->aq, f); +@@ -1007,7 +1048,7 @@ Index* read_index(IOReader* f, int io_flags) { + idx = iva; + } else if (h == fourcc("IwSh")) { + IndexIVFSpectralHash* ivsp = new IndexIVFSpectralHash(); +- read_ivf_header(ivsp, f); ++ read_ivf_header(ivsp, f, io_flags); + ivsp->vt = read_VectorTransform(f); + ivsp->own_fields = true; + READ1(ivsp->nbit); +@@ -1189,7 +1230,7 @@ Index* read_index(IOReader* f, int io_flags) { + + } else if (h == fourcc("IwPf")) { + IndexIVFPQFastScan* ivpq = new IndexIVFPQFastScan(); +- read_ivf_header(ivpq, f); ++ read_ivf_header(ivpq, f, io_flags); + READ1(ivpq->by_residual); + READ1(ivpq->code_size); + READ1(ivpq->bbs); +@@ -1235,7 +1276,7 @@ Index* read_index(IOReader* f, int io_flags) { + idx = idxq; + } else if (h == fourcc("Iwrq")) { + IndexIVFRaBitQ* ivrq = new IndexIVFRaBitQ(); +- read_ivf_header(ivrq, f); ++ read_ivf_header(ivrq, f, io_flags); + read_RaBitQuantizer(&ivrq->rabitq, f); + READ1(ivrq->code_size); + READ1(ivrq->by_residual); +@@ -1252,7 +1293,7 @@ Index* read_index(IOReader* f, int io_flags) { + return idx; + } + +-Index* read_index(FILE* f, int io_flags) { ++Index* read_index(FILE* f, uint64_t io_flags) { + if ((io_flags & IO_FLAG_MMAP_IFC) == IO_FLAG_MMAP_IFC) { + // enable mmap-supporting IOReader + auto owner = std::make_shared(f); +@@ -1264,7 +1305,7 @@ Index* read_index(FILE* f, int io_flags) { + } + } + +-Index* read_index(const char* fname, int io_flags) { ++Index* read_index(const char* fname, uint64_t io_flags) { + if ((io_flags & IO_FLAG_MMAP_IFC) == IO_FLAG_MMAP_IFC) { + // enable mmap-supporting IOReader + auto owner = std::make_shared(fname); +@@ -1287,7 +1328,7 @@ VectorTransform* read_VectorTransform(const char* fname) { + * Read binary indexes + **************************************************************/ + +-static void read_InvertedLists(IndexBinaryIVF* ivf, IOReader* f, int io_flags) { ++static void read_InvertedLists(IndexBinaryIVF* ivf, IOReader* f, uint64_t io_flags) { + InvertedLists* ils = read_InvertedLists(f, io_flags); + FAISS_THROW_IF_NOT( + !ils || +@@ -1370,7 +1411,7 @@ static void read_binary_multi_hash_map( + } + } + +-IndexBinary* read_index_binary(IOReader* f, int io_flags) { ++IndexBinary* read_index_binary(IOReader* f, uint64_t io_flags) { + IndexBinary* idx = nullptr; + uint32_t h; + READ1(h); +@@ -1454,7 +1495,7 @@ IndexBinary* read_index_binary(IOReader* f, int io_flags) { + return idx; + } + +-IndexBinary* read_index_binary(FILE* f, int io_flags) { ++IndexBinary* read_index_binary(FILE* f, uint64_t io_flags) { + if ((io_flags & IO_FLAG_MMAP_IFC) == IO_FLAG_MMAP_IFC) { + // enable mmap-supporting IOReader + auto owner = std::make_shared(f); +@@ -1466,7 +1507,7 @@ IndexBinary* read_index_binary(FILE* f, int io_flags) { + } + } + +-IndexBinary* read_index_binary(const char* fname, int io_flags) { ++IndexBinary* read_index_binary(const char* fname, uint64_t io_flags) { + if ((io_flags & IO_FLAG_MMAP_IFC) == IO_FLAG_MMAP_IFC) { + // enable mmap-supporting IOReader + auto owner = std::make_shared(fname); +@@ -1479,4 +1520,4 @@ IndexBinary* read_index_binary(const char* fname, int io_flags) { + } + } + +-} // namespace faiss ++} // namespace faiss +\ No newline at end of file +diff --git a/faiss/impl/index_read_utils.h b/faiss/impl/index_read_utils.h +index 543f48126..bef324e4e 100644 +--- a/faiss/impl/index_read_utils.h ++++ b/faiss/impl/index_read_utils.h +@@ -19,13 +19,14 @@ namespace faiss { + struct ProductQuantizer; + struct ScalarQuantizer; + +-void read_index_header(Index* idx, IOReader* f); ++bool read_index_header(Index* idx, IOReader* f); + void read_direct_map(DirectMap* dm, IOReader* f); + void read_ivf_header( + IndexIVF* ivf, + IOReader* f, ++ uint64_t io_flags, + std::vector>* ids = nullptr); +-void read_InvertedLists(IndexIVF* ivf, IOReader* f, int io_flags); ++void read_InvertedLists(IndexIVF* ivf, IOReader* f, uint64_t io_flags); + ArrayInvertedLists* set_array_invlist( + IndexIVF* ivf, + std::vector>& ids); +@@ -34,4 +35,4 @@ void read_ScalarQuantizer(ScalarQuantizer* ivsc, IOReader* f); + + } // namespace faiss + +-#endif ++#endif +\ No newline at end of file +diff --git a/faiss/impl/index_write.cpp b/faiss/impl/index_write.cpp +index def7dcd2b..d321693e1 100644 +--- a/faiss/impl/index_write.cpp ++++ b/faiss/impl/index_write.cpp +@@ -77,17 +77,17 @@ namespace faiss { + /************************************************************* + * Write + **************************************************************/ +-static void write_index_header(const Index* idx, IOWriter* f) { +- WRITE1(idx->d); +- WRITE1(idx->ntotal); +- idx_t dummy = 1 << 20; +- WRITE1(dummy); +- WRITE1(dummy); +- WRITE1(idx->is_trained); +- WRITE1(idx->metric_type); +- if (idx->metric_type > 1) { +- WRITE1(idx->metric_arg); +- } ++static void write_index_header(const Index* idx, IOWriter* f, const bool dedup_vector_enabled=false) { ++ WRITE1(idx->d); ++ WRITE1(idx->ntotal); ++ idx_t dummy = dedup_vector_enabled ? 0x7FFFFFFFFFFFFFFFULL : 1 << 20; ++ WRITE1(dummy); ++ WRITE1(dummy); ++ WRITE1(idx->is_trained); ++ WRITE1(idx->metric_type); ++ if (idx->metric_type > 1) { ++ WRITE1(idx->metric_arg); ++ } + } + + void write_VectorTransform(const VectorTransform* vt, IOWriter* f) { +@@ -387,17 +387,18 @@ static void write_direct_map(const DirectMap* dm, IOWriter* f) { + } + } + +-static void write_ivf_header(const IndexIVF* ivf, IOWriter* f) { +- write_index_header(ivf, f); ++static void write_ivf_header(const IndexIVF* ivf, IOWriter* f, uint64_t io_flags) { ++ write_index_header(ivf, f, !(io_flags & DEDUPE_VECTORS_OPT_DISABLED)); + WRITE1(ivf->nlist); + WRITE1(ivf->nprobe); + // subclasses write by_residual (some of them support only one setting of + // by_residual). +- write_index(ivf->quantizer, f); ++ write_index(ivf->quantizer, f, io_flags); + write_direct_map(&ivf->direct_map, f); + } + +-void write_index(const Index* idx, IOWriter* f, int io_flags) { ++void write_index(const Index* idx, IOWriter* f, uint64_t io_flags) { ++ const bool dedup_vector_enabled = !(io_flags & DEDUPE_VECTORS_OPT_DISABLED); + if (idx == nullptr) { + // eg. for a storage component of HNSW that is set to nullptr + uint32_t h = fourcc("null"); +@@ -408,12 +409,15 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + : idxf->metric_type == METRIC_L2 ? "IxF2" + : "IxFl"); + WRITE1(h); +- write_index_header(idx, f); +- WRITEXBVECTOR(idxf->codes); ++ write_index_header(idx, f, dedup_vector_enabled); ++ ++ if(!dedup_vector_enabled){ ++ WRITEXBVECTOR(idxf->codes); ++ } + } else if (const IndexLSH* idxl = dynamic_cast(idx)) { + uint32_t h = fourcc("IxHe"); + WRITE1(h); +- write_index_header(idx, f); ++ write_index_header(idx, f, dedup_vector_enabled); + WRITE1(idxl->nbits); + WRITE1(idxl->rotate_data); + WRITE1(idxl->train_thresholds); +@@ -425,7 +429,7 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + } else if (const IndexPQ* idxp = dynamic_cast(idx)) { + uint32_t h = fourcc("IxPq"); + WRITE1(h); +- write_index_header(idx, f); ++ write_index_header(idx, f, dedup_vector_enabled); + write_ProductQuantizer(&idxp->pq, f); + WRITEVECTOR(idxp->codes); + // search params -- maybe not useful to store? +@@ -437,7 +441,7 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + dynamic_cast(idx)) { + uint32_t h = fourcc("IxRq"); + WRITE1(h); +- write_index_header(idx, f); ++ write_index_header(idx, f, dedup_vector_enabled); + write_ResidualQuantizer(&idxr->rq, f); + WRITE1(idxr->code_size); + WRITEVECTOR(idxr->codes); +@@ -446,7 +450,7 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + dynamic_cast(idx)) { + uint32_t h = fourcc("IxLS"); + WRITE1(h); +- write_index_header(idx, f); ++ write_index_header(idx, f, dedup_vector_enabled); + write_LocalSearchQuantizer(&idxr_2->lsq, f); + WRITE1(idxr_2->code_size); + WRITEVECTOR(idxr_2->codes); +@@ -465,7 +469,7 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + idx)) { + uint32_t h = fourcc("IxPL"); + WRITE1(h); +- write_index_header(idx, f); ++ write_index_header(idx, f, dedup_vector_enabled); + write_ProductLocalSearchQuantizer(&idxpl->plsq, f); + WRITE1(idxpl->code_size); + WRITEVECTOR(idxpl->codes); +@@ -496,7 +500,7 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + WRITE1(h); + } + +- write_index_header(idxaqfs, f); ++ write_index_header(idxaqfs, f, dedup_vector_enabled); + + if (idxlsqfs) { + write_LocalSearchQuantizer(&idxlsqfs->lsq, f); +@@ -552,7 +556,7 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + WRITE1(h); + } + +- write_ivf_header(ivaqfs, f); ++ write_ivf_header(ivaqfs, f, 1ull<<63); + + if (ivlsqfs) { + write_LocalSearchQuantizer(&ivlsqfs->lsq, f); +@@ -586,15 +590,15 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + dynamic_cast(idx)) { + uint32_t h = fourcc("ImRQ"); + WRITE1(h); +- write_index_header(idx, f); ++ write_index_header(idx, f, dedup_vector_enabled); + write_ResidualQuantizer(&idxr_2->rq, f); + WRITE1(idxr_2->beam_factor); + } else if ( + const Index2Layer* idxp_2 = dynamic_cast(idx)) { + uint32_t h = fourcc("Ix2L"); + WRITE1(h); +- write_index_header(idx, f); +- write_index(idxp_2->q1.quantizer, f); ++ write_index_header(idx, f, dedup_vector_enabled); ++ write_index(idxp_2->q1.quantizer, f, io_flags); + WRITE1(idxp_2->q1.nlist); + WRITE1(idxp_2->q1.quantizer_trains_alone); + write_ProductQuantizer(&idxp_2->pq, f); +@@ -607,9 +611,16 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + dynamic_cast(idx)) { + uint32_t h = fourcc("IxSQ"); + WRITE1(h); +- write_index_header(idx, f); ++ if ((idxs->sq.qtype == 4)) { ++ write_index_header(idx, f, false); ++ } ++ else{ ++ write_index_header(idx, f, dedup_vector_enabled); ++ } + write_ScalarQuantizer(&idxs->sq, f); +- WRITEVECTOR(idxs->codes); ++ if((!dedup_vector_enabled) || (idxs->sq.qtype == 4)){ ++ WRITEVECTOR(idxs->codes); ++ } + } else if ( + const IndexLattice* idxl_2 = + dynamic_cast(idx)) { +@@ -619,14 +630,14 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + WRITE1(idxl_2->nsq); + WRITE1(idxl_2->scale_nbit); + WRITE1(idxl_2->zn_sphere_codec.r2); +- write_index_header(idx, f); ++ write_index_header(idx, f, dedup_vector_enabled); + WRITEVECTOR(idxl_2->trained); + } else if ( + const IndexIVFFlatDedup* ivfl = + dynamic_cast(idx)) { + uint32_t h = fourcc("IwFd"); + WRITE1(h); +- write_ivf_header(ivfl, f); ++ write_ivf_header(ivfl, f, 1ull<<63); + { + std::vector tab(2 * ivfl->instances.size()); + long i = 0; +@@ -643,14 +654,14 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + dynamic_cast(idx)) { + uint32_t h = fourcc("IwFl"); + WRITE1(h); +- write_ivf_header(ivfl_2, f); ++ write_ivf_header(ivfl_2, f, 1ull<<63); + write_InvertedLists(ivfl_2->invlists, f); + } else if ( + const IndexIVFScalarQuantizer* ivsc = + dynamic_cast(idx)) { + uint32_t h = fourcc("IwSq"); + WRITE1(h); +- write_ivf_header(ivsc, f); ++ write_ivf_header(ivsc, f, 1ull<<63); + write_ScalarQuantizer(&ivsc->sq, f); + WRITE1(ivsc->code_size); + WRITE1(ivsc->by_residual); +@@ -672,7 +683,7 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + } + + WRITE1(h); +- write_ivf_header(iva, f); ++ write_ivf_header(iva, f, 1ull<<63); + WRITE1(iva->code_size); + if (is_LSQ) { + write_LocalSearchQuantizer((LocalSearchQuantizer*)iva->aq, f); +@@ -693,7 +704,7 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + dynamic_cast(idx)) { + uint32_t h = fourcc("IwSh"); + WRITE1(h); +- write_ivf_header(ivsp, f); ++ write_ivf_header(ivsp, f, 1ull<<63); + write_VectorTransform(ivsp->vt, f); + WRITE1(ivsp->nbit); + WRITE1(ivsp->period); +@@ -705,7 +716,7 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + + uint32_t h = fourcc(ivfpqr ? "IwQR" : "IwPQ"); + WRITE1(h); +- write_ivf_header(ivpq, f); ++ write_ivf_header(ivpq, f, 1ull<<63); + WRITE1(ivpq->by_residual); + WRITE1(ivpq->code_size); + write_ProductQuantizer(&ivpq->pq, f); +@@ -720,14 +731,14 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + dynamic_cast(idx)) { + uint32_t h = fourcc("IwIQ"); + WRITE1(h); +- write_index_header(indep, f); +- write_index(indep->quantizer, f); ++ write_index_header(indep, f, dedup_vector_enabled); ++ write_index(indep->quantizer, f, io_flags); + bool has_vt = indep->vt != nullptr; + WRITE1(has_vt); + if (has_vt) { + write_VectorTransform(indep->vt, f); + } +- write_index(indep->index_ivf, f); ++ write_index(indep->index_ivf, f, io_flags); + if (auto index_ivfpq = dynamic_cast(indep->index_ivf)) { + WRITE1(index_ivfpq->use_precomputed_table); + } +@@ -736,27 +747,27 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + dynamic_cast(idx)) { + uint32_t h = fourcc("IxPT"); + WRITE1(h); +- write_index_header(ixpt, f); ++ write_index_header(ixpt, f, dedup_vector_enabled); + int nt = ixpt->chain.size(); + WRITE1(nt); + for (int i = 0; i < nt; i++) { + write_VectorTransform(ixpt->chain[i], f); + } +- write_index(ixpt->index, f); ++ write_index(ixpt->index, f, io_flags); + } else if ( + const MultiIndexQuantizer* imiq = + dynamic_cast(idx)) { + uint32_t h = fourcc("Imiq"); + WRITE1(h); +- write_index_header(imiq, f); ++ write_index_header(imiq, f, dedup_vector_enabled); + write_ProductQuantizer(&imiq->pq, f); + } else if ( + const IndexRefine* idxrf = dynamic_cast(idx)) { + uint32_t h = fourcc("IxRF"); + WRITE1(h); +- write_index_header(idxrf, f); +- write_index(idxrf->base_index, f); +- write_index(idxrf->refine_index, f); ++ write_index_header(idxrf, f, dedup_vector_enabled); ++ write_index(idxrf->base_index, f, io_flags); ++ write_index(idxrf->refine_index, f, io_flags); + WRITE1(idxrf->k_factor); + } else if ( + const IndexIDMap* idxmap = dynamic_cast(idx)) { +@@ -764,8 +775,8 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + : fourcc("IxMp"); + // no need to store additional info for IndexIDMap2 + WRITE1(h); +- write_index_header(idxmap, f); +- write_index(idxmap->index, f); ++ write_index_header(idxmap, f, dedup_vector_enabled); ++ write_index(idxmap->index, f, io_flags); + WRITEVECTOR(idxmap->id_map); + } else if (const IndexHNSW* idxhnsw = dynamic_cast(idx)) { + uint32_t h = dynamic_cast(idx) ? fourcc("IHNf") +@@ -776,7 +787,7 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + : 0; + FAISS_THROW_IF_NOT(h != 0); + WRITE1(h); +- write_index_header(idxhnsw, f); ++ write_index_header(idxhnsw, f, dedup_vector_enabled); + if (h == fourcc("IHc2")) { + WRITE1(idxhnsw->keep_max_size_level0); + auto idx_hnsw_cagra = dynamic_cast(idxhnsw); +@@ -789,7 +800,7 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + uint32_t n4 = fourcc("null"); + WRITE1(n4); + } else { +- write_index(idxhnsw->storage, f); ++ write_index(idxhnsw->storage, f, io_flags); + } + } else if (const IndexNSG* idxnsg = dynamic_cast(idx)) { + uint32_t h = dynamic_cast(idx) ? fourcc("INSf") +@@ -798,7 +809,7 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + : 0; + FAISS_THROW_IF_NOT(h != 0); + WRITE1(h); +- write_index_header(idxnsg, f); ++ write_index_header(idxnsg, f, dedup_vector_enabled); + WRITE1(idxnsg->GK); + WRITE1(idxnsg->build_type); + WRITE1(idxnsg->nndescent_S); +@@ -806,7 +817,7 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + WRITE1(idxnsg->nndescent_L); + WRITE1(idxnsg->nndescent_iter); + write_NSG(&idxnsg->nsg, f); +- write_index(idxnsg->storage, f); ++ write_index(idxnsg->storage, f, io_flags); + } else if ( + const IndexNNDescent* idxnnd = + dynamic_cast(idx)) { +@@ -815,15 +826,15 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + uint32_t h = fourcc("INNf"); + FAISS_THROW_IF_NOT(h != 0); + WRITE1(h); +- write_index_header(idxnnd, f); ++ write_index_header(idxnnd, f, dedup_vector_enabled); + write_NNDescent(&idxnnd->nndescent, f); +- write_index(idxnnd->storage, f); ++ write_index(idxnnd->storage, f, io_flags); + } else if ( + const IndexPQFastScan* idxpqfs = + dynamic_cast(idx)) { + uint32_t h = fourcc("IPfs"); + WRITE1(h); +- write_index_header(idxpqfs, f); ++ write_index_header(idxpqfs, f, dedup_vector_enabled); + write_ProductQuantizer(&idxpqfs->pq, f); + WRITE1(idxpqfs->implem); + WRITE1(idxpqfs->bbs); +@@ -836,7 +847,7 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + dynamic_cast(idx)) { + uint32_t h = fourcc("IwPf"); + WRITE1(h); +- write_ivf_header(ivpq_2, f); ++ write_ivf_header(ivpq_2, f, 1ull<<63); + WRITE1(ivpq_2->by_residual); + WRITE1(ivpq_2->code_size); + WRITE1(ivpq_2->bbs); +@@ -851,21 +862,21 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + // IndexRowwiseMinmaxFloat + uint32_t h = fourcc("IRMf"); + WRITE1(h); +- write_index_header(imm, f); +- write_index(imm->index, f); ++ write_index_header(imm, f, dedup_vector_enabled); ++ write_index(imm->index, f, io_flags); + } else if ( + const IndexRowwiseMinMaxFP16* imm_2 = + dynamic_cast(idx)) { + // IndexRowwiseMinmaxHalf + uint32_t h = fourcc("IRMh"); + WRITE1(h); +- write_index_header(imm_2, f); +- write_index(imm_2->index, f); ++ write_index_header(imm_2, f, dedup_vector_enabled); ++ write_index(imm_2->index, f, io_flags); + } else if ( + const IndexRaBitQ* idxq = dynamic_cast(idx)) { + uint32_t h = fourcc("Ixrq"); + WRITE1(h); +- write_index_header(idx, f); ++ write_index_header(idx, f, dedup_vector_enabled); + write_RaBitQuantizer(&idxq->rabitq, f); + WRITEVECTOR(idxq->codes); + WRITEVECTOR(idxq->center); +@@ -875,7 +886,7 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + dynamic_cast(idx)) { + uint32_t h = fourcc("Iwrq"); + WRITE1(h); +- write_ivf_header(ivrq, f); ++ write_ivf_header(ivrq, f, 1ull<<63); + write_RaBitQuantizer(&ivrq->rabitq, f); + WRITE1(ivrq->code_size); + WRITE1(ivrq->by_residual); +@@ -886,12 +897,12 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) { + } + } + +-void write_index(const Index* idx, FILE* f, int io_flags) { ++void write_index(const Index* idx, FILE* f, uint64_t io_flags) { + FileIOWriter writer(f); + write_index(idx, &writer, io_flags); + } + +-void write_index(const Index* idx, const char* fname, int io_flags) { ++void write_index(const Index* idx, const char* fname, uint64_t io_flags) { + FileIOWriter writer(fname); + write_index(idx, &writer, io_flags); + } +@@ -1002,7 +1013,7 @@ void write_index_binary(const IndexBinary* idx, IOWriter* f) { + uint32_t h = fourcc("IBFf"); + WRITE1(h); + write_index_binary_header(idxff, f); +- write_index(idxff->index, f); ++ write_index(idxff->index, f, 1ull<<63); + } else if ( + const IndexBinaryHNSW* idxhnsw = + dynamic_cast(idx)) { +@@ -1071,4 +1082,4 @@ void write_index_binary(const IndexBinary* idx, const char* fname) { + write_index_binary(idx, &writer); + } + +-} // namespace faiss ++} // namespace faiss +\ No newline at end of file +diff --git a/faiss/impl/io.h b/faiss/impl/io.h +index ebd640fef..e80738f2c 100644 +--- a/faiss/impl/io.h ++++ b/faiss/impl/io.h +@@ -34,6 +34,11 @@ struct IOReader { + // return a file number that can be memory-mapped + virtual int filedescriptor(); + ++ // New method to stream full-precision vectors ++ virtual bool copy(void* dest, int expectedByteSize, bool isFloat) { ++ return false; // default fallback ++ } ++ + virtual ~IOReader() {} + }; + +diff --git a/faiss/index_io.h b/faiss/index_io.h +index 79c013dff..793e915cc 100644 +--- a/faiss/index_io.h ++++ b/faiss/index_io.h +@@ -1,3 +1,4 @@ ++ + /* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * +@@ -11,6 +12,9 @@ + #define FAISS_INDEX_IO_H + + #include ++#include ++#include ++#include + + /** I/O functions can read/write to a filename, a file handle or to an + * object that abstracts the medium. +@@ -32,10 +36,11 @@ struct InvertedLists; + + /// skip the storage for graph-based indexes + const int IO_FLAG_SKIP_STORAGE = 1; ++const uint64_t DEDUPE_VECTORS_OPT_DISABLED = (1ull << 63); + +-void write_index(const Index* idx, const char* fname, int io_flags = 0); +-void write_index(const Index* idx, FILE* f, int io_flags = 0); +-void write_index(const Index* idx, IOWriter* writer, int io_flags = 0); ++void write_index(const Index* idx, const char* fname, uint64_t io_flags = 0); ++void write_index(const Index* idx, FILE* f, uint64_t io_flags = 0); ++void write_index(const Index* idx, IOWriter* writer, uint64_t io_flags = 0); + + void write_index_binary(const IndexBinary* idx, const char* fname); + void write_index_binary(const IndexBinary* idx, FILE* f); +@@ -64,13 +69,13 @@ const int IO_FLAG_MMAP = IO_FLAG_SKIP_IVF_DATA | 0x646f0000; + // after OnDiskInvertedLists get properly updated. + const int IO_FLAG_MMAP_IFC = 1 << 9; + +-Index* read_index(const char* fname, int io_flags = 0); +-Index* read_index(FILE* f, int io_flags = 0); +-Index* read_index(IOReader* reader, int io_flags = 0); ++Index* read_index(const char* fname, uint64_t io_flags = 0); ++Index* read_index(FILE* f, uint64_t io_flags = 0); ++Index* read_index(IOReader* reader, uint64_t io_flags = 0); + +-IndexBinary* read_index_binary(const char* fname, int io_flags = 0); +-IndexBinary* read_index_binary(FILE* f, int io_flags = 0); +-IndexBinary* read_index_binary(IOReader* reader, int io_flags = 0); ++IndexBinary* read_index_binary(const char* fname, uint64_t io_flags = 0); ++IndexBinary* read_index_binary(FILE* f, uint64_t io_flags = 0); ++IndexBinary* read_index_binary(IOReader* reader, uint64_t io_flags = 0); + + void write_VectorTransform(const VectorTransform* vt, const char* fname); + void write_VectorTransform(const VectorTransform* vt, IOWriter* f); +@@ -85,7 +90,7 @@ void write_ProductQuantizer(const ProductQuantizer* pq, const char* fname); + void write_ProductQuantizer(const ProductQuantizer* pq, IOWriter* f); + + void write_InvertedLists(const InvertedLists* ils, IOWriter* f); +-InvertedLists* read_InvertedLists(IOReader* reader, int io_flags = 0); ++InvertedLists* read_InvertedLists(IOReader* reader, uint64_t io_flags = 0); + + } // namespace faiss + +-- +2.39.5 (Apple Git-154) + diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index d906e47266..2fe7a48fba 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -252,7 +252,7 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * // Create faiss index std::unique_ptr indexWriter; - indexWriter.reset(faiss::read_index(&vectorIoReader, 0)); + indexWriter.reset(faiss::read_index(&vectorIoReader, 1ull<<63)); auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get()); @@ -265,7 +265,7 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * // Write the index to disk knn_jni::stream::NativeEngineIndexOutputMediator mediator {jniUtil, env, output}; knn_jni::stream::FaissOpenSearchIOWriter writer {&mediator}; - faiss::write_index(&idMap, &writer); + faiss::write_index(&idMap, &writer, 1ull<<63); mediator.flush(); } @@ -458,7 +458,8 @@ jlong knn_jni::faiss_wrapper::LoadIndexWithStream(faiss::IOReader* ioReader) { faiss::read_index(ioReader, faiss::IO_FLAG_READ_ONLY | faiss::IO_FLAG_PQ_SKIP_SDC_TABLE - | faiss::IO_FLAG_SKIP_PRECOMPUTE_TABLE); + | faiss::IO_FLAG_SKIP_PRECOMPUTE_TABLE + | 1ull<<63); return (jlong) indexReader; } @@ -956,7 +957,7 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti // Now that indexWriter is trained, we just load the bytes into an array and return faiss::VectorIOWriter vectorIoWriter; - faiss::write_index(indexWriter.get(), &vectorIoWriter); + faiss::write_index(indexWriter.get(), &vectorIoWriter, 1ull<<63); // Wrap in smart pointer std::unique_ptr jbytesBuffer; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 6b8e52eb74..07b00a9ff3 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -24,6 +24,7 @@ import org.opensearch.knn.plugin.stats.KNNGraphValue; import java.io.IOException; +import java.util.function.Supplier; import static org.opensearch.knn.common.FieldInfoExtractor.extractKNNEngine; import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; @@ -65,9 +66,18 @@ private boolean isKNNBinaryFieldRequired(FieldInfo field) { && KNNEngine.getEnginesThatCreateCustomSegmentFiles().stream().anyMatch(engine -> engine == knnEngine); } - public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge) throws IOException { + public KNNVectorValues addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge) throws IOException { final VectorDataType vectorDataType = extractVectorDataType(field); - final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, valuesProducer.getBinary(field)); + + Supplier> knnVectorValuesSupplier = () -> { + try { + return KNNVectorValuesFactory.getVectorValues(vectorDataType, valuesProducer.getBinary(field)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }; + + final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); // For BDV it is fine to use knnVectorValues.totalLiveDocs() as we already run the full loop to calculate total // live docs @@ -76,6 +86,8 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, } else { NativeIndexWriter.getWriter(field, state).flushIndex(() -> knnVectorValues, (int) knnVectorValues.totalLiveDocs()); } + + return knnVectorValuesSupplier.get(); } /** diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index 70bee76d46..42070699a9 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -33,6 +33,7 @@ import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.CompressionLevel; import org.opensearch.knn.index.mapper.Mode; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.index.vectorvalues.TestVectorValues; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.engine.MethodComponentContext; @@ -109,8 +110,9 @@ public void testAddBinaryField_withKNN() throws IOException { KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, null) { @Override - public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge) { + public KNNVectorValues addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge) { called[0] = true; + return null; } }; @@ -146,8 +148,9 @@ public void testAddBinaryField_withoutKNN() throws IOException { KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, state) { @Override - public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge) { + public KNNVectorValues addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge) { called[0] = true; + return null; } }; @@ -248,7 +251,7 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException assertValidFooter(state.directory, expectedFile); // The document should be readable by nmslib - assertLoadableByEngine(null, state, expectedFile, knnEngine, spaceType, dimension); + assertLoadableByEngine(null, state, expectedFile, knnEngine, spaceType, dimension, null); // The graph creation statistics should be updated assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); @@ -303,7 +306,7 @@ public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException assertValidFooter(state.directory, expectedFile); // The document should be readable by nmslib - assertLoadableByEngine(null, state, expectedFile, knnEngine, spaceType, dimension); + assertLoadableByEngine(null, state, expectedFile, knnEngine, spaceType, dimension, null); // The graph creation statistics should be updated assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); @@ -360,7 +363,11 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException docsInSegment, dimension ); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, false); + KNNVectorValues knnVectorValues = knn80DocValuesConsumer.addKNNBinaryField( + fieldInfoArray[0], + randomVectorDocValuesProducer, + false + ); // The document should be created in the correct location String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); @@ -370,7 +377,7 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException assertValidFooter(state.directory, expectedFile); // The document should be readable by faiss - assertLoadableByEngine(HNSW_METHODPARAMETERS, state, expectedFile, knnEngine, spaceType, dimension); + assertLoadableByEngine(HNSW_METHODPARAMETERS, state, expectedFile, knnEngine, spaceType, dimension, knnVectorValues); // The graph creation statistics should be updated assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); @@ -542,7 +549,7 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio assertValidFooter(state.directory, expectedFile); // The document should be readable by faiss - assertLoadableByEngine(HNSW_METHODPARAMETERS, state, expectedFile, knnEngine, spaceType, dimension); + assertLoadableByEngine(HNSW_METHODPARAMETERS, state, expectedFile, knnEngine, spaceType, dimension, null); // The graph creation statistics should be updated assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java index 64c5371dbf..a147acb19e 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java @@ -25,6 +25,7 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.store.IndexInputWithBuffer; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.jni.JNIService; import java.io.IOException; @@ -199,10 +200,12 @@ public static void assertLoadableByEngine( String fileName, KNNEngine knnEngine, SpaceType spaceType, - int dimension + int dimension, + KNNVectorValues knnVectorValues ) { try (final IndexInput indexInput = state.directory.openInput(fileName, IOContext.DEFAULT)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + indexInputWithBuffer.setKnnVectorValues(knnVectorValues); long indexPtr = JNIService.loadIndex( indexInputWithBuffer, Maps.newHashMap(ImmutableMap.of(SPACE_TYPE, spaceType.getValue())), diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index 3f06e85e1a..36c79f3fe8 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -37,6 +37,7 @@ import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.store.IndexInputWithBuffer; import org.opensearch.knn.index.store.IndexOutputWithBuffer; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import java.io.IOException; import java.net.URL; @@ -632,8 +633,12 @@ public void testLoadIndex_when_io_exception_was_raised() { ); assertTrue(directory.fileLength(indexFileName1) > 0); + // KnnVectorValues ioexceptionKnnVectorValues = TestUtils.createInMemoryFloatVectorValues(vectors, vectors[0].length); + KNNVectorValues ioexceptionKnnVectorValues = TestUtils.createKNNFloatVectorValues(docIds, vectors); + final IndexInput raiseIOExceptionIndexInput = new RaisingIOExceptionIndexInput(); final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(raiseIOExceptionIndexInput); + indexInputWithBuffer.setKnnVectorValues(ioexceptionKnnVectorValues); try { JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); @@ -987,7 +992,12 @@ public void testLoadIndex_faiss_valid() throws IOException { assertTrue(directory.fileLength(indexFileName1) > 0); try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { + KNNVectorValues knnVectorValues = TestUtils.createKNNFloatVectorValues( + testData.indexData.docs, + testData.indexData.vectors + ); final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + indexInputWithBuffer.setKnnVectorValues(knnVectorValues); long pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, pointer); } catch (Throwable e) { @@ -1092,7 +1102,7 @@ public void testQueryIndex_faiss_invalid_badPointer() { expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, null, KNNEngine.FAISS, null, 0, null)); } - public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { + public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException, ReflectiveOperationException { Path tempDirPath = createTempDir(); String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; @@ -1110,7 +1120,12 @@ public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { final long pointer; try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { + KNNVectorValues knnVectorValues = TestUtils.createKNNFloatVectorValues( + testData.indexData.docs, + testData.indexData.vectors + ); final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + indexInputWithBuffer.setKnnVectorValues(knnVectorValues); pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, pointer); } catch (Throwable e) { @@ -1122,7 +1137,7 @@ public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { } } - public void testQueryIndex_faiss_streaming_invalid_nullQueryVector() throws IOException { + public void testQueryIndex_faiss_streaming_invalid_nullQueryVector() throws IOException, ReflectiveOperationException { Path tempDirPath = createTempDir(); String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; try (Directory directory = newFSDirectory(tempDirPath)) { @@ -1139,7 +1154,12 @@ public void testQueryIndex_faiss_streaming_invalid_nullQueryVector() throws IOEx final long pointer; try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { + KNNVectorValues knnVectorValues = TestUtils.createKNNFloatVectorValues( + testData.indexData.docs, + testData.indexData.vectors + ); final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + indexInputWithBuffer.setKnnVectorValues(knnVectorValues); pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, pointer); } catch (Throwable e) { @@ -1151,7 +1171,7 @@ public void testQueryIndex_faiss_streaming_invalid_nullQueryVector() throws IOEx } } - public void testQueryIndex_faiss_valid() throws IOException { + public void testQueryIndex_faiss_valid() throws IOException, ReflectiveOperationException { int k = 10; int efSearch = 100; @@ -1176,7 +1196,12 @@ public void testQueryIndex_faiss_valid() throws IOException { final long pointer; try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { + KNNVectorValues knnVectorValues = TestUtils.createKNNFloatVectorValues( + testData.indexData.docs, + testData.indexData.vectors + ); final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + indexInputWithBuffer.setKnnVectorValues(knnVectorValues); pointer = JNIService.loadIndex( indexInputWithBuffer, ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), @@ -1244,8 +1269,14 @@ public void testQueryIndex_faiss_streaming_valid() throws IOException { assertTrue(directory.fileLength(indexFileName1) > 0); try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.READONCE)) { + KNNVectorValues knnVectorValues = TestUtils.createKNNFloatVectorValues( + testData.indexData.docs, + testData.indexData.vectors + ); + IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + indexInputWithBuffer.setKnnVectorValues(knnVectorValues); long pointer = JNIService.loadIndex( - new IndexInputWithBuffer(indexInput), + indexInputWithBuffer, ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), KNNEngine.FAISS ); @@ -1280,12 +1311,15 @@ public void testQueryIndex_faiss_streaming_valid() throws IOException { assertEquals(0, results.length); } // End for } // End try + catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } } // End for } // End for } } - public void testQueryIndex_faiss_parentIds() throws IOException { + public void testQueryIndex_faiss_parentIds() throws IOException, ReflectiveOperationException { int k = 100; int efSearch = 100; @@ -1312,7 +1346,12 @@ public void testQueryIndex_faiss_parentIds() throws IOException { final long pointer; try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { + KNNVectorValues knnVectorValues = TestUtils.createKNNFloatVectorValues( + testDataNested.indexData.docs, + testDataNested.indexData.vectors + ); final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + indexInputWithBuffer.setKnnVectorValues(knnVectorValues); pointer = JNIService.loadIndex( indexInputWithBuffer, ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), @@ -1344,14 +1383,14 @@ public void testQueryIndex_faiss_parentIds() throws IOException { } } - public void testQueryIndex_faissCagra_parentIds() throws IOException { + public void testQueryIndex_faissCagra_parentIds() throws IOException, ReflectiveOperationException { doTestQueryIndex_faissCagra_parentIds(SpaceType.L2); doTestQueryIndex_faissCagra_parentIds(SpaceType.INNER_PRODUCT); doTestQueryIndex_faissCagra_parentIds(SpaceType.COSINESIMIL); } - private void doTestQueryIndex_faissCagra_parentIds(SpaceType spaceType) throws IOException { + private void doTestQueryIndex_faissCagra_parentIds(SpaceType spaceType) throws IOException, ReflectiveOperationException { int k = 100; int efSearch = 100; @@ -1363,7 +1402,12 @@ private void doTestQueryIndex_faissCagra_parentIds(SpaceType spaceType) throws I // This faiss graph binary was created with the IndexHNSWCagra index (base_level_only==true) containing the // test_vectors_nested_1000x128.json vectors try (IndexInput indexInput = loadHnswBinary("data/remoteindexbuild/faiss_hnsw_cagra_nested_float_1000_vectors_128_dims.bin")) { + KNNVectorValues knnVectorValues = TestUtils.createKNNFloatVectorValues( + testDataNested.indexData.docs, + testDataNested.indexData.vectors + ); final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + indexInputWithBuffer.setKnnVectorValues(knnVectorValues); pointer = JNIService.loadIndex( indexInputWithBuffer, ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), @@ -1419,8 +1463,14 @@ public void testQueryIndex_faiss_streaming_parentIds() throws IOException { assertTrue(directory.fileLength(indexFileName1) > 0); try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.READONCE)) { + KNNVectorValues knnVectorValues = TestUtils.createKNNFloatVectorValues( + testDataNested.indexData.docs, + testDataNested.indexData.vectors + ); + IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + indexInputWithBuffer.setKnnVectorValues(knnVectorValues); long pointer = JNIService.loadIndex( - new IndexInputWithBuffer(indexInput), + indexInputWithBuffer, ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), KNNEngine.FAISS ); @@ -1442,6 +1492,9 @@ public void testQueryIndex_faiss_streaming_parentIds() throws IOException { assertEquals(results.length, parentIdSet.size()); } // End for } // End try + catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } } // End for } // End for } @@ -1476,7 +1529,12 @@ public void testQueryBinaryIndex_faiss_valid() { final long pointer; try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { + KNNVectorValues knnVectorValues = TestUtils.createKNNFloatVectorValues( + testData.indexData.docs, + testData.indexData.vectors + ); final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + indexInputWithBuffer.setKnnVectorValues(knnVectorValues); pointer = JNIService.loadIndex( indexInputWithBuffer, ImmutableMap.of( @@ -1529,8 +1587,14 @@ public void testQueryBinaryIndex_faiss_streaming_valid() { assertTrue(directory.fileLength(indexFileName1) > 0); try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.READONCE)) { + KNNVectorValues knnVectorValues = TestUtils.createKNNFloatVectorValues( + testData.indexData.docs, + testData.indexData.vectors + ); + IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + indexInputWithBuffer.setKnnVectorValues(knnVectorValues); long pointer = JNIService.loadIndex( - new IndexInputWithBuffer(indexInput), + indexInputWithBuffer, ImmutableMap.of( INDEX_DESCRIPTION_PARAMETER, method, @@ -1628,7 +1692,7 @@ public void testFree_nmslib_valid() throws IOException { } } - public void testFree_faiss_valid() throws IOException { + public void testFree_faiss_valid() throws IOException, ReflectiveOperationException { Path tempDirPath = createTempDir(); String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; @@ -1646,7 +1710,12 @@ public void testFree_faiss_valid() throws IOException { final long pointer; try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.DEFAULT)) { + KNNVectorValues knnVectorValues = TestUtils.createKNNFloatVectorValues( + testDataNested.indexData.docs, + testDataNested.indexData.vectors + ); final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + indexInputWithBuffer.setKnnVectorValues(knnVectorValues); pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, pointer); } catch (Throwable e) { @@ -2012,7 +2081,12 @@ public void testIsIndexIVFPQL2() { String faissHNSWIndex = createFaissHNSWIndex(directory, SpaceType.L2); try (IndexInput indexInput = directory.openInput(faissHNSWIndex, IOContext.DEFAULT)) { + KNNVectorValues knnVectorValues = TestUtils.createKNNFloatVectorValues( + testDataNested.indexData.docs, + testDataNested.indexData.vectors + ); final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + indexInputWithBuffer.setKnnVectorValues(knnVectorValues); long faissHNSWAddress = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertFalse(JNIService.isSharedIndexStateRequired(faissHNSWAddress, KNNEngine.FAISS)); JNIService.free(faissHNSWAddress, KNNEngine.FAISS); diff --git a/src/test/java/org/opensearch/knn/memoryoptsearch/FaissMemoryOptimizedSearcherTests.java b/src/test/java/org/opensearch/knn/memoryoptsearch/FaissMemoryOptimizedSearcherTests.java index f87e2bf6c1..0eb8dbdd8e 100644 --- a/src/test/java/org/opensearch/knn/memoryoptsearch/FaissMemoryOptimizedSearcherTests.java +++ b/src/test/java/org/opensearch/knn/memoryoptsearch/FaissMemoryOptimizedSearcherTests.java @@ -8,10 +8,7 @@ import lombok.RequiredArgsConstructor; import lombok.SneakyThrows; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; -import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.FieldInfos; -import org.apache.lucene.index.SegmentInfo; -import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.*; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.ScoreDoc; @@ -24,6 +21,7 @@ import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.util.FixedBitSet; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.TestUtils; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.generate.IndexingType; import org.opensearch.knn.generate.SearchTestHelper; @@ -318,11 +316,16 @@ private void doSearchTest( // Build FAISS index final BuildInfo buildInfo = buildFaissIndex(testingSpec, TOTAL_NUM_DOCS_IN_SEGMENT, indexingType, spaceType); + KNNVectorValues knnVectorValues = testingSpec.dataType == VectorDataType.FLOAT + ? createKNNFloatVectorValues(buildInfo.documentIds, buildInfo.vectors.floatVectors) + : createKNNByteVectorValues(buildInfo.documentIds, buildInfo.vectors.byteVectors); + // Load FAISS index via JNI long indexPointer = -1; try (final Directory directory = newFSDirectory(buildInfo.tempDirPath)) { try (final IndexInput input = directory.openInput(buildInfo.faissIndexFile, IOContext.READONCE)) { final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(input); + indexInputWithBuffer.setKnnVectorValues(knnVectorValues); if (testingSpec.isAdcEnabled) { buildInfo.parameters.put("data_type", VectorDataType.FLOAT.getValue()); buildInfo.parameters.put(ADC_ENABLED_FAISS_INDEX_INTERNAL_PARAMETER, true); @@ -512,12 +515,28 @@ private static KNNQueryResult[] doSearchViaVectorReader( // Make SegmentReadState and do search try (final Directory directory = new MMapDirectory(buildInfo.tempDirPath)) { final SegmentReadState readState = new SegmentReadState(directory, segmentInfo, fieldInfos, IOContext.DEFAULT); - try ( - NativeEngines990KnnVectorsReader vectorsReader = new NativeEngines990KnnVectorsReader( - readState, - mock(FlatVectorsReader.class) - ) - ) { + + FlatVectorsReader flatVectorsReader = mock(FlatVectorsReader.class); + if (vectorDataType == VectorDataType.FLOAT) { + FloatVectorValues floatVectorValues = TestUtils.createInMemoryFloatVectorValuesForList( + buildInfo.vectors.floatVectors, + DIMENSIONS, + buildInfo.documentIds + ); + when(flatVectorsReader.getFloatVectorValues(TARGET_FIELD)).thenReturn(floatVectorValues); + } + if (vectorDataType == VectorDataType.BYTE) { + ByteVectorValues byteVectorValues = TestUtils.createInMemoryByteVectorValuesForList( + buildInfo.vectors.byteVectors, + DIMENSIONS, + buildInfo.documentIds + ); + when(flatVectorsReader.getByteVectorValues(TARGET_FIELD)).thenReturn(byteVectorValues); + } + + try (NativeEngines990KnnVectorsReader vectorsReader = new NativeEngines990KnnVectorsReader(readState, flatVectorsReader + // mock(FlatVectorsReader.class) + )) { if (vectorDataType == VectorDataType.FLOAT) { vectorsReader.search(TARGET_FIELD, (float[]) query, knnCollector, acceptDocs); } else if (vectorDataType == VectorDataType.BYTE || vectorDataType == VectorDataType.BINARY) { diff --git a/src/testFixtures/java/org/opensearch/knn/TestUtils.java b/src/testFixtures/java/org/opensearch/knn/TestUtils.java index fe64336bfb..eefa369ec7 100644 --- a/src/testFixtures/java/org/opensearch/knn/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/knn/TestUtils.java @@ -9,6 +9,9 @@ import lombok.AllArgsConstructor; import lombok.Getter; import lombok.Setter; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexOutput; @@ -21,12 +24,19 @@ import java.io.FileReader; import java.io.IOException; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.store.IndexOutputWithBuffer; +import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesIterator; +import org.opensearch.knn.index.vectorvalues.VectorValueExtractorStrategy; import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.plugin.script.KNNScoringUtil; +import java.io.UnsupportedEncodingException; +import java.lang.reflect.Constructor; import java.util.Base64; import java.util.Collections; import java.util.Comparator; @@ -470,4 +480,240 @@ public static void createIndex( } } } + + public static FloatVectorValues createInMemoryFloatVectorValuesForList( + List vectors, + int dimension, + final List docids + ) { + return new FloatVectorValues() { + + @Override + public int size() { + return vectors.size(); + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public float[] vectorValue(int ord) { + int docId = docids.get(ord); + return vectors.get(docId); + } + + @Override + public FloatVectorValues copy() { + return this; + } + + @Override + public DocIndexIterator iterator() { + return new DocIndexIterator() { + int doc = -1; + + @Override + public int index() { + return doc; + } + + @Override + public int nextDoc() { + return ++doc < vectors.size() ? doc : NO_MORE_DOCS; + } + + @Override + public int advance(int target) throws IOException { + doc = target; + return doc < vectors.size() ? doc : NO_MORE_DOCS; + } + + @Override + public long cost() { + return vectors.size(); + } + + @Override + public int docID() { + return doc; + } + }; + } + }; + } + + public static ByteVectorValues createInMemoryByteVectorValuesForList(List vectors, int dimension, final List docids) { + return new ByteVectorValues() { + + @Override + public int size() { + return vectors.size(); + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public byte[] vectorValue(int ord) { + int docId = docids.get(ord); + return vectors.get(docId); + } + + @Override + public ByteVectorValues copy() { + return this; + } + + @Override + public DocIndexIterator iterator() { + return new DocIndexIterator() { + int doc = -1; + + @Override + public int index() { + return doc; + } + + @Override + public int nextDoc() { + return ++doc < vectors.size() ? doc : NO_MORE_DOCS; + } + + @Override + public int advance(int target) throws IOException { + doc = target; + return doc < vectors.size() ? doc : NO_MORE_DOCS; + } + + @Override + public long cost() { + return vectors.size(); + } + + @Override + public int docID() { + return doc; + } + }; + } + }; + } + + public static KNNFloatVectorValues createKNNFloatVectorValues(final int[] documentIds, final float[][] vectors) + throws ReflectiveOperationException { + + final KNNVectorValuesIterator iterator = new KNNVectorValuesIterator() { + private int index = -1; + + @Override + public int docId() { + if (index == -1) { + return -1; + } else if (index == DocIdSetIterator.NO_MORE_DOCS) { + return DocIdSetIterator.NO_MORE_DOCS; + } + return documentIds[index]; + } + + @Override + public int advance(int docId) throws IOException { + throw new UnsupportedEncodingException(); + } + + @Override + public int nextDoc() { + if ((index + 1) >= documentIds.length) { + index = DocIdSetIterator.NO_MORE_DOCS; + return DocIdSetIterator.NO_MORE_DOCS; + } + return documentIds[++index]; + } + + @Override + public DocIdSetIterator getDocIdSetIterator() { + return null; + } + + @Override + public long liveDocs() { + return documentIds.length; + } + + @Override + public VectorValueExtractorStrategy getVectorExtractorStrategy() { + return new VectorValueExtractorStrategy() { + @Override + public float[] extract(VectorDataType vectorDataType, KNNVectorValuesIterator vectorValuesIterator) { + return vectors[index]; + } + }; + } + }; + + // Instantiate KNNFloatVectorValues + Constructor constructor = KNNFloatVectorValues.class.getDeclaredConstructor(KNNVectorValuesIterator.class); + constructor.setAccessible(true); + return constructor.newInstance(iterator); + } + + public static KNNByteVectorValues createKNNByteVectorValues(final int[] documentIds, final byte[][] vectors) + throws ReflectiveOperationException { + + final KNNVectorValuesIterator iterator = new KNNVectorValuesIterator() { + private int index = -1; + + @Override + public int docId() { + if (index == -1) { + return -1; + } else if (index == DocIdSetIterator.NO_MORE_DOCS) { + return DocIdSetIterator.NO_MORE_DOCS; + } + return documentIds[index]; + } + + @Override + public int advance(int docId) throws IOException { + throw new UnsupportedEncodingException(); + } + + @Override + public int nextDoc() { + if ((index + 1) >= documentIds.length) { + index = DocIdSetIterator.NO_MORE_DOCS; + return DocIdSetIterator.NO_MORE_DOCS; + } + return documentIds[++index]; + } + + @Override + public DocIdSetIterator getDocIdSetIterator() { + return null; + } + + @Override + public long liveDocs() { + return documentIds.length; + } + + @Override + public VectorValueExtractorStrategy getVectorExtractorStrategy() { + return new VectorValueExtractorStrategy() { + @Override + public byte[] extract(VectorDataType vectorDataType, KNNVectorValuesIterator vectorValuesIterator) { + return vectors[index]; + } + }; + } + }; + + // Instantiate KNNByteVectorValues + Constructor constructor = KNNByteVectorValues.class.getDeclaredConstructor(KNNVectorValuesIterator.class); + constructor.setAccessible(true); + return constructor.newInstance(iterator); + } }