-
Notifications
You must be signed in to change notification settings - Fork 169
Deduplicating Vectors for Space Optimization - Patch #2840
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dedup-vectors-optimization
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,28 +32,101 @@ namespace stream { | |
* This will then indirectly call the mediator component and eventually read required bytes from Lucene's IndexInput. | ||
*/ | ||
class FaissOpenSearchIOReader final : public faiss::IOReader { | ||
public: | ||
explicit FaissOpenSearchIOReader(NativeEngineIndexInputMediator *_mediator) | ||
: faiss::IOReader(), | ||
mediator(knn_jni::util::ParameterCheck::require_non_null(_mediator, "mediator")) { | ||
name = "FaissOpenSearchIOReader"; | ||
} | ||
public: | ||
explicit FaissOpenSearchIOReader(NativeEngineIndexInputMediator* _mediator) | ||
: faiss::IOReader(), | ||
mediator(knn_jni::util::ParameterCheck::require_non_null(_mediator, "mediator")) { | ||
|
||
size_t operator()(void *ptr, size_t size, size_t nitems) final { | ||
const auto readBytes = size * nitems; | ||
if (readBytes > 0) { | ||
// Mediator calls IndexInput, then copy read bytes to `ptr`. | ||
mediator->copyBytes(readBytes, (uint8_t *) ptr); | ||
name = "FaissOpenSearchIOReader"; | ||
} | ||
return nitems; | ||
} | ||
|
||
int filedescriptor() final { | ||
throw std::runtime_error("filedescriptor() is not supported in FaissOpenSearchIOReader."); | ||
} | ||
~FaissOpenSearchIOReader() override { | ||
JNIEnv* env = mediator->getEnv(); | ||
if (vectorReaderGlobalRef && env) { | ||
env->DeleteGlobalRef(vectorReaderGlobalRef); | ||
} | ||
} | ||
|
||
private: | ||
NativeEngineIndexInputMediator *mediator; | ||
size_t operator()(void* ptr, size_t size, size_t nitems) override { | ||
const auto bytes = size * nitems; | ||
mediator->copyBytes(bytes, static_cast<uint8_t*>(ptr)); | ||
return nitems; | ||
} | ||
|
||
bool copy(void* dest, int expectedByteSize, bool isFloat) override { | ||
JNIEnv* env = mediator->getEnv(); | ||
if (env == nullptr) return false; | ||
|
||
jobject readStream = mediator->getJavaObject(); | ||
if (!vectorReaderGlobalRef) { | ||
jclass streamClass = env->GetObjectClass(readStream); | ||
if (env->ExceptionCheck() || streamClass == nullptr) return false; | ||
|
||
jmethodID getVectorsMid = env->GetMethodID( | ||
streamClass, | ||
"getFullPrecisionVectors", | ||
"()Lorg/opensearch/knn/index/store/VectorReader;" | ||
); | ||
if (env->ExceptionCheck() || !getVectorsMid) return false; | ||
|
||
jobject vectorReader = env->CallObjectMethod(readStream, getVectorsMid); | ||
if (env->ExceptionCheck() || vectorReader == nullptr) return false; | ||
|
||
vectorReaderGlobalRef = env->NewGlobalRef(vectorReader); | ||
if (env->ExceptionCheck() || vectorReaderGlobalRef == nullptr) return false; | ||
} | ||
|
||
if (isFloat) { | ||
jclass vectorReaderClass = env->GetObjectClass(vectorReaderGlobalRef); | ||
jmethodID nextFloatMid = env->GetMethodID(vectorReaderClass, "nextFloatVector", "()[F"); | ||
if (env->ExceptionCheck() || !nextFloatMid) return false; | ||
|
||
jfloatArray vector = (jfloatArray) env->CallObjectMethod(vectorReaderGlobalRef, nextFloatMid); | ||
if (env->ExceptionCheck() || vector == nullptr) return false; | ||
|
||
jsize length = env->GetArrayLength(vector); | ||
jfloat* elems = env->GetFloatArrayElements(vector, nullptr); | ||
|
||
JNIReleaseElements release_elems([=]() { | ||
env->ReleaseFloatArrayElements(vector, elems, JNI_ABORT); | ||
}); | ||
|
||
int vectorByteSize = sizeof(float) * length; | ||
if (vectorByteSize != expectedByteSize) return false; | ||
|
||
std::memcpy(dest, elems, vectorByteSize); | ||
return true; | ||
|
||
} else { | ||
jclass vectorReaderClass = env->GetObjectClass(vectorReaderGlobalRef); | ||
jmethodID nextByteMid = env->GetMethodID(vectorReaderClass, "nextByteVector", "()[B"); | ||
if (env->ExceptionCheck() || !nextByteMid) return false; | ||
|
||
jbyteArray vector = (jbyteArray) env->CallObjectMethod(vectorReaderGlobalRef, nextByteMid); | ||
if (env->ExceptionCheck() || vector == nullptr) return false; | ||
|
||
jsize length = env->GetArrayLength(vector); | ||
jbyte* elems = env->GetByteArrayElements(vector, nullptr); | ||
|
||
JNIReleaseElements release_elems([=]() { | ||
env->ReleaseByteArrayElements(vector, elems, JNI_ABORT); | ||
}); | ||
|
||
int vectorByteSize = sizeof(float) * length; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a problem with that, for byte and other quantized(IxSQ) cases faiss expects float vectors and it internally performs quantization. I tried using the byte vectors array but that failed. So what we are doing here is getting byte vectors from VectorReader and then typecasting those vectors to float, then in read_index for IxSQ case what we are doing is using faiss sq.compute_codes so that faiss can get vectors in proper format. |
||
if (vectorByteSize != expectedByteSize) return false; | ||
|
||
float* floatDest = static_cast<float*>(dest); | ||
for (int i = 0; i < length; ++i) { | ||
floatDest[i] = static_cast<float>(elems[i]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, Got it! |
||
} | ||
|
||
return true; | ||
} | ||
} | ||
|
||
private: | ||
NativeEngineIndexInputMediator* mediator; | ||
jobject vectorReaderGlobalRef = nullptr; | ||
}; // class FaissOpenSearchIOReader | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add two helper functions to reduce branching? It's hard to tell where float and byte cases differ in this code, so helper would improve readability and maintainability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it make sense, thank you.