Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions jni/include/faiss_index_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ class IndexService {
*/
virtual void insertToIndex(int dim, int numIds, int threadCount, int64_t vectorsAddress, std::vector<int64_t> &ids, jlong idMapAddress);

/**
* Build a flat (IndexFlat) FAISS index from a set of vectors.
*
* @param numVectors Number of vectors to add to the index
* @param dim Dimension of each vector
* @param vectors Vector data as a flat array (size should be numVectors * dim)
* @param metricType Metric type for distance calculations (e.g., L2, IP)
* @return Memory address of the native IndexFlat object
*/
jlong buildFlatIndexFromVectors(int numVectors, int dim, const std::vector<float> &vectors, faiss::MetricType metricType);

/**
* Write index to disk
*
Expand Down
4 changes: 4 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ namespace knn_jni {

void InsertToIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jlong indexAddr, jint threadCount, IndexService *indexService);

// Build a flat (IndexFlat) FAISS index from a set of vectors.
// Returns the memory address of the created index.
jlong BuildFlatIndexFromVectors(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jfloatArray vectorsJ, jint numVectors, jint dimJ, jstring metricTypeJ, knn_jni::faiss_wrapper::IndexService *indexService);

void WriteIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jobject output, jlong indexAddr, IndexService *indexService);

// Create an index with ids and vectors. Instead of creating a new index, this function creates the index
Expand Down
8 changes: 8 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initByteIndex(J
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JNIEnv * env, jclass cls, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ,
jlong indexAddress, jint threadCount);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: buildFlatIndexFromVectors
* Signature: ([FIILjava/lang/String;)J
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_buildFlatIndexFromVectors(JNIEnv *, jclass, jfloatArray, jint, jint, jstring);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: insertToBinaryIndex
Expand Down
31 changes: 31 additions & 0 deletions jni/src/faiss_index_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@
#include "faiss/IndexIVFFlat.h"
#include "faiss/IndexBinaryIVF.h"
#include "faiss/IndexIDMap.h"
#include "faiss/IndexFlat.h"

#include <string>
#include <vector>
#include <memory>
#include <type_traits>

#include <fstream>
#include <iomanip>

namespace knn_jni {
namespace faiss_wrapper {

Expand Down Expand Up @@ -138,6 +142,33 @@ void IndexService::insertToIndex(
idMap->add_with_ids(numVectors, inputVectors->data(), ids.data());
}

jlong IndexService::buildFlatIndexFromVectors(
int numVectors,
int dim,
const std::vector<float> &vectors,
faiss::MetricType metricType) {

if (vectors.empty()) {
throw std::runtime_error("Input vectors cannot be empty");
}

if ((int)vectors.size() != numVectors * dim) {
throw std::runtime_error("Mismatch between vector count/dimension and actual data length");
}

faiss::IndexFlat *index = nullptr;

if (metricType == faiss::METRIC_INNER_PRODUCT) {
index = new faiss::IndexFlatIP(dim);
} else {
index = new faiss::IndexFlatL2(dim);
}

index->add(numVectors, vectors.data());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to call add_with_ids?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 - lets try to use existing mechanisms. You might need the id mapping during reconstruction.


return reinterpret_cast<jlong>(index);
}

void IndexService::writeIndex(
faiss::IOWriter* writer,
jlong idMapAddress
Expand Down
36 changes: 34 additions & 2 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,38 @@ void knn_jni::faiss_wrapper::InsertToIndex(knn_jni::JNIUtilInterface * jniUtil,
indexService->insertToIndex(dim, numIds, threadCount, vectorsAddress, ids, index_ptr);
}

jlong knn_jni::faiss_wrapper::BuildFlatIndexFromVectors(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jfloatArray vectorsJ, jint numVectors,
jint dimJ, jstring metricTypeJ, IndexService *indexService) {

if (vectorsJ == nullptr) {
throw std::runtime_error("Vector data cannot be null");
}

if (dimJ <= 0 || numVectors <= 0) {
throw std::runtime_error("Invalid dimensions or number of vectors");
}

const char *metricTypeC = env->GetStringUTFChars(metricTypeJ, nullptr);
jsize totalLength = env->GetArrayLength(vectorsJ);

if (totalLength != numVectors * dimJ) {
env->ReleaseStringUTFChars(metricTypeJ, metricTypeC);
throw std::runtime_error("Vector data length does not match numVectors * dimension");
}

jfloat* vectors = env->GetFloatArrayElements(vectorsJ, nullptr);
std::vector<float> cppVectors(vectors, vectors + totalLength);

faiss::MetricType metric = (strcmp(metricTypeC, "IP") == 0) ? faiss::METRIC_INNER_PRODUCT : faiss::METRIC_L2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use a helper in jni_util for this?


jlong indexPtr = indexService->buildFlatIndexFromVectors(numVectors, dimJ, cppVectors, metric);

env->ReleaseFloatArrayElements(vectorsJ, vectors, JNI_ABORT);
env->ReleaseStringUTFChars(metricTypeJ, metricTypeC);

return indexPtr;
}

void knn_jni::faiss_wrapper::WriteIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env,
jobject output, jlong index_ptr, IndexService* indexService) {

Expand Down Expand Up @@ -597,7 +629,7 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
} else {
auto ivfReader = dynamic_cast<const faiss::IndexIVF*>(indexReader->index);
auto ivfFlatReader = dynamic_cast<const faiss::IndexIVFFlat*>(indexReader->index);

if(ivfReader || ivfFlatReader) {
int indexNprobe = ivfReader == nullptr ? ivfFlatReader->nprobe : ivfReader->nprobe;
ivfParams.nprobe = commons::getIntegerMethodParameter(env, jniUtil, methodParams, NPROBES, indexNprobe);
Expand Down Expand Up @@ -1228,4 +1260,4 @@ jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInter
}

return results;
}
}
15 changes: 15 additions & 0 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include "jni_util.h"
#include "faiss_stream_support.h"



static knn_jni::JNIUtil jniUtil;
static const jint KNN_FAISS_JNI_VERSION = JNI_VERSION_1_1;

Expand Down Expand Up @@ -96,6 +98,19 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JN
}
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_buildFlatIndexFromVectors(
JNIEnv *env, jclass cls, jfloatArray vectorsJ, jint numVectors, jint dimJ, jstring metricTypeJ) {
try {
std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods());
knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods));
return knn_jni::faiss_wrapper::BuildFlatIndexFromVectors(&jniUtil, env, vectorsJ, numVectors, dimJ, metricTypeJ, &indexService);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}

return -1;
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ,
jlong indexAddress, jint threadCount)
Expand Down
119 changes: 119 additions & 0 deletions jni/tests/faiss_index_service_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "commons.h"
#include <faiss/IndexFlat.h>

using ::testing::NiceMock;
using ::testing::Return;
Expand Down Expand Up @@ -160,4 +161,122 @@ TEST(CreateByteIndexTest, BasicAssertions) {
long indexAddress = indexService.initIndex(&mockJNIUtil, jniEnv, metricType, indexDescription, dim, numIds, threadCount, parametersMap);
indexService.insertToIndex(dim, numIds, threadCount, (int64_t) &vectors, ids, indexAddress);
indexService.writeIndex(&fileIOWriter, indexAddress);
}

//buildFlatIndexFromVectors tests

// Helper: Create dummy data for float vectors
std::vector<float> makeVectors(int num, int dim, float val = 1.0f) {
std::vector<float> v(num * dim, val);
return v;
}

/**
* Test that a Flat L2 index is successfully created from vectors.
* Checks that the returned pointer is not null and the number of vectors is correct.
*/
TEST(BuildFlatIndexFromVectorsTest, BuildsL2Index) {
int numVectors = 5, dim = 3;
std::vector<float> data = makeVectors(numVectors, dim, 2.0f);

std::unique_ptr<MockFaissMethods> mockFaissMethods(new MockFaissMethods());
knn_jni::faiss_wrapper::IndexService indexService(std::move(mockFaissMethods));

jlong indexPtr = indexService.buildFlatIndexFromVectors(numVectors, dim, data, faiss::METRIC_L2);

ASSERT_NE(indexPtr, 0);
auto* index = reinterpret_cast<faiss::IndexFlatL2*>(indexPtr);
ASSERT_EQ(index->ntotal, numVectors);
delete index;
}

/**
* Test that a Flat Inner Product index is successfully created from vectors.
* Checks that the returned pointer is not null and the number of vectors is correct.
*/
TEST(BuildFlatIndexFromVectorsTest, BuildsIPIndex) {
int numVectors = 4, dim = 2;
std::vector<float> data = makeVectors(numVectors, dim, 3.0f);

std::unique_ptr<MockFaissMethods> mockFaissMethods(new MockFaissMethods());
knn_jni::faiss_wrapper::IndexService indexService(std::move(mockFaissMethods));

jlong indexPtr = indexService.buildFlatIndexFromVectors(numVectors, dim, data, faiss::METRIC_INNER_PRODUCT);

ASSERT_NE(indexPtr, 0);
auto* index = reinterpret_cast<faiss::IndexFlatIP*>(indexPtr);
ASSERT_EQ(index->ntotal, numVectors);
delete index;
}

/**
* Test that providing empty vectors throws a runtime_error.
*/
TEST(BuildFlatIndexFromVectorsTest, ThrowsOnEmptyVectors) {
int numVectors = 10, dim = 4;
std::vector<float> empty;

std::unique_ptr<MockFaissMethods> mockFaissMethods(new MockFaissMethods());
knn_jni::faiss_wrapper::IndexService indexService(std::move(mockFaissMethods));

EXPECT_THROW(
indexService.buildFlatIndexFromVectors(numVectors, dim, empty, faiss::METRIC_L2),
std::runtime_error
);
}

/**
* Test that providing a vector whose size does not match numVectors * dim throws a runtime_error.
*/
TEST(BuildFlatIndexFromVectorsTest, ThrowsOnMismatchedSize) {
int numVectors = 3, dim = 5;
std::vector<float> badData(7, 1.0f);

std::unique_ptr<MockFaissMethods> mockFaissMethods(new MockFaissMethods());
knn_jni::faiss_wrapper::IndexService indexService(std::move(mockFaissMethods));

EXPECT_THROW(
indexService.buildFlatIndexFromVectors(numVectors, dim, badData, faiss::METRIC_L2),
std::runtime_error
);
}

/**
* Test that the vectors inserted into the flat index are preserved in order and value.
* This reconstructs each vector from the index and compares to the original input.
*/
TEST(BuildFlatIndexFromVectorsTest, IndexContainsInsertedVectorsInOrder) {
int numVectors = 5, dim = 3;
// Prepare 5 unique vectors
std::vector<float> data = {
1.0f, 2.0f, 3.0f, // vector 0
4.0f, 5.0f, 6.0f, // vector 1
7.0f, 8.0f, 9.0f, // vector 2
10.0f, 11.0f, 12.0f, // vector 3
13.0f, 14.0f, 15.0f // vector 4
};

// Use L2 metric for this test
std::unique_ptr<MockFaissMethods> mockFaissMethods(new MockFaissMethods());
knn_jni::faiss_wrapper::IndexService indexService(std::move(mockFaissMethods));

jlong indexPtr = indexService.buildFlatIndexFromVectors(numVectors, dim, data, faiss::METRIC_L2);

ASSERT_NE(indexPtr, 0);
auto* index = reinterpret_cast<faiss::IndexFlatL2*>(indexPtr);
ASSERT_EQ(index->ntotal, numVectors);

// Check each vector in the index matches input and order
std::vector<float> reconstructed(dim);
for (int i = 0; i < numVectors; ++i) {
index->reconstruct(i, reconstructed.data());
for (int j = 0; j < dim; ++j) {
float expected = data[i * dim + j];
ASSERT_FLOAT_EQ(reconstructed[j], expected)
<< "Vector " << i << " element " << j << " mismatch";
}
}

// Clean up
delete index;
}
12 changes: 6 additions & 6 deletions jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ const float rangeSearchRandomDataMax = 50;
const float rangeSearchRadius = 20000;

void createIndexIteratively(
knn_jni::JNIUtilInterface * JNIUtil,
JNIEnv *jniEnv,
knn_jni::JNIUtilInterface * JNIUtil,
JNIEnv *jniEnv,
std::vector<faiss::idx_t> & ids,
std::vector<float> & vectors,
int dim,
Expand Down Expand Up @@ -73,13 +73,13 @@ void createIndexIteratively(
}

void createBinaryIndexIteratively(
knn_jni::JNIUtilInterface * JNIUtil,
JNIEnv *jniEnv,
knn_jni::JNIUtilInterface * JNIUtil,
JNIEnv *jniEnv,
std::vector<faiss::idx_t> & ids,
std::vector<uint8_t> & vectors,
int dim,
jobject javaFileOutputMock,
std::unordered_map<string, jobject> parametersMap,
std::unordered_map<string, jobject> parametersMap,
IndexService * indexService,
int insertions = 10
) {
Expand Down Expand Up @@ -1336,4 +1336,4 @@ TEST(FaissRangeSearchQueryIndexTestWithParentFilterTest, BasicAssertions) {
delete it;
}
}
}
}
Loading