Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Make Merge in nativeEngine can Abort #2529

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions jni/include/faiss_index_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <jni.h>
#include "faiss/MetricType.h"
#include "faiss/impl/io.h"
#include "faiss/impl/AuxIndexStructures.h"
#include "jni_util.h"
#include "faiss_methods.h"
#include "faiss_stream_support.h"
Expand Down Expand Up @@ -195,6 +196,19 @@ class ByteIndexService final : public IndexService {
void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) final;
}; // class ByteIndexService

struct OpenSearchMergeInterruptCallback : faiss::InterruptCallback {

OpenSearchMergeInterruptCallback(JNIUtil *jniUtil, JNIEnv *env) {
mergeHelperClass = jniUtil->FindClass(env,"org/apache/lucene/index/KNNMergeHelper");
isAbortedMethod = jniUtil->FindMethod(env, "org/apache/lucene/index/KNNMergeHelper", "isMergeAborted");
}
bool want_interrupt () override {
return (bool) jenv->CallStaticBooleanMethod(mergeHelperClass, isAbortedMethod);
}
JNIEnv *jenv;
jclass mergeHelperClass;
jmethodID isAbortedMethod;
};
}
}

Expand Down
14 changes: 14 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,20 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSea
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex
(JNIEnv *, jclass, jlong, jfloatArray, jfloat, jobject, jint, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: setMergeInterruptCallback
* Signature: ()V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_setMergeInterruptCallback(JNIEnv * env, jclass cls);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: setMergeInterruptCallback
* Signature: ()V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_unsetMergeInterruptCallback(JNIEnv * env, jclass cls);

#ifdef __cplusplus
}
#endif
Expand Down
5 changes: 5 additions & 0 deletions jni/src/jni_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ void knn_jni::JNIUtil::Initialize(JNIEnv *env) {
this->cachedClasses["org/opensearch/knn/index/query/KNNQueryResult"] = (jclass) env->NewGlobalRef(tempLocalClassRef);
this->cachedMethods["org/opensearch/knn/index/query/KNNQueryResult:<init>"] = env->GetMethodID(tempLocalClassRef, "<init>", "(IF)V");
env->DeleteLocalRef(tempLocalClassRef);

tempLocalClassRef = env->FindClass("org/apache/lucene/index/KNNMergeHelper");
this->cachedClasses["org/apache/lucene/index/KNNMergeHelper"] = (jclass) env->NewGlobalRef(tempLocalClassRef);
this->cachedMethods["org/apache/lucene/index/KNNMergeHelper:isMergeAborted"] = env->GetStaticMethodID(tempLocalClassRef, "isMergeAborted", "()Z");
env->DeleteLocalRef(tempLocalClassRef);
}

void knn_jni::JNIUtil::Uninitialize(JNIEnv* env) {
Expand Down
20 changes: 20 additions & 0 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,3 +481,23 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSea
}
return nullptr;
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_setMergeInterruptCallback(JNIEnv * env, jclass cls)
{
try {
faiss::InterruptCallback::instance.reset(
new knn_jni::faiss_wrapper::OpenSearchMergeInterruptCallback(&jniUtil, env)
);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_unsetMergeInterruptCallback(JNIEnv * env, jclass cls)
{
try {
faiss::InterruptCallback::instance.get()->clear_instance();
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
}
18 changes: 18 additions & 0 deletions src/main/java/org/apache/lucene/index/KNNMergeHelper.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.lucene.index;

public class KNNMergeHelper {

private KNNMergeHelper() {}
public static boolean isMergeAborted() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we generalize this to include shard closed events? See this comment: opensearch-project/OpenSearch#8590 (comment). Im not sure if its possible or not to look up that kind of information. @shwetathareja do you have any ideas around here?

Thread mergeThread = Thread.currentThread();
if (mergeThread instanceof ConcurrentMergeScheduler.MergeThread) {
return ((ConcurrentMergeScheduler.MergeThread) mergeThread).merge.isAborted();
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package org.opensearch.knn.index.codec.KNN80Codec;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.ConcurrentMergeScheduler;
import org.apache.lucene.index.KNNMergeHelper;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.KNNMergeHelper;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.IndexOutput;
import org.opensearch.common.Nullable;
Expand Down Expand Up @@ -46,6 +47,8 @@
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName;
import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
import static org.opensearch.knn.jni.JNIService.setMergeInterruptCallback;
import static org.opensearch.knn.jni.JNIService.unsetMergeInterruptCallback;

/**
* Writes KNN Index for a field in a segment. This is intended to be used for native engines
Expand Down Expand Up @@ -119,9 +122,18 @@ public void mergeIndex(final KNNVectorValues<?> knnVectorValues, int totalLiveDo
}

long bytesPerVector = knnVectorValues.bytesPerVector();
startMergeStats(totalLiveDocs, bytesPerVector);
buildAndWriteIndex(knnVectorValues, totalLiveDocs);
endMergeStats(totalLiveDocs, bytesPerVector);
final KNNEngine knnEngine = extractKNNEngine(fieldInfo);
setMergeInterruptCallback(knnEngine);
try {
startMergeStats(totalLiveDocs, bytesPerVector);
buildAndWriteIndex(knnVectorValues, totalLiveDocs);
endMergeStats(totalLiveDocs, bytesPerVector);
} catch (Exception ex) {
//TODO handle
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to catch this?

log.debug("Merge may abort {}",KNNMergeHelper.isMergeAborted());
} finally {
unsetMergeInterruptCallback(knnEngine);
}
}

private void buildAndWriteIndex(final KNNVectorValues<?> knnVectorValues, int totalLiveDocs) throws IOException {
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/org/opensearch/knn/jni/FaissService.java
Original file line number Diff line number Diff line change
Expand Up @@ -463,4 +463,8 @@ public static native KNNQueryResult[] rangeSearchIndex(
int indexMaxResultWindow,
int[] parentIds
);

public static native void setMergeInterruptCallback();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove these and just have one interrupt callback registered at the time of library initialization? I dont think these need to be in the interface.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just prefer not to do this when its a global, static callback.


public static native void unsetMergeInterruptCallback();
}
13 changes: 13 additions & 0 deletions src/main/java/org/opensearch/knn/jni/JNIService.java
Original file line number Diff line number Diff line change
Expand Up @@ -476,4 +476,17 @@ public static KNNQueryResult[] radiusQueryIndex(
}
throw new IllegalArgumentException(String.format(Locale.ROOT, "RadiusQueryIndex not supported for provided engine"));
}

public static void setMergeInterruptCallback(KNNEngine knnEngine) {

if (KNNEngine.FAISS == knnEngine) {
FaissService.setMergeInterruptCallback();
}
}

public static void unsetMergeInterruptCallback(KNNEngine knnEngine) {
if (KNNEngine.FAISS == knnEngine) {
FaissService.unsetMergeInterruptCallback();
}
}
}
Loading