Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jni/cmake/init-faiss.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ if(NOT DEFINED APPLY_LIB_PATCHES OR "${APPLY_LIB_PATCHES}" STREQUAL true)
list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0004-Custom-patch-to-support-binary-vector.patch")
list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0005-Custom-patch-to-support-multi-vector-IndexHNSW-search_level_0.patch")
list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0006-Add-nested-search-support-for-IndexBinaryHNSWCagra.patch")
list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0007-Custom-patch-to-support-dedup-flatvector-faiss-file.patch")

# Get patch id of the last commit
execute_process(COMMAND sh -c "git --no-pager show HEAD | git patch-id --stable" OUTPUT_VARIABLE PATCH_ID_OUTPUT_FROM_COMMIT WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss)
Expand Down
109 changes: 91 additions & 18 deletions jni/include/faiss_stream_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,101 @@ namespace stream {
* This will then indirectly call the mediator component and eventually read required bytes from Lucene's IndexInput.
*/
class FaissOpenSearchIOReader final : public faiss::IOReader {
public:
explicit FaissOpenSearchIOReader(NativeEngineIndexInputMediator *_mediator)
: faiss::IOReader(),
mediator(knn_jni::util::ParameterCheck::require_non_null(_mediator, "mediator")) {
name = "FaissOpenSearchIOReader";
}
public:
explicit FaissOpenSearchIOReader(NativeEngineIndexInputMediator* _mediator)
: faiss::IOReader(),
mediator(knn_jni::util::ParameterCheck::require_non_null(_mediator, "mediator")) {

size_t operator()(void *ptr, size_t size, size_t nitems) final {
const auto readBytes = size * nitems;
if (readBytes > 0) {
// Mediator calls IndexInput, then copy read bytes to `ptr`.
mediator->copyBytes(readBytes, (uint8_t *) ptr);
name = "FaissOpenSearchIOReader";
}
return nitems;
}

int filedescriptor() final {
throw std::runtime_error("filedescriptor() is not supported in FaissOpenSearchIOReader.");
}
~FaissOpenSearchIOReader() override {
JNIEnv* env = mediator->getEnv();
if (vectorReaderGlobalRef && env) {
env->DeleteGlobalRef(vectorReaderGlobalRef);
}
}

private:
NativeEngineIndexInputMediator *mediator;
size_t operator()(void* ptr, size_t size, size_t nitems) override {
const auto bytes = size * nitems;
mediator->copyBytes(bytes, static_cast<uint8_t*>(ptr));
return nitems;
}

bool copy(void* dest, int expectedByteSize, bool isFloat) override {
JNIEnv* env = mediator->getEnv();
if (env == nullptr) return false;

jobject readStream = mediator->getJavaObject();
if (!vectorReaderGlobalRef) {
jclass streamClass = env->GetObjectClass(readStream);
if (env->ExceptionCheck() || streamClass == nullptr) return false;

jmethodID getVectorsMid = env->GetMethodID(
streamClass,
"getFullPrecisionVectors",
"()Lorg/opensearch/knn/index/store/VectorReader;"
);
if (env->ExceptionCheck() || !getVectorsMid) return false;

jobject vectorReader = env->CallObjectMethod(readStream, getVectorsMid);
if (env->ExceptionCheck() || vectorReader == nullptr) return false;

vectorReaderGlobalRef = env->NewGlobalRef(vectorReader);
if (env->ExceptionCheck() || vectorReaderGlobalRef == nullptr) return false;
}

if (isFloat) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add two helper functions to reduce branching? It's hard to tell where float and byte cases differ in this code, so helper would improve readability and maintainability.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it make sense, thank you.

jclass vectorReaderClass = env->GetObjectClass(vectorReaderGlobalRef);
jmethodID nextFloatMid = env->GetMethodID(vectorReaderClass, "nextFloatVector", "()[F");
if (env->ExceptionCheck() || !nextFloatMid) return false;

jfloatArray vector = (jfloatArray) env->CallObjectMethod(vectorReaderGlobalRef, nextFloatMid);
if (env->ExceptionCheck() || vector == nullptr) return false;

jsize length = env->GetArrayLength(vector);
jfloat* elems = env->GetFloatArrayElements(vector, nullptr);

JNIReleaseElements release_elems([=]() {
env->ReleaseFloatArrayElements(vector, elems, JNI_ABORT);
});

int vectorByteSize = sizeof(float) * length;
if (vectorByteSize != expectedByteSize) return false;

std::memcpy(dest, elems, vectorByteSize);
return true;

} else {
jclass vectorReaderClass = env->GetObjectClass(vectorReaderGlobalRef);
jmethodID nextByteMid = env->GetMethodID(vectorReaderClass, "nextByteVector", "()[B");
if (env->ExceptionCheck() || !nextByteMid) return false;

jbyteArray vector = (jbyteArray) env->CallObjectMethod(vectorReaderGlobalRef, nextByteMid);
if (env->ExceptionCheck() || vector == nullptr) return false;

jsize length = env->GetArrayLength(vector);
jbyte* elems = env->GetByteArrayElements(vector, nullptr);

JNIReleaseElements release_elems([=]() {
env->ReleaseByteArrayElements(vector, elems, JNI_ABORT);
});

int vectorByteSize = sizeof(float) * length;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be sizeof(byte) if it's the byte vector case?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a problem with that, for byte and other quantized(IxSQ) cases faiss expects float vectors and it internally performs quantization. I tried using the byte vectors array but that failed. So what we are doing here is getting byte vectors from VectorReader and then typecasting those vectors to float, then in read_index for IxSQ case what we are doing is using faiss sq.compute_codes so that faiss can get vectors in proper format.

if (vectorByteSize != expectedByteSize) return false;

float* floatDest = static_cast<float*>(dest);
for (int i = 0; i < length; ++i) {
floatDest[i] = static_cast<float>(elems[i]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use memcpy

Copy link
Author

@sobhu17 sobhu17 Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, Got it!

}

return true;
}
}

private:
NativeEngineIndexInputMediator* mediator;
jobject vectorReaderGlobalRef = nullptr;
}; // class FaissOpenSearchIOReader


Expand Down
12 changes: 12 additions & 0 deletions jni/include/native_engines_stream_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ class NativeEngineIndexInputMediator {
remainingBytesMethod(getRemainingBytesMethod(_jni_interface, _env)) {
}

JNIEnv* getEnv() const {
return env;
}

jobject getJavaObject() const {
return indexInput;
}

JNIUtilInterface* getJNIUtil() const {
return jni_interface;
}

void copyBytes(int64_t nbytes, uint8_t * RESTRICT destination) {
auto jclazz = getIndexInputWithBufferClass(jni_interface, env);

Expand Down
Loading
Loading