diff --git a/jni/include/faiss_index_service.h b/jni/include/faiss_index_service.h index d96c3e7557..15aeb6c7cf 100644 --- a/jni/include/faiss_index_service.h +++ b/jni/include/faiss_index_service.h @@ -17,6 +17,7 @@ #include #include "faiss/MetricType.h" #include "faiss/impl/io.h" +#include "faiss/impl/AuxIndexStructures.h" #include "jni_util.h" #include "faiss_methods.h" #include "faiss_stream_support.h" @@ -195,6 +196,19 @@ class ByteIndexService final : public IndexService { void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) final; }; // class ByteIndexService +struct OpenSearchMergeInterruptCallback : faiss::InterruptCallback { + + OpenSearchMergeInterruptCallback(JNIUtil *jniUtil, JNIEnv *env) { + mergeHelperClass = jniUtil->FindClass(env,"org/apache/lucene/index/KNNMergeHelper"); + isAbortedMethod = jniUtil->FindMethod(env, "org/apache/lucene/index/KNNMergeHelper", "isMergeAborted"); + } + bool want_interrupt () override { + return (bool) jenv->CallStaticBooleanMethod(mergeHelperClass, isAbortedMethod); + } + JNIEnv *jenv; + jclass mergeHelperClass; + jmethodID isAbortedMethod; +}; } } diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index dce5801383..dce65dfcb7 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -268,6 +268,20 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSea JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex (JNIEnv *, jclass, jlong, jfloatArray, jfloat, jobject, jint, jintArray); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: setMergeInterruptCallback + * Signature: ()V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_setMergeInterruptCallback(JNIEnv * env, jclass cls); + +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: setMergeInterruptCallback + * Signature: ()V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_unsetMergeInterruptCallback(JNIEnv * env, jclass cls); + #ifdef __cplusplus } #endif diff --git a/jni/src/jni_util.cpp b/jni/src/jni_util.cpp index 3ff79752a6..9420d50a13 100644 --- a/jni/src/jni_util.cpp +++ b/jni/src/jni_util.cpp @@ -65,6 +65,11 @@ void knn_jni::JNIUtil::Initialize(JNIEnv *env) { this->cachedClasses["org/opensearch/knn/index/query/KNNQueryResult"] = (jclass) env->NewGlobalRef(tempLocalClassRef); this->cachedMethods["org/opensearch/knn/index/query/KNNQueryResult:"] = env->GetMethodID(tempLocalClassRef, "", "(IF)V"); env->DeleteLocalRef(tempLocalClassRef); + + tempLocalClassRef = env->FindClass("org/apache/lucene/index/KNNMergeHelper"); + this->cachedClasses["org/apache/lucene/index/KNNMergeHelper"] = (jclass) env->NewGlobalRef(tempLocalClassRef); + this->cachedMethods["org/apache/lucene/index/KNNMergeHelper:isMergeAborted"] = env->GetStaticMethodID(tempLocalClassRef, "isMergeAborted", "()Z"); + env->DeleteLocalRef(tempLocalClassRef); } void knn_jni::JNIUtil::Uninitialize(JNIEnv* env) { diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 836774402f..746fdfd408 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -481,3 +481,23 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSea } return nullptr; } + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_setMergeInterruptCallback(JNIEnv * env, jclass cls) +{ + try { + faiss::InterruptCallback::instance.reset( + new knn_jni::faiss_wrapper::OpenSearchMergeInterruptCallback(&jniUtil, env) + ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_unsetMergeInterruptCallback(JNIEnv * env, jclass cls) +{ + try { + faiss::InterruptCallback::instance.get()->clear_instance(); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} \ No newline at end of file diff --git a/src/main/java/org/apache/lucene/index/KNNMergeHelper.java b/src/main/java/org/apache/lucene/index/KNNMergeHelper.java new file mode 100644 index 0000000000..b87c55b821 --- /dev/null +++ b/src/main/java/org/apache/lucene/index/KNNMergeHelper.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.lucene.index; + +public class KNNMergeHelper { + + private KNNMergeHelper() {} + public static boolean isMergeAborted() { + Thread mergeThread = Thread.currentThread(); + if (mergeThread instanceof ConcurrentMergeScheduler.MergeThread) { + return ((ConcurrentMergeScheduler.MergeThread) mergeThread).merge.isAborted(); + } + return false; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 443b12b9c4..bcfd38fee9 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -6,6 +6,8 @@ package org.opensearch.knn.index.codec.KNN80Codec; import lombok.extern.log4j.Log4j2; +import org.apache.lucene.index.ConcurrentMergeScheduler; +import org.apache.lucene.index.KNNMergeHelper; import org.opensearch.common.StopWatch; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java index de535c39e8..dcc22754ed 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java @@ -9,6 +9,7 @@ import lombok.extern.log4j.Log4j2; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.KNNMergeHelper; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.store.IndexOutput; import org.opensearch.common.Nullable; @@ -46,6 +47,8 @@ import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName; import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; +import static org.opensearch.knn.jni.JNIService.setMergeInterruptCallback; +import static org.opensearch.knn.jni.JNIService.unsetMergeInterruptCallback; /** * Writes KNN Index for a field in a segment. This is intended to be used for native engines @@ -119,9 +122,18 @@ public void mergeIndex(final KNNVectorValues knnVectorValues, int totalLiveDo } long bytesPerVector = knnVectorValues.bytesPerVector(); - startMergeStats(totalLiveDocs, bytesPerVector); - buildAndWriteIndex(knnVectorValues, totalLiveDocs); - endMergeStats(totalLiveDocs, bytesPerVector); + final KNNEngine knnEngine = extractKNNEngine(fieldInfo); + setMergeInterruptCallback(knnEngine); + try { + startMergeStats(totalLiveDocs, bytesPerVector); + buildAndWriteIndex(knnVectorValues, totalLiveDocs); + endMergeStats(totalLiveDocs, bytesPerVector); + } catch (Exception ex) { + //TODO handle + log.debug("Merge may abort {}",KNNMergeHelper.isMergeAborted()); + } finally { + unsetMergeInterruptCallback(knnEngine); + } } private void buildAndWriteIndex(final KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException { diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 3ae8bbb926..cb97b71897 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -463,4 +463,8 @@ public static native KNNQueryResult[] rangeSearchIndex( int indexMaxResultWindow, int[] parentIds ); + + public static native void setMergeInterruptCallback(); + + public static native void unsetMergeInterruptCallback(); } diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index b490476eb1..a4659a8584 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -476,4 +476,17 @@ public static KNNQueryResult[] radiusQueryIndex( } throw new IllegalArgumentException(String.format(Locale.ROOT, "RadiusQueryIndex not supported for provided engine")); } + + public static void setMergeInterruptCallback(KNNEngine knnEngine) { + + if (KNNEngine.FAISS == knnEngine) { + FaissService.setMergeInterruptCallback(); + } + } + + public static void unsetMergeInterruptCallback(KNNEngine knnEngine) { + if (KNNEngine.FAISS == knnEngine) { + FaissService.unsetMergeInterruptCallback(); + } + } }