diff --git a/.github/workflows/remote_index_build.yml b/.github/workflows/remote_index_build.yml new file mode 100644 index 0000000000..fece6a4c11 --- /dev/null +++ b/.github/workflows/remote_index_build.yml @@ -0,0 +1,144 @@ +name: Build and Test k-NN using Remote Index Builder +on: + schedule: + - cron: '0 0 * * *' # every night + push: + branches: + - "*" + - "feature/**" + paths: + - 'build.gradle' + - 'settings.gradle' + - 'src/**' + - 'build-tools/**' + - 'buildSrc/**' + - 'gradle/**' + - 'jni/**' + - '.github/workflows/remote_index_build.yml' + pull_request: + branches: + - "*" + - "feature/**" + paths: + - 'build.gradle' + - 'settings.gradle' + - 'src/**' + - 'build-tools/**' + - 'buildSrc/**' + - 'gradle/**' + - 'jni/**' + - '.github/workflows/remote_index_build.yml' + +jobs: + Remote-Index-Build-IT-Tests: + strategy: + matrix: + java: [21] + + env: + AWS_ACCESS_KEY_ID: test + AWS_SECRET_ACCESS_KEY: test + AWS_SESSION_TOKEN: test + + name: Remote-Index-Build-IT-Tests on Linux + runs-on: + group: selfhosted-gpu-runners + labels: g6xlarge + + steps: + - name: Checkout k-NN + uses: actions/checkout@v4 + + # Setup git user so that patches for native libraries can be applied and committed + - name: Setup git user + run: | + git config --global user.name "github-actions[bot]" + git config --global user.email "github-actions[bot]@users.noreply.github.com" + + - name: Setup Java ${{ matrix.java }} + uses: actions/setup-java@v4 + with: + java-version: ${{ matrix.java }} + distribution: 'temurin' + + - name: Install dependencies on linux + run: | + sudo yum install gcc g++ -y + sudo yum install openblas openblas-devel -y + sudo yum install -y zlib + sudo yum install -y zlib-devel + sudo yum install -y cmake + sudo yum install gcc-gfortran -y + + - name: Initial cleanup + run: | + docker ps -aq | xargs -r docker rm -f + docker system prune -af --volumes + + - name: Pull Remote Index Build Docker Image from Docker Hub + run: | + docker pull rchitale7/remote-index-build-service:api + + - name: Pull LocalStack Docker image + run: | + docker pull localstack/localstack:latest + + - name: Run LocalStack + run: | + docker run --rm -d -p 4566:4566 localstack/localstack:latest + + - name: Verify Localstack is ready + run: | + if ! timeout 3 bash -c 'until curl --silent --fail http://localhost:4566/_localstack/health; do sleep 1; done'; then + echo "Localstack health check failed after 3 seconds" + exit 1 + fi + + - name: Create S3 Bucket in LocalStack + run: | + aws --endpoint-url=http://localhost:4566 s3 mb s3://remote-index-build-bucket + + - name: Run Docker container + run: | + docker run --rm -d --name remote-index-builder-container --gpus all -p 80:80 -e INTEGRATION_TESTS=TRUE -e AWS_ACCESS_KEY_ID=${{ env.AWS_ACCESS_KEY_ID }} -e AWS_SECRET_ACCESS_KEY=${{ env.AWS_SECRET_ACCESS_KEY }} -e AWS_SESSION_TOKEN=${{ env.AWS_SESSION_TOKEN}} rchitale7/remote-index-build-service:api + sleep 5 + + - name: Run tests + run: | + if lscpu | grep -i avx512f | grep -i avx512cd | grep -i avx512vl | grep -i avx512dq | grep -i avx512bw + then + if lscpu | grep -q "GenuineIntel" && lscpu | grep -i avx512_fp16 | grep -i avx512_bf16 | grep -i avx512_vpopcntdq + then + echo "the system is an Intel(R) Sapphire Rapids or a newer-generation processor" + ./gradlew :integTestRemoteIndexBuild -Ds3.enabled=true -Dtest.remoteBuild=s3.localStack -Dtest.bucket=remote-index-build-bucket -Dtest.base_path=vectors -Daccess_key=${{ env.AWS_ACCESS_KEY_ID }} -Dsecret_key=${{ env.AWS_SECRET_ACCESS_KEY }} -Dsession_token=${{ env.AWS_SESSION_TOKEN}} -Dtests.class=org.opensearch.knn.index.RemoteBuildIT -Davx512_spr.enabled=true -Dnproc.count=`nproc` + else + echo "avx512 available on system" + ./gradlew :integTestRemoteIndexBuild -Ds3.enabled=true -Dtest.remoteBuild=s3.localStack -Dtest.bucket=remote-index-build-bucket -Dtest.base_path=vectors -Daccess_key=${{ env.AWS_ACCESS_KEY_ID }} -Dsecret_key=${{ env.AWS_SECRET_ACCESS_KEY }} -Dsession_token=${{ env.AWS_SESSION_TOKEN}} -Dtests.class=org.opensearch.knn.index.RemoteBuildIT -Davx512_spr.enabled=false -Dnproc.count=`nproc` + fi + elif lscpu | grep -i avx2 + then + echo "avx2 available on system" + ./gradlew :integTestRemoteIndexBuild -Ds3.enabled=true -Dtest.remoteBuild=s3.localStack -Dtest.bucket=remote-index-build-bucket -Dtest.base_path=vectors -Daccess_key=${{ env.AWS_ACCESS_KEY_ID }} -Dsecret_key=${{ env.AWS_SECRET_ACCESS_KEY }} -Dsession_token=${{ env.AWS_SESSION_TOKEN}} -Dtests.class=org.opensearch.knn.index.RemoteBuildIT -Davx512.enabled=false -Davx512_spr.enabled=false -Dnproc.count=`nproc` + else + echo "avx512 and avx2 not available on system" + ./gradlew :integTestRemoteIndexBuild -Ds3.enabled=true -Dtest.remoteBuild=s3.localStack -Dtest.bucket=remote-index-build-bucket -Dtest.base_path=vectors -Daccess_key=${{ env.AWS_ACCESS_KEY_ID }} -Dsecret_key=${{ env.AWS_SECRET_ACCESS_KEY }} -Dsession_token=${{ env.AWS_SESSION_TOKEN}} -Dtests.class=org.opensearch.knn.index.RemoteBuildIT -Davx2.enabled=false -Davx512.enabled=false -Davx512_spr.enabled=false -Dnproc.count=`nproc` + fi + + - name: Verify Remote Index Builder logs + run: | + if docker logs remote-index-builder-container 2>&1 | grep -q "INFO - Index built successfully!"; then + echo "Success logs found in Remote Index Builder container" + else + echo "No success logs found. Full logs:" + docker logs remote-index-builder-container + exit 1 + fi + + - name: Final cleanup + if: always() + run: | + docker ps -aq | xargs -r docker rm -f + docker system prune -af --volumes + docker logout + rm -rf ${{ github.workspace }}/* + diff --git a/CHANGELOG.md b/CHANGELOG.md index c9c74dba14..e26caebc83 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,28 +5,18 @@ All notable changes to this project are documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). See the [CONTRIBUTING guide](./CONTRIBUTING.md#Changelog) for instructions on how to add changelog entries. ## [Unreleased 3.0](https://github.com/opensearch-project/k-NN/compare/2.x...HEAD) -### Features -* [Remote Vector Index Build] Client polling mechanism, encoder check, method parameter retrieval [#2576](https://github.com/opensearch-project/k-NN/pull/2576) -* [Remote Vector Index Build] Move client to separate module [#2603](https://github.com/opensearch-project/k-NN/pull/2603) -* Add filter function to KNNQueryBuilder with unit tests and integration tests [#2599](https://github.com/opensearch-project/k-NN/pull/2599) -* [Lucene On Faiss] Add a new mode, memory-optimized-search enable user to run vector search on FAISS index under memory constrained environment. [#2630](https://github.com/opensearch-project/k-NN/pull/2630) ### Enhancements +* Removing redundant type conversions for script scoring for hamming space with binary vectors [#2351](https://github.com/opensearch-project/k-NN/pull/2351) ### Bug Fixes -* Fixing bug to prevent NullPointerException while doing PUT mappings [#2556](https://github.com/opensearch-project/k-NN/issues/2556) -* Add index operation listener to update translog source [#2629](https://github.com/opensearch-project/k-NN/pull/2629) -* [Remote Vector Index Build] Fix bug to support `COSINESIMIL` space type [#2627](https://github.com/opensearch-project/k-NN/pull/2627) -### Infrastructure -### Documentation -### Maintenance -* Update minimum required CMAKE version in NMSLIB [#2635](https://github.com/opensearch-project/k-NN/pull/2635) -### Refactoring -* Switch derived source from field attributes to segment attribute [#2606](https://github.com/opensearch-project/k-NN/pull/2606) -* Migrate derived source from filter to mask [#2612](https://github.com/opensearch-project/k-NN/pull/2612) +* [BUGFIX] Fix KNN Quantization state cache have an invalid weight threshold [#2666](https://github.com/opensearch-project/k-NN/pull/2666) ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.19...2.x) ### Features +* [Vector Profiler] Adding basic generic vector profiler implementation and tests. [#2624](https://github.com/opensearch-project/k-NN/pull/2624) +* [Vector Profiler] Adding main segment implementation for API and indexing. [#2653](https://github.com/opensearch-project/k-NN/pull/2653) ### Enhancements ### Bug Fixes +* [BUGFIX] FIX nested vector query at efficient filter scenarios [#2641](https://github.com/opensearch-project/k-NN/pull/2641) ### Infrastructure ### Documentation ### Maintenance diff --git a/build.gradle b/build.gradle index d1624402c8..7f7ddbeb79 100644 --- a/build.gradle +++ b/build.gradle @@ -45,6 +45,18 @@ buildscript { opensearch_build += "-SNAPSHOT" } opensearch_no_snapshot = opensearch_build.replace("-SNAPSHOT","") + + os_platform = "linux" + artifact_type = "tar" + + if (Os.isFamily(Os.FAMILY_WINDOWS)) { + os_platform = "windows" + artifact_type = "zip" + } + + opensearch_version_no_snapshot = opensearch_version.replace("-SNAPSHOT","") + repo_s3_resource_folder = "build/resource/repository-s3" + repo_s3_download_url = "https://ci.opensearch.org/ci/dbc/distribution-build-opensearch/${opensearch_version_no_snapshot}/latest/${os_platform}/x64/${artifact_type}/builds/opensearch/core-plugins/repository-s3-${opensearch_version_no_snapshot}.zip" } // This isn't applying from repositories.gradle so repeating git diff it here @@ -71,6 +83,7 @@ buildscript { //****************************************************************************/ plugins { + id 'eclipse' id 'java-library' id 'java-test-fixtures' id 'idea' @@ -84,7 +97,7 @@ apply plugin: 'opensearch.opensearchplugin' apply plugin: 'opensearch.rest-test' apply plugin: 'opensearch.pluginzip' apply plugin: 'opensearch.repositories' - +apply plugin: 'opensearch.java-agent' def opensearch_tmp_dir = rootProject.file('build/private/opensearch_tmp').absoluteFile opensearch_tmp_dir.mkdirs() @@ -170,6 +183,28 @@ ext { ) cluster.setSecure(true) } + + configureS3Plugin = { OpenSearchCluster cluster -> + cluster.plugin(provider(new Callable(){ + @Override + RegularFile call() throws Exception { + return new RegularFile() { + @Override + File getAsFile() { + if (new File("$project.rootDir/$repo_s3_resource_folder").exists()) { + project.delete(files("$project.rootDir/$repo_s3_resource_folder")) + } + project.mkdir repo_s3_resource_folder + ant.get(src: repo_s3_download_url, + dest: repo_s3_resource_folder, + httpusecaches: false) + return fileTree(repo_s3_resource_folder).getSingleFile() + } + } + } + })) + } + propertyKeys = [ breaker: [ useRealMemory: 'tests.opensearch.indices.breaker.total.use_real_memory' @@ -321,6 +356,7 @@ dependencies { testFixturesImplementation group: 'net.minidev', name: 'json-smart', version: "${versions.json_smart}" testFixturesImplementation "org.opensearch:common-utils:${version}" implementation 'com.github.oshi:oshi-core:6.4.13' + implementation 'org.apache.commons:commons-math3:3.6.1' api "net.java.dev.jna:jna:${versions.jna}" api "net.java.dev.jna:jna-platform:${versions.jna}" // OpenSearch core is using slf4j 1.7.36. Therefore, we cannot change the version here. @@ -343,6 +379,7 @@ tasks.register('cmakeJniLib', Exec) { args.add("-DAVX512_SPR_ENABLED=${avx512_spr_enabled}") args.add("-DCOMMIT_LIB_PATCHES=${commit_lib_patches}") args.add("-DAPPLY_LIB_PATCHES=${apply_lib_patches}") + args.add("-DCMAKE_POLICY_VERSION_MINIMUM=3.5") def javaHome = Jvm.current().getJavaHome() logger.lifecycle("Java home directory used by gradle: $javaHome") if (Os.isFamily(Os.FAMILY_WINDOWS)) { @@ -397,16 +434,16 @@ test { } def _numNodes = findProperty('numNodes') as Integer ?: 1 -integTest { +def commonIntegTest(RestIntegTestTask task, project, integTestDependOnJniLib, opensearch_tmp_dir, _numNodes){ if (integTestDependOnJniLib) { - dependsOn buildJniLib + task.dependsOn buildJniLib } - systemProperty 'tests.security.manager', 'false' - systemProperty 'java.io.tmpdir', opensearch_tmp_dir.absolutePath - systemProperty "java.library.path", "$rootDir/jni/build/release" - systemProperty "tests.path.repo", "${buildDir}/testSnapshotFolder" + task.systemProperty 'tests.security.manager', 'false' + task.systemProperty 'java.io.tmpdir', opensearch_tmp_dir.absolutePath + task.systemProperty "java.library.path", "$rootDir/jni/build/release" + task.systemProperty "tests.path.repo", "${buildDir}/testSnapshotFolder" // allows integration test classes to access test resource from project root path - systemProperty('project.root', project.rootDir.absolutePath) + task.systemProperty('project.root', project.rootDir.absolutePath) var is_https = System.getProperty("https") var user = System.getProperty("user") @@ -419,67 +456,97 @@ integTest { password = password == null ? "admin" : password } - systemProperty("https", is_https) - systemProperty("user", user) - systemProperty("password", password) - systemProperty("test.exhaustive", System.getProperty("test.exhaustive")) + task.systemProperty("https", is_https) + task.systemProperty("user", user) + task.systemProperty("password", password) + task.systemProperty("test.exhaustive", System.getProperty("test.exhaustive")) - doFirst { + task.doFirst { // Tell the test JVM if the cluster JVM is running under a debugger so that tests can // use longer timeouts for requests. def isDebuggingCluster = getDebug() || System.getProperty("test.debug") != null - systemProperty 'cluster.debug', isDebuggingCluster + task.systemProperty 'cluster.debug', isDebuggingCluster // Set number of nodes system property to be used in tests - systemProperty 'cluster.number_of_nodes', "${_numNodes}" + task.systemProperty 'cluster.number_of_nodes', "${_numNodes}" // There seems to be an issue when running multi node run or integ tasks with unicast_hosts // not being written, the waitForAllConditions ensures it's written - getClusters().forEach { cluster -> + task.getClusters().forEach { cluster -> cluster.waitForAllConditions() } } // The -Ddebug.es option makes the cluster debuggable; this makes the tests debuggable if (System.getProperty("test.debug") != null) { - jvmArgs '-agentlib:jdwp=transport=dt_socket,server=n,suspend=y,address=8000' + task.jvmArgs '-agentlib:jdwp=transport=dt_socket,server=n,suspend=y,address=8000' } - systemProperty propertyKeys.breaker.useRealMemory, getBreakerSetting() + task.systemProperty propertyKeys.breaker.useRealMemory, getBreakerSetting() } -testClusters.integTest { - testDistribution = "ARCHIVE" +integTest { + commonIntegTest(it, project, integTestDependOnJniLib, opensearch_tmp_dir, _numNodes) + filter { + excludeTestsMatching "org.opensearch.knn.index.RemoteBuildIT" + } +} + +task integTestRemoteIndexBuild(type: RestIntegTestTask) { + commonIntegTest(it, project, integTestDependOnJniLib, opensearch_tmp_dir, _numNodes) + systemProperty("test.remoteBuild", System.getProperty("test.remoteBuild")) + systemProperty("test.bucket", System.getProperty("test.bucket")) + systemProperty("test.base_path", System.getProperty("test.base_path")) +} + +def commonIntegTestClusters(OpenSearchCluster cluster, _numNodes){ + cluster.testDistribution = "ARCHIVE" //Used for circuit breaker integration tests - setting 'node.attr.knn_cb_tier', 'integ' + cluster.setting 'node.attr.knn_cb_tier', 'integ' // Optionally install security if (System.getProperty("security.enabled") != null) { - configureSecurityPlugin(testClusters.integTest) + configureSecurityPlugin(cluster) } - plugin(project.tasks.bundlePlugin.archiveFile) + cluster.plugin(project.tasks.bundlePlugin.archiveFile) if (Os.isFamily(Os.FAMILY_WINDOWS)) { // Add the paths of built JNI libraries and its dependent libraries to PATH variable in System variables - environment('PATH', System.getenv('PATH') + ";$rootDir/jni/build/release" + ";$rootDir/src/main/resources/windowsDependencies") + cluster.environment('PATH', System.getenv('PATH') + ";$rootDir/jni/build/release" + ";$rootDir/src/main/resources/windowsDependencies") } // Cluster shrink exception thrown if we try to set numberOfNodes to 1, so only apply if > 1 - if (_numNodes > 1) numberOfNodes = _numNodes + if (_numNodes > 1) cluster.numberOfNodes = _numNodes // When running integration tests it doesn't forward the --debug-jvm to the cluster anymore // i.e. we have to use a custom property to flag when we want to debug opensearch JVM // since we also support multi node integration tests we increase debugPort per node if (System.getProperty("cluster.debug") != null) { def debugPort = 5005 - nodes.forEach { node -> + cluster.nodes.forEach { node -> node.jvmArgs("-agentlib:jdwp=transport=dt_socket,server=n,suspend=y,address=${debugPort}") debugPort += 1 } } - systemProperty("java.library.path", "$rootDir/jni/build/release") + cluster.systemProperty("java.library.path", "$rootDir/jni/build/release") final testSnapshotFolder = file("${buildDir}/testSnapshotFolder") testSnapshotFolder.mkdirs() - setting 'path.repo', "${buildDir}/testSnapshotFolder" - systemProperty propertyKeys.breaker.useRealMemory, getBreakerSetting() + cluster.setting 'path.repo', "${buildDir}/testSnapshotFolder" + cluster.systemProperty propertyKeys.breaker.useRealMemory, getBreakerSetting() +} + +testClusters.integTest { + commonIntegTestClusters(it, _numNodes) +} + +testClusters.integTestRemoteIndexBuild { + commonIntegTestClusters(it, _numNodes) + // Optionally install S3 + if (System.getProperty("s3.enabled") != null) { + configureS3Plugin(testClusters.integTestRemoteIndexBuild) + } + + keystore 's3.client.default.access_key', "${System.getProperty("access_key")}" + keystore 's3.client.default.secret_key', "${System.getProperty("secret_key")}" + keystore 's3.client.default.session_token', "${System.getProperty("session_token")}" } task integTestRemote(type: RestIntegTestTask) { @@ -494,6 +561,7 @@ task integTestRemote(type: RestIntegTestTask) { systemProperty 'tests.security.manager', 'false' systemProperty("test.exhaustive", System.getProperty("test.exhaustive")) + systemProperty "tests.path.repo", "${layout.buildDirectory.toString()}/testSnapshotFolder" // Run tests with remote cluster only if rest case is defined if (System.getProperty("tests.rest.cluster") != null) { diff --git a/jni/cmake/init-faiss.cmake b/jni/cmake/init-faiss.cmake index 523b3c17d5..a11b2d771a 100644 --- a/jni/cmake/init-faiss.cmake +++ b/jni/cmake/init-faiss.cmake @@ -20,6 +20,7 @@ if(NOT DEFINED APPLY_LIB_PATCHES OR "${APPLY_LIB_PATCHES}" STREQUAL true) list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0002-Enable-precomp-table-to-be-shared-ivfpq.patch") list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0003-Custom-patch-to-support-range-search-params.patch") list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0004-Custom-patch-to-support-binary-vector.patch") + list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0005-Custom-patch-to-support-multi-vector-IndexHNSW-search_level_0.patch") # Get patch id of the last commit execute_process(COMMAND sh -c "git --no-pager show HEAD | git patch-id --stable" OUTPUT_VARIABLE PATCH_ID_OUTPUT_FROM_COMMIT WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss) diff --git a/jni/cmake/init-nmslib.cmake b/jni/cmake/init-nmslib.cmake index b28e9d4016..44fc6e1b24 100644 --- a/jni/cmake/init-nmslib.cmake +++ b/jni/cmake/init-nmslib.cmake @@ -21,7 +21,6 @@ if(NOT DEFINED APPLY_LIB_PATCHES OR "${APPLY_LIB_PATCHES}" STREQUAL true) list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0003-Added-streaming-apis-for-vector-index-loading-in-Hnsw.patch") list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0004-Added-a-new-save-apis-in-Hnsw-with-streaming-interfa.patch") list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0005-Add-util-include-to-fix-pragma-error.patch") - list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0006-Bump-cmake-version-nmslib.patch") # Get patch id of the last commit execute_process(COMMAND sh -c "git --no-pager show HEAD | git patch-id --stable" OUTPUT_VARIABLE PATCH_ID_OUTPUT_FROM_COMMIT WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/nmslib) diff --git a/jni/patches/faiss/0005-Custom-patch-to-support-multi-vector-IndexHNSW-search_level_0.patch b/jni/patches/faiss/0005-Custom-patch-to-support-multi-vector-IndexHNSW-search_level_0.patch new file mode 100644 index 0000000000..2c2d365bee --- /dev/null +++ b/jni/patches/faiss/0005-Custom-patch-to-support-multi-vector-IndexHNSW-search_level_0.patch @@ -0,0 +1,333 @@ +From 9ef5e349ca5893da07898d7f1d22b0a81f17fddc Mon Sep 17 00:00:00 2001 +From: AnnTian Shao +Date: Thu, 3 Apr 2025 21:21:11 +0000 +Subject: [PATCH] Add multi-vector-support faiss patch to + IndexHNSW::search_level_0 + +Signed-off-by: AnnTian Shao +--- + faiss/IndexHNSW.cpp | 123 +++++++++++++++++++++++++----------- + faiss/index_factory.cpp | 7 ++- + tests/test_id_grouper.cpp | 128 ++++++++++++++++++++++++++++++++++++++ + 3 files changed, 222 insertions(+), 36 deletions(-) + +diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp +index eee3e99c6..7c5dfe020 100644 +--- a/faiss/IndexHNSW.cpp ++++ b/faiss/IndexHNSW.cpp +@@ -286,6 +286,61 @@ void hnsw_search( + hnsw_stats.combine({n1, n2, ndis, nhops}); + } + ++template ++void hnsw_search_level_0( ++ const IndexHNSW* index, ++ idx_t n, ++ const float* x, ++ idx_t k, ++ const storage_idx_t* nearest, ++ const float* nearest_d, ++ float* distances, ++ idx_t* labels, ++ int nprobe, ++ int search_type, ++ const SearchParameters* params_in, ++ BlockResultHandler& bres) { ++ ++ const HNSW& hnsw = index->hnsw; ++ const SearchParametersHNSW* params = nullptr; ++ ++ if (params_in) { ++ params = dynamic_cast(params_in); ++ FAISS_THROW_IF_NOT_MSG(params, "params type invalid"); ++ } ++ ++#pragma omp parallel ++ { ++ std::unique_ptr qdis( ++ storage_distance_computer(index->storage)); ++ HNSWStats search_stats; ++ VisitedTable vt(index->ntotal); ++ typename BlockResultHandler::SingleResultHandler res(bres); ++ ++#pragma omp for ++ for (idx_t i = 0; i < n; i++) { ++ res.begin(i); ++ qdis->set_query(x + i * index->d); ++ ++ hnsw.search_level_0( ++ *qdis.get(), ++ res, ++ nprobe, ++ nearest + i * nprobe, ++ nearest_d + i * nprobe, ++ search_type, ++ search_stats, ++ vt, ++ params); ++ res.end(); ++ vt.advance(); ++ } ++#pragma omp critical ++ { hnsw_stats.combine(search_stats); } ++ } ++ ++} ++ + } // anonymous namespace + + void IndexHNSW::search( +@@ -419,46 +474,44 @@ void IndexHNSW::search_level_0( + FAISS_THROW_IF_NOT(k > 0); + FAISS_THROW_IF_NOT(nprobe > 0); + +- const SearchParametersHNSW* params = nullptr; +- +- if (params_in) { +- params = dynamic_cast(params_in); +- FAISS_THROW_IF_NOT_MSG(params, "params type invalid"); +- } +- + storage_idx_t ntotal = hnsw.levels.size(); + +- using RH = HeapBlockResultHandler; +- RH bres(n, distances, labels, k); + +-#pragma omp parallel +- { +- std::unique_ptr qdis( +- storage_distance_computer(storage)); +- HNSWStats search_stats; +- VisitedTable vt(ntotal); +- RH::SingleResultHandler res(bres); ++ if (params_in && params_in->grp) { ++ using RH = GroupedHeapBlockResultHandler; ++ RH bres(n, distances, labels, k, params_in->grp); + +-#pragma omp for +- for (idx_t i = 0; i < n; i++) { +- res.begin(i); +- qdis->set_query(x + i * d); + +- hnsw.search_level_0( +- *qdis.get(), +- res, +- nprobe, +- nearest + i * nprobe, +- nearest_d + i * nprobe, +- search_type, +- search_stats, +- vt, +- params); +- res.end(); +- vt.advance(); +- } +-#pragma omp critical +- { hnsw_stats.combine(search_stats); } ++ hnsw_search_level_0( ++ this, ++ n, ++ x, ++ k, ++ nearest, ++ nearest_d, ++ distances, ++ labels, ++ nprobe, // n_probes ++ search_type, // search_type ++ params_in, ++ bres); ++ } else { ++ using RH = HeapBlockResultHandler; ++ RH bres(n, distances, labels, k); ++ ++ hnsw_search_level_0( ++ this, ++ n, ++ x, ++ k, ++ nearest, ++ nearest_d, ++ distances, ++ labels, ++ nprobe, // n_probes ++ search_type, // search_type ++ params_in, ++ bres); + } + if (is_similarity_metric(this->metric_type)) { + // we need to revert the negated distances +diff --git a/faiss/index_factory.cpp b/faiss/index_factory.cpp +index 8ff4bfec7..24e65b632 100644 +--- a/faiss/index_factory.cpp ++++ b/faiss/index_factory.cpp +@@ -453,6 +453,11 @@ IndexHNSW* parse_IndexHNSW( + return re_match(code_string, pattern, sm); + }; + ++ if (match("Cagra")) { ++ IndexHNSWCagra* cagra = new IndexHNSWCagra(d, hnsw_M, mt); ++ return cagra; ++ } ++ + if (match("Flat|")) { + return new IndexHNSWFlat(d, hnsw_M, mt); + } +@@ -781,7 +786,7 @@ std::unique_ptr index_factory_sub( + + // HNSW variants (it was unclear in the old version that the separator was a + // "," so we support both "_" and ",") +- if (re_match(description, "HNSW([0-9]*)([,_].*)?", sm)) { ++ if (re_match(description, "HNSW([0-9]*)([,_].*)?(Cagra)?", sm)) { + int hnsw_M = mres_to_int(sm[1], 32); + // We also accept empty code string (synonym of Flat) + std::string code_string = +diff --git a/tests/test_id_grouper.cpp b/tests/test_id_grouper.cpp +index bd8ab5f9d..ebe16a364 100644 +--- a/tests/test_id_grouper.cpp ++++ b/tests/test_id_grouper.cpp +@@ -172,6 +172,65 @@ TEST(IdGrouper, bitmap_with_hnsw) { + delete[] xb; + } + ++TEST(IdGrouper, bitmap_with_hnsw_cagra) { ++ int d = 1; // dimension ++ int nb = 10; // database size ++ ++ std::mt19937 rng; ++ std::uniform_real_distribution<> distrib; ++ ++ float* xb = new float[d * nb]; ++ ++ for (int i = 0; i < nb; i++) { ++ for (int j = 0; j < d; j++) ++ xb[d * i + j] = distrib(rng); ++ xb[d * i] += i / 1000.; ++ } ++ ++ uint64_t bitmap[1] = {}; ++ faiss::IDGrouperBitmap id_grouper(1, bitmap); ++ for (int i = 0; i < nb; i++) { ++ if (i % 2 == 1) { ++ id_grouper.set_group(i); ++ } ++ } ++ ++ int k = 10; ++ int m = 8; ++ faiss::Index* index = ++ new faiss::IndexHNSWCagra(d, m, faiss::MetricType::METRIC_L2); ++ index->add(nb, xb); // add vectors to the index ++ dynamic_cast(index)->base_level_only=true; ++ ++ // search ++ idx_t* I = new idx_t[k]; ++ float* D = new float[k]; ++ ++ auto pSearchParameters = new faiss::SearchParametersHNSW(); ++ pSearchParameters->grp = &id_grouper; ++ ++ index->search(1, xb, k, D, I, pSearchParameters); ++ ++ std::unordered_set group_ids; ++ ASSERT_EQ(0, I[0]); ++ ASSERT_EQ(0, D[0]); ++ group_ids.insert(id_grouper.get_group(I[0])); ++ for (int j = 1; j < 5; j++) { ++ ASSERT_NE(-1, I[j]); ++ ASSERT_NE(std::numeric_limits::max(), D[j]); ++ group_ids.insert(id_grouper.get_group(I[j])); ++ } ++ for (int j = 5; j < k; j++) { ++ ASSERT_EQ(-1, I[j]); ++ ASSERT_EQ(std::numeric_limits::max(), D[j]); ++ } ++ ASSERT_EQ(5, group_ids.size()); ++ ++ delete[] I; ++ delete[] D; ++ delete[] xb; ++} ++ + TEST(IdGrouper, bitmap_with_binary_hnsw) { + int d = 16; // dimension + int nb = 10; // database size +@@ -291,6 +350,75 @@ TEST(IdGrouper, bitmap_with_hnsw_idmap) { + delete[] xb; + } + ++TEST(IdGrouper, bitmap_with_hnsw_cagra_idmap) { ++ int d = 1; // dimension ++ int nb = 10; // database size ++ ++ std::mt19937 rng; ++ std::uniform_real_distribution<> distrib; ++ ++ float* xb = new float[d * nb]; ++ idx_t* xids = new idx_t[d * nb]; ++ ++ for (int i = 0; i < nb; i++) { ++ for (int j = 0; j < d; j++) ++ xb[d * i + j] = distrib(rng); ++ xb[d * i] += i / 1000.; ++ } ++ ++ uint64_t bitmap[1] = {}; ++ faiss::IDGrouperBitmap id_grouper(1, bitmap); ++ int num_grp = 0; ++ int grp_size = 2; ++ int id_in_grp = 0; ++ for (int i = 0; i < nb; i++) { ++ xids[i] = i + num_grp; ++ id_in_grp++; ++ if (id_in_grp == grp_size) { ++ id_grouper.set_group(i + num_grp + 1); ++ num_grp++; ++ id_in_grp = 0; ++ } ++ } ++ ++ int k = 10; ++ int m = 8; ++ faiss::Index* index = ++ new faiss::IndexHNSWCagra(d, m, faiss::MetricType::METRIC_L2); ++ faiss::IndexIDMap id_map = ++ faiss::IndexIDMap(index); // add vectors to the index ++ id_map.add_with_ids(nb, xb, xids); ++ dynamic_cast(id_map.index)->base_level_only=true; ++ ++ // search ++ idx_t* I = new idx_t[k]; ++ float* D = new float[k]; ++ ++ auto pSearchParameters = new faiss::SearchParametersHNSW(); ++ pSearchParameters->grp = &id_grouper; ++ ++ id_map.search(1, xb, k, D, I, pSearchParameters); ++ ++ std::unordered_set group_ids; ++ ASSERT_EQ(0, I[0]); ++ ASSERT_EQ(0, D[0]); ++ group_ids.insert(id_grouper.get_group(I[0])); ++ for (int j = 1; j < 5; j++) { ++ ASSERT_NE(-1, I[j]); ++ ASSERT_NE(std::numeric_limits::max(), D[j]); ++ group_ids.insert(id_grouper.get_group(I[j])); ++ } ++ for (int j = 5; j < k; j++) { ++ ASSERT_EQ(-1, I[j]); ++ ASSERT_EQ(std::numeric_limits::max(), D[j]); ++ } ++ ASSERT_EQ(5, group_ids.size()); ++ ++ delete[] I; ++ delete[] D; ++ delete[] xb; ++} ++ + TEST(IdGrouper, bitmap_with_binary_hnsw_idmap) { + int d = 16; // dimension + int nb = 10; // database size +-- +2.47.1 + diff --git a/jni/patches/nmslib/0006-Bump-cmake-version-nmslib.patch b/jni/patches/nmslib/0006-Bump-cmake-version-nmslib.patch deleted file mode 100644 index f5a4ca58ce..0000000000 --- a/jni/patches/nmslib/0006-Bump-cmake-version-nmslib.patch +++ /dev/null @@ -1,54 +0,0 @@ -From 38d32faf4183ad400831624a3d6874af6128f9a8 Mon Sep 17 00:00:00 2001 -From: Owen Halpert -Date: Mon, 31 Mar 2025 17:00:48 -0700 -Subject: [PATCH] Bump cmake version nmslib - ---- - similarity_search/CMakeLists.txt | 16 +++++++++------- - 1 file changed, 9 insertions(+), 7 deletions(-) - -diff --git a/similarity_search/CMakeLists.txt b/similarity_search/CMakeLists.txt -index bc6ef3c..c555115 100644 ---- a/similarity_search/CMakeLists.txt -+++ b/similarity_search/CMakeLists.txt -@@ -8,7 +8,11 @@ - # - # - --cmake_minimum_required (VERSION 2.8) -+cmake_minimum_required (VERSION 3.5...4.0) -+ -+if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0") -+ cmake_policy(SET CMP0153 OLD) -+endif() - - project (NonMetricSpaceLib) - -@@ -20,12 +24,10 @@ project (NonMetricSpaceLib) - # - function(CXX_COMPILER_DUMPVERSION _OUTPUT_VERSION) - -- exec_program(${CMAKE_CXX_COMPILER} -- ARGS ${CMAKE_CXX_COMPILER_ARG1} -dumpversion -- OUTPUT_VARIABLE COMPILER_VERSION -+ execute_process( -+ COMMAND ${CMAKE_CXX_COMPILER} ${CMAKE_CXX_COMPILER_ARG1} -dumpversion -+ OUTPUT_VARIABLE COMPILER_VERSION - ) -- #string(REGEX REPLACE "([0-9])\\.([0-9])(\\.[0-9])?" "\\1\\2" -- # COMPILER_VERSION ${COMPILER_VERSION}) - - set(${_OUTPUT_VERSION} ${COMPILER_VERSION} PARENT_SCOPE) - endfunction() -@@ -55,7 +57,7 @@ elseif(${CMAKE_CXX_COMPILER_ID} STREQUAL "Intel") - endif() - set (CMAKE_CXX_FLAGS_RELEASE "-Wall -Wunreachable-code -Ofast -DNDEBUG -std=c++11 -DHAVE_CXX0X -pthread ${SIMD_FLAGS} -fpic") - set (CMAKE_CXX_FLAGS_DEBUG "-Wall -Wunreachable-code -ggdb -DNDEBUG -std=c++11 -DHAVE_CXX0X -pthread ${SIMD_FLAGS} -fpic") --elseif(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang") -+elseif(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang" OR ${CMAKE_CXX_COMPILER_ID} STREQUAL "AppleClang") - if (CMAKE_SYSTEM_NAME MATCHES Darwin) - # MACOSX - set (CMAKE_CXX_FLAGS_RELEASE "${WARN_FLAGS} -O3 -DNDEBUG -std=c++11 -DHAVE_CXX0X -pthread -fpic ${SIMD_FLAGS}") --- -2.47.1 - diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 6512019642..0dd9ac8366 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -748,6 +748,84 @@ TEST(FaissQueryIndexWithParentFilterTest, BasicAssertions) { } } +TEST(FaissQueryIndexHNSWCagraWithParentFilterTest, BasicAssertions) { + // Define the index data + faiss::idx_t numIds = 100; + std::vector ids; + std::vector vectors; + std::vector parentIds; + int dim = 16; + for (int64_t i = 1; i < numIds + 1; i++) { + if (i % 10 == 0) { + parentIds.push_back(i); + continue; + } + ids.push_back(i); + for (int j = 0; j < dim; j++) { + vectors.push_back(test_util::RandomFloat(-500.0, 500.0)); + } + } + + faiss::MetricType metricType = faiss::METRIC_L2; + std::string method = "HNSW32,Cagra"; + + // Define query data + int k = 20; + int numQueries = 100; + std::vector> queries; + + for (int i = 0; i < numQueries; i++) { + std::vector query; + query.reserve(dim); + for (int j = 0; j < dim; j++) { + query.push_back(test_util::RandomFloat(-500.0, 500.0)); + } + queries.push_back(query); + } + + // Create the index + std::unique_ptr createdIndex( + test_util::FaissCreateIndex(dim, method, metricType)); + auto createdIndexWithData = + test_util::FaissAddData(createdIndex.get(), ids, vectors); + dynamic_cast(createdIndexWithData.index)->base_level_only=true; + + int efSearch = 100; + std::unordered_map methodParams; + methodParams[knn_jni::EF_SEARCH] = reinterpret_cast(&efSearch); + + // Setup jni + NiceMock jniEnv; + NiceMock mockJNIUtil; + EXPECT_CALL(mockJNIUtil, + GetJavaIntArrayLength( + &jniEnv, reinterpret_cast(&parentIds))) + .WillRepeatedly(Return(parentIds.size())); + for (auto query : queries) { + std::unique_ptr *>> results( + reinterpret_cast *> *>( + knn_jni::faiss_wrapper::QueryIndex( + &mockJNIUtil, &jniEnv, + reinterpret_cast(&createdIndexWithData), + reinterpret_cast(&query), k, reinterpret_cast(&methodParams), + reinterpret_cast(&parentIds)))); + + // Even with k 20, result should have only 10 which is total number of groups + ASSERT_EQ(10, results->size()); + // Result should be one for each group + std::set idSet; + for (const auto& pairPtr : *results) { + idSet.insert(pairPtr->first / 10); + } + ASSERT_EQ(10, idSet.size()); + + // Need to free up each result + for (auto it : *results.get()) { + delete it; + } + } +} + TEST(FaissFreeTest, BasicAssertions) { // Define the data int dim = 2; diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/DerivedSourceBWCRestartIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/DerivedSourceBWCRestartIT.java index a1f96b0ece..bf7bc309c9 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/DerivedSourceBWCRestartIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/DerivedSourceBWCRestartIT.java @@ -6,6 +6,8 @@ package org.opensearch.knn.bwc; import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.knn.DerivedSourceTestCase; import org.opensearch.knn.DerivedSourceUtils; import org.opensearch.test.rest.OpenSearchRestTestCase; @@ -22,12 +24,12 @@ public class DerivedSourceBWCRestartIT extends DerivedSourceTestCase { public void testFlat_indexAndForceMergeOnOld_injectOnNew() throws IOException { - List indexConfigContexts = getFlatIndexContexts("knn-bwc", false); + List indexConfigContexts = getFlatIndexContexts("knn-bwc", false, false); testIndexAndForceMergeOnOld_injectOnNew(indexConfigContexts); } public void testFlat_indexOnOld_forceMergeAndInjectOnNew() throws IOException { - List indexConfigContexts = getFlatIndexContexts("knn-bwc", false); + List indexConfigContexts = getFlatIndexContexts("knn-bwc", false, false); testIndexOnOld_forceMergeAndInjectOnNew(indexConfigContexts); } @@ -67,6 +69,29 @@ private void testIndexOnOld_forceMergeAndInjectOnNew(List profile(String fieldName) { + List shardVectorProfile = new ArrayList<>(); + + try (Engine.Searcher searcher = indexShard.acquireSearcher("knn-profile")) { + List segmentLevelProfilerStates = new ArrayList<>(); + + log.info("[KNN] Beginning profiling for field: {} in shard: {}", fieldName, indexShard.shardId()); + + // For each leaf, collect the profile + searcher.getIndexReader().leaves().forEach(leaf -> { + try { + log.info("[KNN] Processing leaf reader for segment: {}", leaf.reader()); + segmentLevelProfilerStates.add(SegmentProfilerUtil.getSegmentProfileState(leaf.reader(), fieldName)); + log.info("[KNN] Successfully obtained segment profile state"); + } catch (Exception e) { + log.error("[KNN] Error profiling segment: {}", e.getMessage(), e); + } + }); + + if (segmentLevelProfilerStates.isEmpty()) { + log.info("[KNN] No segment profiles were collected for field: {} in shard: {}", fieldName, indexShard.shardId()); + return shardVectorProfile; // Return empty list + } + + log.info("[KNN] Collected {} segment profiles", segmentLevelProfilerStates.size()); + + // Get dimension + int dimension = segmentLevelProfilerStates.get(0).getDimension(); + log.info("[KNN] Vector dimension: {}", dimension); + + // Transpose our list to aggregate per dimension + for (int i = 0; i < dimension; i++) { + final int dimensionId = i; + List transposed = segmentLevelProfilerStates.stream() + .map(state -> state.getStatistics().get(dimensionId)) + .collect(Collectors.toList()); + + shardVectorProfile.add(AggregateSummaryStatistics.aggregate(transposed)); + } + + // Log the results for each dimension + for (int i = 0; i < shardVectorProfile.size(); i++) { + StatisticalSummaryValues stats = shardVectorProfile.get(i); + log.info( + "[KNN] Dimension {}: count={}, min={}, max={}, mean={}, sum={}, variance={}, std_deviation={}", + i, + stats.getN(), + stats.getMin(), + stats.getMax(), + stats.getMean(), + stats.getSum(), + stats.getVariance(), + Math.sqrt(stats.getVariance()) + ); + } + + log.info("[KNN] Profiling completed for field: {} in shard: {}", fieldName, indexShard.shardId()); + } catch (Exception e) { + log.error( + "[KNN] Critical error during profiling for field: {} in shard: {}: {}", + fieldName, + indexShard.shardId(), + e.getMessage(), + e + ); + } + + return shardVectorProfile; + } + /** * Load all of the k-NN segments for this shard into the cache. * @@ -235,4 +317,13 @@ static class EngineFileContext { private final VectorDataType vectorDataType; private final SegmentInfo segmentInfo; } + + /** + * Profile the vector fields in this shard with default field name. + * + * @return List of statistical summaries for each dimension + */ + public List profile() { + return profile("my_vector_field"); + } } diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 092c609dc4..9cb9ae6294 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -259,19 +259,19 @@ public class KNNSettings { Setting.Property.Dynamic ); - public static final Setting KNN_DERIVED_SOURCE_ENABLED_SETTING = Setting.boolSetting( - KNN_DERIVED_SOURCE_ENABLED, + /** + * This setting identifies KNN index. + */ + public static final Setting IS_KNN_INDEX_SETTING = Setting.boolSetting( + KNN_INDEX, false, IndexScope, Final, UnmodifiableOnRestore ); - /** - * This setting identifies KNN index. - */ - public static final Setting IS_KNN_INDEX_SETTING = Setting.boolSetting( - KNN_INDEX, + public static final Setting KNN_DERIVED_SOURCE_ENABLED_SETTING = Setting.boolSetting( + KNN_DERIVED_SOURCE_ENABLED, false, IndexScope, Final, diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java index eee2808a6b..4df768b7fc 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java @@ -40,7 +40,7 @@ public long ramBytesUsed() { } @Override - public ScriptDocValues getScriptValues() { + public ScriptDocValues getScriptValues() { try { FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, fieldName); if (fieldInfo == null) { diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java index 08ba7d498a..112cdc5a5e 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java @@ -19,7 +19,7 @@ import org.opensearch.index.fielddata.ScriptDocValues; @RequiredArgsConstructor(access = AccessLevel.PRIVATE) -public abstract class KNNVectorScriptDocValues extends ScriptDocValues { +public abstract class KNNVectorScriptDocValues extends ScriptDocValues { private final DocIdSetIterator vectorValues; private final String fieldName; @@ -41,7 +41,7 @@ public void setNextDocId(int docId) throws IOException { docExists = lastDocID == curDocID; } - public float[] getValue() { + public T getValue() { if (!docExists) { String errorMessage = String.format( "One of the document doesn't have a value for field '%s'. " @@ -59,7 +59,7 @@ public float[] getValue() { } } - protected abstract float[] doGetValue() throws IOException; + protected abstract T doGetValue() throws IOException; @Override public int size() { @@ -67,7 +67,7 @@ public int size() { } @Override - public float[] get(int i) { + public T get(int i) { throw new UnsupportedOperationException("knn vector does not support this operation"); } @@ -80,7 +80,7 @@ public float[] get(int i) { * @return A KNNVectorScriptDocValues object based on the type of the values. * @throws IllegalArgumentException If the type of values is unsupported. */ - public static KNNVectorScriptDocValues create(KnnVectorValues knnVectorValues, String fieldName, VectorDataType vectorDataType) { + public static KNNVectorScriptDocValues create(KnnVectorValues knnVectorValues, String fieldName, VectorDataType vectorDataType) { Objects.requireNonNull(knnVectorValues, "values must not be null"); if (knnVectorValues instanceof FloatVectorValues) { return new KNNFloatVectorScriptDocValues((FloatVectorValues) knnVectorValues, fieldName, vectorDataType); @@ -91,17 +91,17 @@ public static KNNVectorScriptDocValues create(KnnVectorValues knnVectorValues, S } } - public static KNNVectorScriptDocValues create(DocIdSetIterator docIdSetIterator, String fieldName, VectorDataType vectorDataType) { + public static KNNVectorScriptDocValues create(DocIdSetIterator docIdSetIterator, String fieldName, VectorDataType vectorDataType) { Objects.requireNonNull(docIdSetIterator, "values must not be null"); if (docIdSetIterator instanceof BinaryDocValues) { - return new KNNNativeVectorScriptDocValues((BinaryDocValues) docIdSetIterator, fieldName, vectorDataType); + return new KNNNativeVectorScriptDocValues<>((BinaryDocValues) docIdSetIterator, fieldName, vectorDataType); } else { throw new IllegalArgumentException("Unsupported values type: " + docIdSetIterator.getClass()); } } - private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues { + private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues { private final ByteVectorValues values; private final KnnVectorValues.DocIndexIterator iterator; @@ -114,24 +114,22 @@ private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptD } @Override - protected float[] doGetValue() throws IOException { + protected byte[] doGetValue() throws IOException { int docId = this.iterator.index(); if (docId == KnnVectorValues.DocIndexIterator.NO_MORE_DOCS) { throw new IllegalStateException("No more ordinals to retrieve vector values."); } - // Use the correct method to retrieve the byte vector for the current ordinal - byte[] bytes = values.vectorValue(docId); - float[] value = new float[bytes.length]; - for (int i = 0; i < bytes.length; i++) { - value[i] = (float) bytes[i]; + try { + return values.vectorValue(docId); + } catch (IOException e) { + throw ExceptionsHelper.convertToOpenSearchException(e); } - return value; } } - private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues { + private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues { private final FloatVectorValues values; private final KnnVectorValues.DocIndexIterator iterator; @@ -153,7 +151,7 @@ protected float[] doGetValue() throws IOException { } } - private static final class KNNNativeVectorScriptDocValues extends KNNVectorScriptDocValues { + private static final class KNNNativeVectorScriptDocValues extends KNNVectorScriptDocValues { private final BinaryDocValues values; KNNNativeVectorScriptDocValues(BinaryDocValues values, String field, VectorDataType type) { @@ -162,7 +160,7 @@ private static final class KNNNativeVectorScriptDocValues extends KNNVectorScrip } @Override - protected float[] doGetValue() throws IOException { + protected T doGetValue() throws IOException { return getVectorDataType().getVectorFromBytesRef(values.binaryValue()); } } @@ -174,10 +172,18 @@ protected float[] doGetValue() throws IOException { * @param type The data type of the vector. * @return An empty KNNVectorScriptDocValues object. */ - public static KNNVectorScriptDocValues emptyValues(String fieldName, VectorDataType type) { - return new KNNVectorScriptDocValues(DocIdSetIterator.empty(), fieldName, type) { + public static KNNVectorScriptDocValues emptyValues(String fieldName, VectorDataType type) { + if (type == VectorDataType.FLOAT) { + return new KNNVectorScriptDocValues(DocIdSetIterator.empty(), fieldName, type) { + @Override + protected float[] doGetValue() throws IOException { + throw new UnsupportedOperationException("empty values"); + } + }; + } + return new KNNVectorScriptDocValues(DocIdSetIterator.empty(), fieldName, type) { @Override - protected float[] doGetValue() throws IOException { + protected byte[] doGetValue() throws IOException { throw new UnsupportedOperationException("empty values"); } }; diff --git a/src/main/java/org/opensearch/knn/index/SpaceType.java b/src/main/java/org/opensearch/knn/index/SpaceType.java index 5d90071e84..4c7609cfe3 100644 --- a/src/main/java/org/opensearch/knn/index/SpaceType.java +++ b/src/main/java/org/opensearch/knn/index/SpaceType.java @@ -40,7 +40,7 @@ public void validateVectorDataType(VectorDataType vectorDataType) { throw new IllegalStateException("Unsupported method"); } }, - L2("l2") { + L2("l2", SpaceType.GENERIC_SCORE_TRANSLATION) { @Override public float scoreTranslation(float rawScore) { return 1 / (1 + rawScore); @@ -59,7 +59,7 @@ public float scoreToDistanceTranslation(float score) { return 1 / score - 1; } }, - COSINESIMIL("cosinesimil") { + COSINESIMIL("cosinesimil", "`Math.max((2.0F - rawScore) / 2.0F, 0.0F)`") { /** * Cosine similarity has range of [-1, 1] where -1 represents vectors are at diametrically opposite, and 1 is where * they are identical in direction and perfectly similar. In Lucene, scores have to be in the range of [0, Float.MAX_VALUE]. @@ -100,13 +100,13 @@ public void validateVector(float[] vector) { } } }, - L1("l1") { + L1("l1", SpaceType.GENERIC_SCORE_TRANSLATION) { @Override public float scoreTranslation(float rawScore) { return 1 / (1 + rawScore); } }, - LINF("linf") { + LINF("linf", SpaceType.GENERIC_SCORE_TRANSLATION) { @Override public float scoreTranslation(float rawScore) { return 1 / (1 + rawScore); @@ -129,12 +129,17 @@ public float scoreTranslation(float rawScore) { return -rawScore + 1; } + @Override + public String explainScoreTranslation(float rawScore) { + return rawScore >= 0 ? GENERIC_SCORE_TRANSLATION : "`-rawScore + 1`"; + } + @Override public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() { return KNNVectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; } }, - HAMMING("hamming") { + HAMMING("hamming", SpaceType.GENERIC_SCORE_TRANSLATION) { @Override public float scoreTranslation(float rawScore) { return 1 / (1 + rawScore); @@ -169,14 +174,29 @@ public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() { .collect(Collectors.toList()) .toArray(new String[0]); + private static final String GENERIC_SCORE_TRANSLATION = "`1 / (1 + rawScore)`"; private final String value; + private final String explanationFormula; SpaceType(String value) { this.value = value; + this.explanationFormula = null; + } + + SpaceType(String value, String explanationFormula) { + this.value = value; + this.explanationFormula = explanationFormula; } public abstract float scoreTranslation(float rawScore); + public String explainScoreTranslation(float rawScore) { + if (explanationFormula != null) { + return explanationFormula; + } + throw new UnsupportedOperationException("explainScoreTranslation is not defined for this space type."); + } + /** * Get KNNVectorSimilarityFunction that maps to this SpaceType * diff --git a/src/main/java/org/opensearch/knn/index/VectorDataType.java b/src/main/java/org/opensearch/knn/index/VectorDataType.java index d0b743887d..aa8c558b86 100644 --- a/src/main/java/org/opensearch/knn/index/VectorDataType.java +++ b/src/main/java/org/opensearch/knn/index/VectorDataType.java @@ -46,15 +46,8 @@ public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunc } @Override - public float[] getVectorFromBytesRef(BytesRef binaryValue) { - float[] vector = new float[binaryValue.length]; - int i = 0; - int j = binaryValue.offset; - - while (i < binaryValue.length) { - vector[i++] = binaryValue.bytes[j++]; - } - return vector; + public byte[] getVectorFromBytesRef(BytesRef binaryValue) { + return binaryValue.bytes; } @Override @@ -75,15 +68,8 @@ public FieldType createKnnVectorFieldType(int dimension, KNNVectorSimilarityFunc } @Override - public float[] getVectorFromBytesRef(BytesRef binaryValue) { - float[] vector = new float[binaryValue.length]; - int i = 0; - int j = binaryValue.offset; - - while (i < binaryValue.length) { - vector[i++] = binaryValue.bytes[j++]; - } - return vector; + public byte[] getVectorFromBytesRef(BytesRef binaryValue) { + return binaryValue.bytes; } @Override @@ -143,7 +129,7 @@ public void freeNativeMemory(long memoryAddress) { * @param binaryValue Binary Value * @return float vector deserialized from binary value */ - public abstract float[] getVectorFromBytesRef(BytesRef binaryValue); + public abstract T getVectorFromBytesRef(BytesRef binaryValue); /** * @param trainingDataAllocation training data that has been allocated in native memory diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010DerivedSourceStoredFieldsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010DerivedSourceStoredFieldsFormat.java index d49e1c9ec6..47406a1779 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010DerivedSourceStoredFieldsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010DerivedSourceStoredFieldsFormat.java @@ -18,6 +18,7 @@ import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.index.codec.derivedsource.DerivedFieldInfo; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReaders; import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier; import org.opensearch.knn.index.codec.derivedsource.DerivedSourceSegmentAttributeParser; import org.opensearch.knn.index.mapper.KNNVectorFieldType; @@ -55,11 +56,13 @@ public StoredFieldsReader fieldsReader(Directory directory, SegmentInfo segmentI if (derivedVectorFields.isEmpty()) { return delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext); } + SegmentReadState segmentReadState = new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext); + DerivedSourceReaders derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState); return new KNN10010DerivedSourceStoredFieldsReader( delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext), derivedVectorFields, - derivedSourceReadersSupplier, - new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext) + derivedSourceReaders, + segmentReadState ); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010DerivedSourceStoredFieldsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010DerivedSourceStoredFieldsReader.java index 86aef37aed..42c135cdaf 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010DerivedSourceStoredFieldsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN10010Codec/KNN10010DerivedSourceStoredFieldsReader.java @@ -11,7 +11,7 @@ import org.apache.lucene.util.IOUtils; import org.opensearch.index.fieldvisitor.FieldsVisitor; import org.opensearch.knn.index.codec.derivedsource.DerivedFieldInfo; -import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReaders; import org.opensearch.knn.index.codec.derivedsource.DerivedSourceStoredFieldVisitor; import org.opensearch.knn.index.codec.derivedsource.DerivedSourceVectorTransformer; @@ -21,7 +21,7 @@ public class KNN10010DerivedSourceStoredFieldsReader extends StoredFieldsReader { private final StoredFieldsReader delegate; private final List derivedVectorFields; - private final DerivedSourceReadersSupplier derivedSourceReadersSupplier; + private final DerivedSourceReaders derivedSourceReaders; private final SegmentReadState segmentReadState; private final boolean shouldInject; @@ -31,36 +31,36 @@ public class KNN10010DerivedSourceStoredFieldsReader extends StoredFieldsReader * * @param delegate delegate StoredFieldsReader * @param derivedVectorFields List of fields that are derived source fields - * @param derivedSourceReadersSupplier Supplier for the derived source readers + * @param derivedSourceReaders derived source readers * @param segmentReadState SegmentReadState for the segment * @throws IOException in case of I/O error */ public KNN10010DerivedSourceStoredFieldsReader( StoredFieldsReader delegate, List derivedVectorFields, - DerivedSourceReadersSupplier derivedSourceReadersSupplier, + DerivedSourceReaders derivedSourceReaders, SegmentReadState segmentReadState ) throws IOException { - this(delegate, derivedVectorFields, derivedSourceReadersSupplier, segmentReadState, true); + this(delegate, derivedVectorFields, derivedSourceReaders, segmentReadState, true); } private KNN10010DerivedSourceStoredFieldsReader( StoredFieldsReader delegate, List derivedVectorFields, - DerivedSourceReadersSupplier derivedSourceReadersSupplier, + DerivedSourceReaders derivedSourceReaders, SegmentReadState segmentReadState, boolean shouldInject ) throws IOException { this.delegate = delegate; this.derivedVectorFields = derivedVectorFields; - this.derivedSourceReadersSupplier = derivedSourceReadersSupplier; + this.derivedSourceReaders = derivedSourceReaders; this.segmentReadState = segmentReadState; this.shouldInject = shouldInject; this.derivedSourceVectorTransformer = createDerivedSourceVectorTransformer(); } - private DerivedSourceVectorTransformer createDerivedSourceVectorTransformer() throws IOException { - return new DerivedSourceVectorTransformer(derivedSourceReadersSupplier, segmentReadState, derivedVectorFields); + private DerivedSourceVectorTransformer createDerivedSourceVectorTransformer() { + return new DerivedSourceVectorTransformer(derivedSourceReaders, segmentReadState, derivedVectorFields); } @Override @@ -89,7 +89,7 @@ public StoredFieldsReader clone() { return new KNN10010DerivedSourceStoredFieldsReader( delegate.clone(), derivedVectorFields, - derivedSourceReadersSupplier, + derivedSourceReaders.cloneWithMerge(), segmentReadState, shouldInject ); @@ -105,7 +105,7 @@ public void checkIntegrity() throws IOException { @Override public void close() throws IOException { - IOUtils.close(delegate, derivedSourceVectorTransformer); + IOUtils.close(delegate, derivedSourceReaders); } /** @@ -120,7 +120,7 @@ private StoredFieldsReader cloneForMerge() { return new KNN10010DerivedSourceStoredFieldsReader( delegate.getMergeInstance(), derivedVectorFields, - derivedSourceReadersSupplier, + derivedSourceReaders.cloneWithMerge(), segmentReadState, false ); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriter.java index 49b1819c10..54d125627f 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriter.java @@ -12,6 +12,7 @@ import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.store.IndexOutput; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.profiler.SegmentProfilerState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; @@ -85,6 +86,20 @@ public void writeState(int fieldNumber, QuantizationState quantizationState) thr fieldQuantizationStates.add(new FieldQuantizationState(fieldNumber, stateBytes, position)); } + /** + * Writes a segment profile state as bytes + * + * @param fieldNumber field number + * @param segmentProfilerState segment profiler state + * @throws IOException could be thrown while writing + */ + public void writeState(int fieldNumber, SegmentProfilerState segmentProfilerState) throws IOException { + byte[] stateBytes = segmentProfilerState.toByteArray(); + long position = output.getFilePointer(); + output.writeBytes(stateBytes, stateBytes.length); + fieldQuantizationStates.add(new FieldQuantizationState(fieldNumber, stateBytes, position)); + } + /** * Writes index footer and other index information for parsing later * @throws IOException could be thrown while writing diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java index 1b0e2a8397..57bb8f8e28 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java @@ -35,6 +35,10 @@ import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.memoryoptsearch.VectorSearcher; import org.opensearch.knn.memoryoptsearch.VectorSearcherFactory; +import org.opensearch.knn.profiler.KNN990ProfileStateReader; +import org.opensearch.knn.profiler.SegmentProfileKNNCollector; +import org.opensearch.knn.profiler.SegmentProfileStateReadConfig; +import org.opensearch.knn.profiler.SegmentProfilerState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager; import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; @@ -163,6 +167,14 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits return; } + if (knnCollector instanceof SegmentProfileKNNCollector) { + SegmentProfilerState segmentProfileState = KNN990ProfileStateReader.read( + new SegmentProfileStateReadConfig(segmentReadState, field) + ); + ((SegmentProfileKNNCollector) knnCollector).setSegmentProfilerState(segmentProfileState); + return; + } + if (trySearchWithMemoryOptimizedSearch(field, target, knnCollector, acceptDocs, true)) { return; } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 0379696897..d6b976d60c 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -29,6 +29,7 @@ import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.plugin.stats.KNNGraphValue; +import org.opensearch.knn.profiler.SegmentProfilerState; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; @@ -107,6 +108,7 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { field.getVectors() ); final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs); + profile(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs); // should skip graph building only for non quantization use case and if threshold is met if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) { log.info( @@ -150,6 +152,10 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState } final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs); + + // Write the segment profile state to the directory + profile(fieldInfo, knnVectorValuesSupplier, totalLiveDocs); + // should skip graph building only for non quantization use case and if threshold is met if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) { log.info( @@ -188,6 +194,7 @@ public void finish() throws IOException { if (quantizationStateWriter != null) { quantizationStateWriter.writeFooter(); } + flatVectorsWriter.finish(); } @@ -241,6 +248,23 @@ private QuantizationState train( return quantizationState; } + private SegmentProfilerState profile( + final FieldInfo fieldInfo, + final Supplier> knnVectorValuesSupplier, + final int totalLiveDocs + ) throws IOException { + + SegmentProfilerState segmentProfilerState = null; + if (totalLiveDocs > 0) { + // TODO:Refactor to another init + initQuantizationStateWriterIfNecessary(); + SegmentProfilerState profileResultForSegment = SegmentProfilerState.profileVectors(knnVectorValuesSupplier); + quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), profileResultForSegment); + } + + return segmentProfilerState; + } + /** * The {@link KNNVectorValues} will be exhausted after this function run. So make sure that you are not sending the * vectorsValues object which you plan to use later diff --git a/src/main/java/org/opensearch/knn/index/codec/backward_codecs/KNN9120Codec/DerivedSourceVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/backward_codecs/KNN9120Codec/DerivedSourceVectorInjector.java index 32561b0c4f..beee41ff19 100644 --- a/src/main/java/org/opensearch/knn/index/codec/backward_codecs/KNN9120Codec/DerivedSourceVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/backward_codecs/KNN9120Codec/DerivedSourceVectorInjector.java @@ -8,7 +8,6 @@ import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.SegmentReadState; -import org.apache.lucene.util.IOUtils; import org.opensearch.common.collect.Tuple; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentHelper; @@ -18,7 +17,6 @@ import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.Closeable; import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -33,7 +31,7 @@ * format readers and information about the fields to inject vectors into the source. */ @Log4j2 -public class DerivedSourceVectorInjector implements Closeable { +public class DerivedSourceVectorInjector { private final KNN9120DerivedSourceReaders derivedSourceReaders; private final List perFieldDerivedVectorInjectors; @@ -42,16 +40,16 @@ public class DerivedSourceVectorInjector implements Closeable { /** * Constructor for DerivedSourceVectorInjector. * - * @param derivedSourceReadersSupplier Supplier for the derived source readers. + * @param derivedSourceReaders Derived source readers. * @param segmentReadState Segment read state * @param fieldsToInjectVector List of fields to inject vectors into */ public DerivedSourceVectorInjector( - KNN9120DerivedSourceReadersSupplier derivedSourceReadersSupplier, + KNN9120DerivedSourceReaders derivedSourceReaders, SegmentReadState segmentReadState, List fieldsToInjectVector - ) throws IOException { - this.derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState); + ) { + this.derivedSourceReaders = derivedSourceReaders; this.perFieldDerivedVectorInjectors = new ArrayList<>(); this.fieldNames = new HashSet<>(); for (FieldInfo fieldInfo : fieldsToInjectVector) { @@ -128,9 +126,4 @@ public boolean shouldInject(String[] includes, String[] excludes) { } return true; } - - @Override - public void close() throws IOException { - IOUtils.close(derivedSourceReaders); - } } diff --git a/src/main/java/org/opensearch/knn/index/codec/backward_codecs/KNN9120Codec/KNN9120DerivedSourceReaders.java b/src/main/java/org/opensearch/knn/index/codec/backward_codecs/KNN9120Codec/KNN9120DerivedSourceReaders.java index 0277dd428e..1c0e3e3bbc 100644 --- a/src/main/java/org/opensearch/knn/index/codec/backward_codecs/KNN9120Codec/KNN9120DerivedSourceReaders.java +++ b/src/main/java/org/opensearch/knn/index/codec/backward_codecs/KNN9120Codec/KNN9120DerivedSourceReaders.java @@ -11,11 +11,13 @@ import org.apache.lucene.codecs.FieldsProducer; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.NormsProducer; +import org.apache.lucene.store.AlreadyClosedException; import org.apache.lucene.util.IOUtils; import org.opensearch.common.Nullable; import java.io.Closeable; import java.io.IOException; +import java.util.concurrent.atomic.AtomicInteger; /** * Class holds the readers necessary to implement derived source. Important to note that if a segment does not have @@ -33,8 +35,48 @@ public class KNN9120DerivedSourceReaders implements Closeable { @Nullable private final NormsProducer normsProducer; + // Copied from lucene (https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/index/SegmentCoreReaders.java): + // We need to reference count these readers because they may be shared amongst different instances. + // "Counts how many other readers share the core objects + // (freqStream, proxStream, tis, etc.) of this reader; + // when coreRef drops to 0, these core objects may be + // closed. A given instance of SegmentReader may be + // closed, even though it shares core objects with other + // SegmentReaders": + private final AtomicInteger ref = new AtomicInteger(1); + + /** + * This method is used to clone the KNN9120DerivedSourceReaders object. + * This is used when the object is passed to multiple threads. + * + * @return KNN9120DerivedSourceReaders object + */ + public KNN9120DerivedSourceReaders cloneWithMerge() { + // For cloning, we dont need to reference count. In Lucene, the merging will actually not close any of the + // readers, so it should only be handled by the original code. See + // https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/index/IndexWriter.java#L3372 + // for more details + return this; + } + @Override public void close() throws IOException { - IOUtils.close(knnVectorsReader, docValuesProducer, fieldsProducer, normsProducer); + decRef(); + } + + private void incRef() { + int count; + while ((count = ref.get()) > 0) { + if (ref.compareAndSet(count, count + 1)) { + return; + } + } + throw new AlreadyClosedException("DerivedSourceReaders is already closed"); + } + + private void decRef() throws IOException { + if (ref.decrementAndGet() == 0) { + IOUtils.close(knnVectorsReader, docValuesProducer, fieldsProducer, normsProducer); + } } } diff --git a/src/main/java/org/opensearch/knn/index/codec/backward_codecs/KNN9120Codec/KNN9120DerivedSourceStoredFieldsFormat.java b/src/main/java/org/opensearch/knn/index/codec/backward_codecs/KNN9120Codec/KNN9120DerivedSourceStoredFieldsFormat.java index e86d10bbc0..007618e9dc 100644 --- a/src/main/java/org/opensearch/knn/index/codec/backward_codecs/KNN9120Codec/KNN9120DerivedSourceStoredFieldsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/backward_codecs/KNN9120Codec/KNN9120DerivedSourceStoredFieldsFormat.java @@ -51,10 +51,12 @@ public StoredFieldsReader fieldsReader(Directory directory, SegmentInfo segmentI if (derivedVectorFields == null || derivedVectorFields.isEmpty()) { return delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext); } + SegmentReadState segmentReadState = new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext); + KNN9120DerivedSourceReaders derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState); return new KNN9120DerivedSourceStoredFieldsReader( delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext), derivedVectorFields, - derivedSourceReadersSupplier, + derivedSourceReaders, new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext) ); } diff --git a/src/main/java/org/opensearch/knn/index/codec/backward_codecs/KNN9120Codec/KNN9120DerivedSourceStoredFieldsReader.java b/src/main/java/org/opensearch/knn/index/codec/backward_codecs/KNN9120Codec/KNN9120DerivedSourceStoredFieldsReader.java index 2e76ea263f..80fa35b407 100644 --- a/src/main/java/org/opensearch/knn/index/codec/backward_codecs/KNN9120Codec/KNN9120DerivedSourceStoredFieldsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/backward_codecs/KNN9120Codec/KNN9120DerivedSourceStoredFieldsReader.java @@ -23,7 +23,7 @@ public class KNN9120DerivedSourceStoredFieldsReader extends StoredFieldsReader { private final StoredFieldsReader delegate; private final List derivedVectorFields; - private final KNN9120DerivedSourceReadersSupplier derivedSourceReadersSupplier; + private final KNN9120DerivedSourceReaders derivedSourceReaders; private final SegmentReadState segmentReadState; private final boolean shouldInject; @@ -33,36 +33,36 @@ public class KNN9120DerivedSourceStoredFieldsReader extends StoredFieldsReader { * * @param delegate delegate StoredFieldsReader * @param derivedVectorFields List of fields that are derived source fields - * @param derivedSourceReadersSupplier Supplier for the derived source readers + * @param derivedSourceReaders Derived source readers * @param segmentReadState SegmentReadState for the segment * @throws IOException in case of I/O error */ public KNN9120DerivedSourceStoredFieldsReader( StoredFieldsReader delegate, List derivedVectorFields, - KNN9120DerivedSourceReadersSupplier derivedSourceReadersSupplier, + KNN9120DerivedSourceReaders derivedSourceReaders, SegmentReadState segmentReadState ) throws IOException { - this(delegate, derivedVectorFields, derivedSourceReadersSupplier, segmentReadState, true); + this(delegate, derivedVectorFields, derivedSourceReaders, segmentReadState, true); } private KNN9120DerivedSourceStoredFieldsReader( StoredFieldsReader delegate, List derivedVectorFields, - KNN9120DerivedSourceReadersSupplier derivedSourceReadersSupplier, + KNN9120DerivedSourceReaders derivedSourceReaders, SegmentReadState segmentReadState, boolean shouldInject ) throws IOException { this.delegate = delegate; this.derivedVectorFields = derivedVectorFields; - this.derivedSourceReadersSupplier = derivedSourceReadersSupplier; + this.derivedSourceReaders = derivedSourceReaders; this.segmentReadState = segmentReadState; this.shouldInject = shouldInject; this.derivedSourceVectorInjector = createDerivedSourceVectorInjector(); } private DerivedSourceVectorInjector createDerivedSourceVectorInjector() throws IOException { - return new DerivedSourceVectorInjector(derivedSourceReadersSupplier, segmentReadState, derivedVectorFields); + return new DerivedSourceVectorInjector(derivedSourceReaders, segmentReadState, derivedVectorFields); } @Override @@ -88,7 +88,7 @@ public StoredFieldsReader clone() { return new KNN9120DerivedSourceStoredFieldsReader( delegate.clone(), derivedVectorFields, - derivedSourceReadersSupplier, + derivedSourceReaders.cloneWithMerge(), segmentReadState, shouldInject ); @@ -104,7 +104,7 @@ public void checkIntegrity() throws IOException { @Override public void close() throws IOException { - IOUtils.close(delegate, derivedSourceVectorInjector); + IOUtils.close(delegate, derivedSourceReaders); } /** diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceLuceneHelper.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceLuceneHelper.java index 06890b1266..de7f339d45 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceLuceneHelper.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceLuceneHelper.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.codec.derivedsource; +import com.google.common.annotations.VisibleForTesting; import lombok.RequiredArgsConstructor; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.NumericDocValues; @@ -21,13 +22,26 @@ @RequiredArgsConstructor public class DerivedSourceLuceneHelper { + // Offsets from the parent docId to start from in order to determine where to start the search from for first child. + // In other words, we're guessing the upperbound on the number of nested documents a single parent will have. + // This is an optimization to avoid starting from the first doc to find the previous parent to the parent docid. + // The values are just back of the napkin calculations, but heres how I got these numbers: Assuming there are + // ~12 bytes an entry for NumericDocValues (8 for long value, 4 for int id). On a single + // 4kb page, it should be possible to fit 340 values. If first offset is 150, we can be confident it will be only 1 + // page fetched typically. Then 10 pages and then 40 pages. + private static final int[] NESTED_OFFSET_STARTING_POINTS = new int[] { 150, 1500, 6000 }; + private final DerivedSourceReaders derivedSourceReaders; private final SegmentReadState segmentReadState; + @VisibleForTesting + static final int NO_CHILDREN_INDICATOR = -1; + /** * Get the first child of the given parentDoc. This can be used to determine if the document contains any nested * fields. * + * @param parentDocId Parent doc id to find children for * @return doc id of last matching doc. {@link DocIdSetIterator#NO_MORE_DOCS} if no children exist. * @throws IOException */ @@ -36,16 +50,38 @@ public int getFirstChild(int parentDocId) throws IOException { if (parentDocId == 0) { return NO_MORE_DOCS; } + int lastStartingPoint = -1; + for (int offset : NESTED_OFFSET_STARTING_POINTS) { + int currentStartingPoint = Math.max(0, parentDocId - offset); + // If we've already checked this starting point, no need to continue + if (currentStartingPoint <= lastStartingPoint) { + break; + } + int firstChild = getFirstChild(parentDocId, currentStartingPoint); + // If the returned value is NO_CHILDREN_INDICATOR, we know for sure that there are no children. No need to + // keep checking + if (firstChild == NO_CHILDREN_INDICATOR) { + return NO_MORE_DOCS; + } + // If the first child is in between currentStartingPoint and parentDocId, we can return + if (firstChild != NO_MORE_DOCS) { + return firstChild; + } + lastStartingPoint = currentStartingPoint; + } + // If none of the shortcuts worked, we'll try from the start + return getFirstChild(parentDocId, 0); + } + @VisibleForTesting + int getFirstChild(int parentDocId, int startingPoint) throws IOException { // Only root level documents have the "_primary_term" field. So, we iterate through all of the documents in // order to find out if any have this term. - // TODO: This is expensive and should be optimized. We should start at doc parentDocId - 10000 and work back - // (can we fetch the setting? Maybe) FieldInfo fieldInfo = segmentReadState.fieldInfos.fieldInfo("_primary_term"); assert derivedSourceReaders.getDocValuesProducer() != null; NumericDocValues numericDocValues = derivedSourceReaders.getDocValuesProducer().getNumeric(fieldInfo); int previousParentDocId = NO_MORE_DOCS; - numericDocValues.advance(0); + numericDocValues.advance(startingPoint); while (numericDocValues.docID() != NO_MORE_DOCS) { if (numericDocValues.docID() >= parentDocId) { break; @@ -59,9 +95,9 @@ public int getFirstChild(int parentDocId) throws IOException { if (previousParentDocId == NO_MORE_DOCS) { return 0; } - // If the document right before is the previous parent, then there are no children. + // If the document right before is the previous parent, then there are no children. Return if (parentDocId - previousParentDocId <= 1) { - return NO_MORE_DOCS; + return NO_CHILDREN_INDICATOR; } return previousParentDocId + 1; } diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java index 2d8357072d..e8a18f9462 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java @@ -9,11 +9,13 @@ import lombok.RequiredArgsConstructor; import org.apache.lucene.codecs.DocValuesProducer; import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.store.AlreadyClosedException; import org.apache.lucene.util.IOUtils; import org.opensearch.common.Nullable; import java.io.Closeable; import java.io.IOException; +import java.util.concurrent.atomic.AtomicInteger; /** * Class holds the readers necessary to implement derived source. Important to note that if a segment does not have @@ -21,14 +23,53 @@ */ @RequiredArgsConstructor @Getter -public class DerivedSourceReaders implements Closeable { +public final class DerivedSourceReaders implements Closeable { @Nullable private final KnnVectorsReader knnVectorsReader; @Nullable private final DocValuesProducer docValuesProducer; + // Copied from lucene (https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/index/SegmentCoreReaders.java): + // We need to reference count these readers because they may be shared amongst different instances. + // "Counts how many other readers share the core objects + // (freqStream, proxStream, tis, etc.) of this reader; + // when coreRef drops to 0, these core objects may be + // closed. A given instance of SegmentReader may be + // closed, even though it shares core objects with other + // SegmentReaders": + private final AtomicInteger ref = new AtomicInteger(1); + + /** + * Returns this DerivedSourceReaders object with incremented reference count + * + * @return DerivedSourceReaders object with incremented reference count + */ + public DerivedSourceReaders cloneWithMerge() { + // For cloning, we dont need to reference count. In Lucene, the merging will actually not close any of the + // readers, so it should only be handled by the original code. See + // https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/index/IndexWriter.java#L3372 + // for more details + return this; + } + @Override public void close() throws IOException { - IOUtils.close(knnVectorsReader, docValuesProducer); + decRef(); + } + + private void incRef() { + int count; + while ((count = ref.get()) > 0) { + if (ref.compareAndSet(count, count + 1)) { + return; + } + } + throw new AlreadyClosedException("DerivedSourceReaders is already closed"); + } + + private void decRef() throws IOException { + if (ref.decrementAndGet() == 0) { + IOUtils.close(knnVectorsReader, docValuesProducer); + } } } diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorTransformer.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorTransformer.java index 33f5837b13..e7196c0197 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorTransformer.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorTransformer.java @@ -7,7 +7,6 @@ import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.SegmentReadState; -import org.apache.lucene.util.IOUtils; import org.opensearch.common.collect.Tuple; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentHelper; @@ -18,7 +17,6 @@ import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.Closeable; import java.io.IOException; import java.nio.ByteBuffer; import java.util.HashMap; @@ -27,7 +25,7 @@ import java.util.function.Function; @Log4j2 -public class DerivedSourceVectorTransformer implements Closeable { +public class DerivedSourceVectorTransformer { private final DerivedSourceReaders derivedSourceReaders; Function, Map> derivedSourceVectorTransformer; @@ -37,16 +35,16 @@ public class DerivedSourceVectorTransformer implements Closeable { /** * - * @param derivedSourceReadersSupplier Supplier for the derived source readers. + * @param derivedSourceReaders derived source readers. * @param segmentReadState Segment read state * @param fieldsToInjectVector List of fields to inject vectors into */ public DerivedSourceVectorTransformer( - DerivedSourceReadersSupplier derivedSourceReadersSupplier, + DerivedSourceReaders derivedSourceReaders, SegmentReadState segmentReadState, List fieldsToInjectVector - ) throws IOException { - this.derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState); + ) { + this.derivedSourceReaders = derivedSourceReaders; perFieldDerivedVectorTransformers = new HashMap<>(); Map> perFieldDerivedVectorTransformersFunctionValues = new HashMap<>(); for (DerivedFieldInfo derivedFieldInfo : fieldsToInjectVector) { @@ -137,9 +135,4 @@ public boolean shouldInject(String[] includes, String[] excludes) { } return true; } - - @Override - public void close() throws IOException { - IOUtils.close(derivedSourceReaders); - } } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategyFactory.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategyFactory.java index c9804ed9e2..20a1fb31cf 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategyFactory.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategyFactory.java @@ -9,6 +9,7 @@ import org.apache.lucene.index.FieldInfo; import org.opensearch.index.IndexSettings; import org.opensearch.knn.common.featureflags.KNNFeatureFlags; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.codec.nativeindex.remote.RemoteIndexBuildStrategy; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; @@ -42,17 +43,19 @@ public NativeIndexBuildStrategyFactory(Supplier repositorie } /** - * @param fieldInfo Field related attributes/info - * @param totalLiveDocs Number of documents with the vector field. This values comes from {@link org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsWriter#flush} - * and {@link org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsWriter#mergeOneField} - * @param knnVectorValues An instance of {@link KNNVectorValues} which is used to evaluate the size threshold KNN_REMOTE_VECTOR_BUILD_THRESHOLD - * @return The {@link NativeIndexBuildStrategy} to be used. Intended to be used by {@link NativeIndexWriter} + * @param fieldInfo Field related attributes/info + * @param totalLiveDocs Number of documents with the vector field. This values comes from {@link org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsWriter#flush} + * and {@link org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsWriter#mergeOneField} + * @param knnVectorValues An instance of {@link KNNVectorValues} which is used to evaluate the size threshold KNN_REMOTE_VECTOR_BUILD_THRESHOLD + * @param indexInfo An instance of {@link BuildIndexParams} containing relevant index info + * @return The {@link NativeIndexBuildStrategy} to be used. Intended to be used by {@link NativeIndexWriter} * @throws IOException */ public NativeIndexBuildStrategy getBuildStrategy( final FieldInfo fieldInfo, final int totalLiveDocs, - final KNNVectorValues knnVectorValues + final KNNVectorValues knnVectorValues, + BuildIndexParams indexInfo ) throws IOException { final KNNEngine knnEngine = extractKNNEngine(fieldInfo); boolean isTemplate = fieldInfo.attributes().containsKey(MODEL_ID); diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java index e0ff86e849..e858a07328 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java @@ -102,7 +102,7 @@ public static NativeIndexWriter getWriter( * @throws IOException */ public void flushIndex(final Supplier> knnVectorValuesSupplier, int totalLiveDocs) throws IOException { - buildAndWriteIndex(knnVectorValuesSupplier, totalLiveDocs); + buildAndWriteIndex(knnVectorValuesSupplier, totalLiveDocs, true); recordRefreshStats(); } @@ -122,11 +122,12 @@ public void mergeIndex(final Supplier> knnVectorValuesSupplie long bytesPerVector = knnVectorValues.bytesPerVector(); startMergeStats(totalLiveDocs, bytesPerVector); - buildAndWriteIndex(knnVectorValuesSupplier, totalLiveDocs); + buildAndWriteIndex(knnVectorValuesSupplier, totalLiveDocs, false); endMergeStats(totalLiveDocs, bytesPerVector); } - private void buildAndWriteIndex(final Supplier> knnVectorValuesSupplier, int totalLiveDocs) throws IOException { + private void buildAndWriteIndex(final Supplier> knnVectorValuesSupplier, int totalLiveDocs, boolean isFlush) + throws IOException { if (totalLiveDocs == 0) { log.debug("No live docs for field {}", fieldInfo.name); return; @@ -146,12 +147,14 @@ private void buildAndWriteIndex(final Supplier> knnVectorValu indexOutputWithBuffer, knnEngine, knnVectorValuesSupplier, - totalLiveDocs + totalLiveDocs, + isFlush ); NativeIndexBuildStrategy indexBuilder = indexBuilderFactory.getBuildStrategy( fieldInfo, totalLiveDocs, - knnVectorValuesSupplier.get() + knnVectorValuesSupplier.get(), + nativeIndexParams ); indexBuilder.buildAndWriteIndex(nativeIndexParams); CodecUtil.writeFooter(output); @@ -166,7 +169,8 @@ private BuildIndexParams indexParams( IndexOutputWithBuffer indexOutputWithBuffer, KNNEngine knnEngine, Supplier> knnVectorValuesSupplier, - int totalLiveDocs + int totalLiveDocs, + boolean isFlush ) throws IOException { final Map parameters; VectorDataType vectorDataType; @@ -192,6 +196,7 @@ private BuildIndexParams indexParams( .knnVectorValuesSupplier(knnVectorValuesSupplier) .totalLiveDocs(totalLiveDocs) .segmentWriteState(state) + .isFlush(isFlush) .build(); } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java index cf5d5be70b..466fff601d 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java @@ -36,4 +36,5 @@ public class BuildIndexParams { Supplier> knnVectorValuesSupplier; int totalLiveDocs; SegmentWriteState segmentWriteState; + boolean isFlush; } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildMetrics.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildMetrics.java new file mode 100644 index 0000000000..38acad3461 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildMetrics.java @@ -0,0 +1,156 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex.remote; + +import lombok.extern.log4j.Log4j2; +import org.opensearch.common.StopWatch; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; + +import java.io.IOException; + +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; +import static org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue.BUILD_REQUEST_FAILURE_COUNT; +import static org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue.BUILD_REQUEST_SUCCESS_COUNT; +import static org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue.INDEX_BUILD_FAILURE_COUNT; +import static org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue.INDEX_BUILD_SUCCESS_COUNT; +import static org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue.READ_FAILURE_COUNT; +import static org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue.READ_SUCCESS_COUNT; +import static org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue.READ_TIME; +import static org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue.REMOTE_INDEX_BUILD_CURRENT_FLUSH_OPERATIONS; +import static org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue.REMOTE_INDEX_BUILD_CURRENT_FLUSH_SIZE; +import static org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue.REMOTE_INDEX_BUILD_CURRENT_MERGE_OPERATIONS; +import static org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue.REMOTE_INDEX_BUILD_CURRENT_MERGE_SIZE; +import static org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue.REMOTE_INDEX_BUILD_FLUSH_TIME; +import static org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue.REMOTE_INDEX_BUILD_MERGE_TIME; +import static org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue.WAITING_TIME; +import static org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue.WRITE_FAILURE_COUNT; +import static org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue.WRITE_SUCCESS_COUNT; +import static org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue.WRITE_TIME; + +/** + * Class to handle all metric collection for the remote index build. + * Each phase has its own StopWatch and `start` and `end` methods. + */ +@Log4j2 +public class RemoteIndexBuildMetrics { + private final StopWatch overallStopWatch; + private final StopWatch writeStopWatch; + private final StopWatch buildRequestStopWatch; + private final StopWatch waiterStopWatch; + private final StopWatch readStopWatch; + private long size; + private boolean isFlush; + private String fieldName; + + public RemoteIndexBuildMetrics() { + this.overallStopWatch = new StopWatch(); + this.writeStopWatch = new StopWatch(); + this.buildRequestStopWatch = new StopWatch(); + this.waiterStopWatch = new StopWatch(); + this.readStopWatch = new StopWatch(); + } + + /** + * Helper method to collect remote index build metrics on start + */ + public void startRemoteIndexBuildMetrics(BuildIndexParams indexInfo) throws IOException { + KNNVectorValues knnVectorValues = indexInfo.getKnnVectorValuesSupplier().get(); + initializeVectorValues(knnVectorValues); + this.size = (long) indexInfo.getTotalLiveDocs() * knnVectorValues.bytesPerVector(); + this.isFlush = indexInfo.isFlush(); + this.fieldName = indexInfo.getFieldName(); + overallStopWatch.start(); + if (isFlush) { + REMOTE_INDEX_BUILD_CURRENT_FLUSH_OPERATIONS.increment(); + REMOTE_INDEX_BUILD_CURRENT_FLUSH_SIZE.incrementBy(size); + } else { + REMOTE_INDEX_BUILD_CURRENT_MERGE_OPERATIONS.increment(); + REMOTE_INDEX_BUILD_CURRENT_MERGE_SIZE.incrementBy(size); + } + } + + // Repository read phase metric helpers + public void startRepositoryWriteMetrics() { + writeStopWatch.start(); + } + + public void endRepositoryWriteMetrics(boolean success) { + long time_in_millis = writeStopWatch.stop().totalTime().millis(); + if (success) { + WRITE_SUCCESS_COUNT.increment(); + WRITE_TIME.incrementBy(time_in_millis); + log.debug("Repository write took {} ms for vector field [{}]", time_in_millis, fieldName); + } else { + WRITE_FAILURE_COUNT.increment(); + } + } + + // Build request phase metric helpers + public void startBuildRequestMetrics() { + buildRequestStopWatch.start(); + } + + public void endBuildRequestMetrics(boolean success) { + long time_in_millis = buildRequestStopWatch.stop().totalTime().millis(); + if (success) { + BUILD_REQUEST_SUCCESS_COUNT.increment(); + log.debug("Submit vector build took {} ms for vector field [{}]", time_in_millis, fieldName); + } else { + BUILD_REQUEST_FAILURE_COUNT.increment(); + } + } + + // Await index build phase metric helpers + public void startWaitingMetrics() { + waiterStopWatch.start(); + } + + public void endWaitingMetrics() { + long time_in_millis = waiterStopWatch.stop().totalTime().millis(); + WAITING_TIME.incrementBy(time_in_millis); + log.debug("Await vector build took {} ms for vector field [{}]", time_in_millis, fieldName); + } + + // Repository read phase metric helpers + public void startRepositoryReadMetrics() { + readStopWatch.start(); + } + + public void endRepositoryReadMetrics(boolean success) { + long time_in_millis = readStopWatch.stop().totalTime().millis(); + if (success) { + READ_SUCCESS_COUNT.increment(); + READ_TIME.incrementBy(time_in_millis); + log.debug("Repository read took {} ms for vector field [{}]", time_in_millis, fieldName); + } else { + READ_FAILURE_COUNT.increment(); + } + } + + /** + * Helper method to collect overall remote index build metrics + */ + public void endRemoteIndexBuildMetrics(boolean wasSuccessful) { + long time_in_millis = overallStopWatch.stop().totalTime().millis(); + if (wasSuccessful) { + INDEX_BUILD_SUCCESS_COUNT.increment(); + log.debug("Remote index build succeeded after {} ms for vector field [{}]", time_in_millis, fieldName); + } else { + INDEX_BUILD_FAILURE_COUNT.increment(); + log.warn("Remote index build failed after {} ms for vector field [{}]", time_in_millis, fieldName); + } + if (isFlush) { + REMOTE_INDEX_BUILD_CURRENT_FLUSH_OPERATIONS.decrement(); + REMOTE_INDEX_BUILD_CURRENT_FLUSH_SIZE.decrementBy(size); + REMOTE_INDEX_BUILD_FLUSH_TIME.incrementBy(time_in_millis); + } else { + REMOTE_INDEX_BUILD_CURRENT_MERGE_OPERATIONS.decrement(); + REMOTE_INDEX_BUILD_CURRENT_MERGE_SIZE.decrementBy(size); + REMOTE_INDEX_BUILD_MERGE_TIME.incrementBy(time_in_millis); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java index aab5de32e0..9ef50e4e8a 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java @@ -8,7 +8,6 @@ import lombok.extern.log4j.Log4j2; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.metadata.RepositoryMetadata; -import org.opensearch.common.StopWatch; import org.opensearch.common.UUIDs; import org.opensearch.common.annotation.ExperimentalApi; import org.opensearch.common.blobstore.BlobPath; @@ -16,7 +15,6 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategy; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; -import org.opensearch.knn.index.codec.util.KNNCodecUtil; import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; import org.opensearch.knn.index.remote.RemoteIndexWaiter; import org.opensearch.knn.index.remote.RemoteIndexWaiterFactory; @@ -44,6 +42,7 @@ import static org.opensearch.knn.index.KNNSettings.KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING; import static org.opensearch.knn.index.KNNSettings.KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD_SETTING; import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_VECTOR_REPO_SETTING; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; /** * This class orchestrates building vector indices. It handles uploading data to a repository, submitting a remote @@ -57,14 +56,16 @@ public class RemoteIndexBuildStrategy implements NativeIndexBuildStrategy { private final NativeIndexBuildStrategy fallbackStrategy; private final IndexSettings indexSettings; private final KNNLibraryIndexingContext knnLibraryIndexingContext; + private final RemoteIndexBuildMetrics metrics; /** * Public constructor, intended to be called by {@link org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategyFactory} based in * part on the return value from {@link RemoteIndexBuildStrategy#shouldBuildIndexRemotely} - * @param repositoriesServiceSupplier A supplier for {@link RepositoriesService} used to interact with a repository - * @param fallbackStrategy Delegate {@link NativeIndexBuildStrategy} used to fall back to local build - * @param indexSettings {@link IndexSettings} used to retrieve information about the index - * @param knnLibraryIndexingContext {@link KNNLibraryIndexingContext} used to retrieve method specific params for the remote build request + * + * @param repositoriesServiceSupplier A supplier for {@link RepositoriesService} used to interact with a repository + * @param fallbackStrategy Delegate {@link NativeIndexBuildStrategy} used to fall back to local build + * @param indexSettings {@link IndexSettings} used to retrieve information about the index + * @param knnLibraryIndexingContext {@link KNNLibraryIndexingContext} used to retrieve method specific params for the remote build request */ public RemoteIndexBuildStrategy( Supplier repositoriesServiceSupplier, @@ -76,6 +77,7 @@ public RemoteIndexBuildStrategy( this.fallbackStrategy = fallbackStrategy; this.indexSettings = indexSettings; this.knnLibraryIndexingContext = knnLibraryIndexingContext; + this.metrics = new RemoteIndexBuildMetrics(); } /** @@ -122,62 +124,136 @@ public static boolean shouldBuildIndexRemotely(IndexSettings indexSettings, long * 3. Awaits on vector build to complete * 4. Downloads index file and writes to indexOutput * - * @param indexInfo - * @throws IOException + * @param indexInfo {@link BuildIndexParams} containing information about the index to be built + * @throws IOException if an error occurs during the build process */ @Override public void buildAndWriteIndex(BuildIndexParams indexInfo) throws IOException { - StopWatch stopWatch; - long time_in_millis; + metrics.startRemoteIndexBuildMetrics(indexInfo); + boolean success = false; + try { + RepositoryContext repositoryContext = getRepositoryContext(indexInfo); + + // 1. Write required data to repository + writeToRepository(repositoryContext, indexInfo); + + // 2. Trigger remote index build + RemoteIndexClient client = RemoteIndexClientFactory.getRemoteIndexClient(KNNSettings.getRemoteBuildServiceEndpoint()); + RemoteBuildResponse remoteBuildResponse = submitBuild(repositoryContext, indexInfo, client); + + // 3. Await vector build completion + RemoteBuildStatusResponse remoteBuildStatusResponse = awaitIndexBuild(remoteBuildResponse, indexInfo, client); + + // 4. Download index file and write to indexOutput + readFromRepository(indexInfo, repositoryContext, remoteBuildStatusResponse); + + success = true; + } catch (Exception e) { + fallbackStrategy.buildAndWriteIndex(indexInfo); + } finally { + metrics.endRemoteIndexBuildMetrics(success); + } + } + + /** + * Writes the required vector and doc ID data to the repository + */ + private void writeToRepository(RepositoryContext repositoryContext, BuildIndexParams indexInfo) throws IOException, + InterruptedException { + VectorRepositoryAccessor vectorRepositoryAccessor = repositoryContext.vectorRepositoryAccessor; + boolean success = false; + metrics.startRepositoryWriteMetrics(); try { - BlobStoreRepository repository = getRepository(); - BlobPath blobPath = repository.basePath().add(indexSettings.getUUID() + VECTORS_PATH); - VectorRepositoryAccessor vectorRepositoryAccessor = new DefaultVectorRepositoryAccessor( - repository.blobStore().blobContainer(blobPath) - ); - stopWatch = new StopWatch().start(); - // We create a new time based UUID per file in order to avoid conflicts across shards. It is also very difficult to get the - // shard id in this context. - String blobName = UUIDs.base64UUID() + "_" + indexInfo.getFieldName() + "_" + indexInfo.getSegmentWriteState().segmentInfo.name; vectorRepositoryAccessor.writeToRepository( - blobName, + repositoryContext.blobName, indexInfo.getTotalLiveDocs(), indexInfo.getVectorDataType(), indexInfo.getKnnVectorValuesSupplier() ); - time_in_millis = stopWatch.stop().totalTime().millis(); - log.debug("Repository write took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); + success = true; + } catch (InterruptedException | IOException e) { + log.debug("Repository write failed for vector field [{}]", indexInfo.getFieldName()); + throw e; + } finally { + metrics.endRepositoryWriteMetrics(success); + } + } - final RemoteIndexClient client = RemoteIndexClientFactory.getRemoteIndexClient(KNNSettings.getRemoteBuildServiceEndpoint()); + /** + * Submits a remote build request to the remote index build service + * @return RemoteBuildResponse containing the response from the remote service + */ + private RemoteBuildResponse submitBuild(RepositoryContext repositoryContext, BuildIndexParams indexInfo, RemoteIndexClient client) + throws IOException { + final RemoteBuildResponse remoteBuildResponse; + boolean success = false; + metrics.startBuildRequestMetrics(); + try { final RemoteBuildRequest buildRequest = buildRemoteBuildRequest( indexSettings, indexInfo, - repository.getMetadata(), - blobPath.buildAsString() + blobName, + repositoryContext.blobStoreRepository.getMetadata(), + repositoryContext.blobPath.buildAsString() + repositoryContext.blobName, knnLibraryIndexingContext.getLibraryParameters() ); - stopWatch = new StopWatch().start(); - final RemoteBuildResponse remoteBuildResponse = client.submitVectorBuild(buildRequest); - time_in_millis = stopWatch.stop().totalTime().millis(); - log.debug("Submit vector build took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); + remoteBuildResponse = client.submitVectorBuild(buildRequest); + success = true; + return remoteBuildResponse; + } catch (IOException e) { + log.debug("Submit vector build failed for vector field [{}]", indexInfo.getFieldName()); + throw e; + } finally { + metrics.endBuildRequestMetrics(success); + } + } + /** + * Awaits the vector build to complete + * @return RemoteBuildStatusResponse containing the completed status response from the remote service. + * This will only be returned with a COMPLETED_INDEX_BUILD status, otherwise the method will throw an exception. + */ + private RemoteBuildStatusResponse awaitIndexBuild( + RemoteBuildResponse remoteBuildResponse, + BuildIndexParams indexInfo, + RemoteIndexClient client + ) throws IOException, InterruptedException { + RemoteBuildStatusResponse remoteBuildStatusResponse; + metrics.startWaitingMetrics(); + try { final RemoteBuildStatusRequest remoteBuildStatusRequest = RemoteBuildStatusRequest.builder() .jobId(remoteBuildResponse.getJobId()) .build(); RemoteIndexWaiter waiter = RemoteIndexWaiterFactory.getRemoteIndexWaiter(client); - stopWatch = new StopWatch().start(); - RemoteBuildStatusResponse remoteBuildStatusResponse = waiter.awaitVectorBuild(remoteBuildStatusRequest); - time_in_millis = stopWatch.stop().totalTime().millis(); - log.debug("Await vector build took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); - - stopWatch = new StopWatch().start(); - vectorRepositoryAccessor.readFromRepository(remoteBuildStatusResponse.getFileName(), indexInfo.getIndexOutputWithBuffer()); - time_in_millis = stopWatch.stop().totalTime().millis(); - log.debug("Repository read took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); + remoteBuildStatusResponse = waiter.awaitVectorBuild(remoteBuildStatusRequest); + metrics.endWaitingMetrics(); + return remoteBuildStatusResponse; + } catch (InterruptedException | IOException e) { + log.debug("Await vector build failed for vector field [{}]", indexInfo.getFieldName(), e); + throw e; + } + } + + /** + * Downloads the index file from the repository and writes to the indexOutput + */ + private void readFromRepository( + BuildIndexParams indexInfo, + RepositoryContext repositoryContext, + RemoteBuildStatusResponse remoteBuildStatusResponse + ) throws IOException { + metrics.startRepositoryReadMetrics(); + boolean success = false; + try { + repositoryContext.vectorRepositoryAccessor.readFromRepository( + remoteBuildStatusResponse.getFileName(), + indexInfo.getIndexOutputWithBuffer() + ); + success = true; } catch (Exception e) { - // TODO: This needs more robust failure handling - log.warn("Failed to build index remotely", e); - fallbackStrategy.buildAndWriteIndex(indexInfo); + log.debug("Repository read failed for vector field [{}]", indexInfo.getFieldName()); + throw e; + } finally { + metrics.endRepositoryReadMetrics(success); } } @@ -197,6 +273,26 @@ private BlobStoreRepository getRepository() throws RepositoryMissingException { return (BlobStoreRepository) repository; } + /** + * Record to hold various repository related objects + */ + private record RepositoryContext(BlobStoreRepository blobStoreRepository, BlobPath blobPath, + VectorRepositoryAccessor vectorRepositoryAccessor, String blobName) { + } + + /** + * Helper method to get repository context. Generates a unique UUID for the blobName so should only be used once. + */ + private RepositoryContext getRepositoryContext(BuildIndexParams indexInfo) { + BlobStoreRepository repository = getRepository(); + BlobPath blobPath = repository.basePath().add(indexSettings.getUUID() + VECTORS_PATH); + String blobName = UUIDs.base64UUID() + "_" + indexInfo.getFieldName() + "_" + indexInfo.getSegmentWriteState().segmentInfo.name; + VectorRepositoryAccessor vectorRepositoryAccessor = new DefaultVectorRepositoryAccessor( + repository.blobStore().blobContainer(blobPath) + ); + return new RepositoryContext(repository, blobPath, vectorRepositoryAccessor, blobName); + } + /** * Constructor for RemoteBuildRequest. * @@ -225,7 +321,7 @@ static RemoteBuildRequest buildRemoteBuildRequest( String vectorDataType = indexInfo.getVectorDataType().getValue(); KNNVectorValues vectorValues = indexInfo.getKnnVectorValuesSupplier().get(); - KNNCodecUtil.initializeVectorValues(vectorValues); + initializeVectorValues(vectorValues); assert (vectorValues.dimension() > 0); return RemoteBuildRequest.builder() @@ -241,5 +337,4 @@ static RemoteBuildRequest buildRemoteBuildRequest( .indexParameters(indexInfo.getKnnEngine().createRemoteIndexingParameters(parameters)) .build(); } - } diff --git a/src/main/java/org/opensearch/knn/index/engine/EngineResolver.java b/src/main/java/org/opensearch/knn/index/engine/EngineResolver.java index 983adf28bf..662fd7dafb 100644 --- a/src/main/java/org/opensearch/knn/index/engine/EngineResolver.java +++ b/src/main/java/org/opensearch/knn/index/engine/EngineResolver.java @@ -5,8 +5,10 @@ package org.opensearch.knn.index.engine; +import com.google.common.annotations.VisibleForTesting; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.Version; import org.opensearch.knn.index.mapper.CompressionLevel; import org.opensearch.knn.index.mapper.Mode; @@ -22,47 +24,83 @@ public final class EngineResolver { private EngineResolver() {} + @VisibleForTesting + KNNEngine resolveEngine(KNNMethodConfigContext knnMethodConfigContext, KNNMethodContext knnMethodContext, boolean requiresTraining) { + return logAndReturnEngine(resolveKNNEngine(knnMethodConfigContext, knnMethodContext, requiresTraining, Version.CURRENT)); + } + /** * Based on the provided {@link Mode} and {@link CompressionLevel}, resolve to a {@link KNNEngine}. * * @param knnMethodConfigContext configuration context * @param knnMethodContext KNNMethodContext * @param requiresTraining whether config requires training + * @param version opensearch index version * @return {@link KNNEngine} */ public KNNEngine resolveEngine( KNNMethodConfigContext knnMethodConfigContext, KNNMethodContext knnMethodContext, - boolean requiresTraining + boolean requiresTraining, + Version version ) { - // User configuration gets precedence - if (knnMethodContext != null && knnMethodContext.isEngineConfigured()) { - return logAndReturnEngine(knnMethodContext.getKnnEngine()); + return logAndReturnEngine(resolveKNNEngine(knnMethodConfigContext, knnMethodContext, requiresTraining, version)); + } + + /** + * Based on the provided {@link Mode} and {@link CompressionLevel}, resolve to a {@link KNNEngine}. + * + * @param knnMethodConfigContext configuration context + * @param knnMethodContext KNNMethodContext + * @param requiresTraining whether config requires training + * @param version opensearch index version + * @return {@link KNNEngine} + */ + private KNNEngine resolveKNNEngine( + KNNMethodConfigContext knnMethodConfigContext, + KNNMethodContext knnMethodContext, + boolean requiresTraining, + Version version + ) { + // Check user configuration first + if (hasUserConfiguredEngine(knnMethodContext)) { + return knnMethodContext.getKnnEngine(); } - // Faiss is the only engine that supports training, so we default to faiss here for now + // Handle training case if (requiresTraining) { - return logAndReturnEngine(KNNEngine.FAISS); + // Faiss is the only engine that supports training, so we default to faiss here for now + return KNNEngine.FAISS; } Mode mode = knnMethodConfigContext.getMode(); CompressionLevel compressionLevel = knnMethodConfigContext.getCompressionLevel(); + // If both mode and compression are not specified, we can just default if (Mode.isConfigured(mode) == false && CompressionLevel.isConfigured(compressionLevel) == false) { - return logAndReturnEngine(KNNEngine.DEFAULT); + return KNNEngine.DEFAULT; } - // For 1x, we need to default to faiss if mode is provided and use nmslib otherwise + if (compressionLevel == CompressionLevel.x4) { + // Lucene is only engine that supports 4x - so we have to default to it here. + return KNNEngine.LUCENE; + } if (CompressionLevel.isConfigured(compressionLevel) == false || compressionLevel == CompressionLevel.x1) { - return logAndReturnEngine(mode == Mode.ON_DISK ? KNNEngine.FAISS : KNNEngine.NMSLIB); + // For 1x or no compression, we need to default to faiss if mode is provided and use nmslib otherwise based on version check + return resolveEngineForX1OrNoCompression(mode, version); } + return KNNEngine.FAISS; + } - // Lucene is only engine that supports 4x - so we have to default to it here. - if (compressionLevel == CompressionLevel.x4) { - return logAndReturnEngine(KNNEngine.LUCENE); - } + private boolean hasUserConfiguredEngine(KNNMethodContext knnMethodContext) { + return knnMethodContext != null && knnMethodContext.isEngineConfigured(); + } - return logAndReturnEngine(KNNEngine.FAISS); + private KNNEngine resolveEngineForX1OrNoCompression(Mode mode, Version version) { + if (version != null && version.onOrAfter(Version.V_2_19_0)) { + return KNNEngine.FAISS; + } + return mode == Mode.ON_DISK ? KNNEngine.FAISS : KNNEngine.NMSLIB; } private KNNEngine logAndReturnEngine(KNNEngine knnEngine) { diff --git a/src/main/java/org/opensearch/knn/index/mapper/EngineFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/EngineFieldMapper.java new file mode 100644 index 0000000000..47ec08d8c9 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/EngineFieldMapper.java @@ -0,0 +1,303 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.apache.lucene.document.Field; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.opensearch.Version; +import org.opensearch.common.Explicit; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.knn.index.DerivedKnnByteVectorField; +import org.opensearch.knn.index.DerivedKnnFloatVectorField; +import org.opensearch.knn.index.KNNVectorSimilarityFunction; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.VectorField; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.qframe.QuantizationConfig; +import org.opensearch.knn.index.engine.qframe.QuantizationConfigParser; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.QFRAMEWORK_CONFIG; +import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.buildDocValuesFieldType; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector; + +/** + * Field mapper for all supported engines. + */ +public class EngineFieldMapper extends KNNVectorFieldMapper { + + private final FieldType vectorFieldType; + private final PerDimensionProcessor perDimensionProcessor; + private final PerDimensionValidator perDimensionValidator; + private final VectorValidator vectorValidator; + private final VectorTransformer vectorTransformer; + private final boolean isLuceneEngine; + + public static EngineFieldMapper createFieldMapper( + String fullname, + String simpleName, + Map metaValue, + KNNMethodConfigContext knnMethodConfigContext, + MultiFields multiFields, + CopyTo copyTo, + Explicit ignoreMalformed, + boolean stored, + boolean hasDocValues, + OriginalMappingParameters originalMappingParameters + ) { + KNNMethodContext methodContext = originalMappingParameters.getResolvedKnnMethodContext(); + KNNLibraryIndexingContext libraryContext = methodContext.getKnnEngine() + .getKNNLibraryIndexingContext(methodContext, knnMethodConfigContext); + boolean isLuceneEngine = KNNEngine.LUCENE.equals(methodContext.getKnnEngine()); + + KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( + fullname, + metaValue, + knnMethodConfigContext.getVectorDataType(), + new KNNMappingConfig() { + @Override + public Optional getKnnMethodContext() { + return Optional.of(methodContext); + } + + @Override + public int getDimension() { + return knnMethodConfigContext.getDimension(); + } + + @Override + public Mode getMode() { + return Mode.fromName(originalMappingParameters.getMode()); + } + + @Override + public CompressionLevel getCompressionLevel() { + return knnMethodConfigContext.getCompressionLevel(); + } + + @Override + public Version getIndexCreatedVersion() { + return knnMethodConfigContext.getVersionCreated(); + } + + @Override + public QuantizationConfig getQuantizationConfig() { + return Optional.ofNullable(libraryContext) + .map(KNNLibraryIndexingContext::getQuantizationConfig) + .orElse(QuantizationConfig.EMPTY); + } + + @Override + public KNNLibraryIndexingContext getKnnLibraryIndexingContext() { + return libraryContext; + } + } + ); + + return new EngineFieldMapper( + simpleName, + mappedFieldType, + multiFields, + copyTo, + ignoreMalformed, + stored, + hasDocValues, + knnMethodConfigContext, + originalMappingParameters, + isLuceneEngine + ); + } + + private EngineFieldMapper( + String name, + KNNVectorFieldType mappedFieldType, + MultiFields multiFields, + CopyTo copyTo, + Explicit ignoreMalformed, + boolean stored, + boolean hasDocValues, + KNNMethodConfigContext knnMethodConfigContext, + OriginalMappingParameters originalMappingParameters, + boolean isLuceneEngine + ) { + super( + name, + mappedFieldType, + multiFields, + copyTo, + ignoreMalformed, + stored, + hasDocValues, + knnMethodConfigContext.getVersionCreated(), + originalMappingParameters + ); + this.isLuceneEngine = isLuceneEngine; + updateEngineStats(); + KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig(); + VectorDataType vectorDataType = mappedFieldType.getVectorDataType(); + KNNMethodContext resolvedKnnMethodContext = originalMappingParameters.getResolvedKnnMethodContext(); + + final KNNVectorSimilarityFunction knnVectorSimilarityFunction = resolvedKnnMethodContext.getSpaceType() + .getKnnVectorSimilarityFunction(); + KNNLibraryIndexingContext knnLibraryIndexingContext = resolvedKnnMethodContext.getKnnEngine() + .getKNNLibraryIndexingContext(resolvedKnnMethodContext, knnMethodConfigContext); + + // LuceneFieldMapper attributes + if (this.isLuceneEngine) { + this.fieldType = vectorDataType.createKnnVectorFieldType(knnMappingConfig.getDimension(), knnVectorSimilarityFunction); + + if (this.hasDocValues) { + this.vectorFieldType = buildDocValuesFieldType(resolvedKnnMethodContext.getKnnEngine()); + } else { + this.vectorFieldType = null; + } + this.vectorTransformer = null; + } else { + // MethodFieldMapper attributes + this.vectorFieldType = null; + this.useLuceneBasedVectorField = KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(indexCreatedVersion); + KNNEngine knnEngine = resolvedKnnMethodContext.getKnnEngine(); + QuantizationConfig quantizationConfig = knnLibraryIndexingContext.getQuantizationConfig(); + this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); + this.fieldType.putAttribute(DIMENSION, String.valueOf(knnMappingConfig.getDimension())); + this.fieldType.putAttribute(SPACE_TYPE, resolvedKnnMethodContext.getSpaceType().getValue()); + // Conditionally add quantization config + if (quantizationConfig != null && quantizationConfig != QuantizationConfig.EMPTY) { + this.fieldType.putAttribute(QFRAMEWORK_CONFIG, QuantizationConfigParser.toCsv(quantizationConfig)); + } + + this.fieldType.putAttribute(VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); + this.fieldType.putAttribute(KNN_ENGINE, knnEngine.getName()); + try { + this.fieldType.putAttribute( + PARAMETERS, + XContentFactory.jsonBuilder().map(knnLibraryIndexingContext.getLibraryParameters()).toString() + ); + } catch (IOException ioe) { + throw new RuntimeException(String.format("Unable to create KNNVectorFieldMapper: %s", ioe), ioe); + } + + if (useLuceneBasedVectorField) { + int adjustedDimension = mappedFieldType.vectorDataType == VectorDataType.BINARY + ? knnMappingConfig.getDimension() / 8 + : knnMappingConfig.getDimension(); + final VectorEncoding encoding = mappedFieldType.vectorDataType == VectorDataType.FLOAT + ? VectorEncoding.FLOAT32 + : VectorEncoding.BYTE; + final VectorSimilarityFunction similarityFunction = findBestMatchingVectorSimilarityFunction( + resolvedKnnMethodContext.getSpaceType() + ); + fieldType.setVectorAttributes(adjustedDimension, encoding, similarityFunction); + } else { + fieldType.setDocValuesType(DocValuesType.BINARY); + } + + this.fieldType.freeze(); + this.vectorTransformer = knnLibraryIndexingContext.getVectorTransformer(); + } + + // Common Attributes + this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor(); + this.perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator(); + this.vectorValidator = knnLibraryIndexingContext.getVectorValidator(); + } + + private VectorSimilarityFunction findBestMatchingVectorSimilarityFunction(final SpaceType spaceType) { + if (indexCreatedVersion.onOrAfter(Version.V_3_0_0)) { + // We need to find the best matching similarity function and not just save DEFAULT space type after 3.0. + // This is required for memory optimized search where utilizing .vec file to retrieve vectors. + // During the retrieval, it will locate similarity function from the meta info. Without this best effort, always the default + // similarity function will be used even when other space type is configured in a mapping. + // However, for keeping the backward compatibility, we only apply this to indices created after 3.0+. + try { + return spaceType.getKnnVectorSimilarityFunction().getVectorSimilarityFunction(); + } catch (Exception e) { + // ignore + } + } + + return SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction(); + } + + @Override + protected List getFieldsForFloatVector(final float[] array, boolean isDerivedSourceEnabled) { + if (this.isLuceneEngine) { + final List fields = new ArrayList<>(); + fields.add(new DerivedKnnFloatVectorField(name(), array, fieldType, isDerivedSourceEnabled)); + if (hasDocValues && vectorFieldType != null) { + fields.add(new VectorField(name(), array, vectorFieldType)); + } + if (stored) { + fields.add(createStoredFieldForFloatVector(name(), array)); + } + return fields; + } + return super.getFieldsForFloatVector(array, isDerivedSourceEnabled); + } + + @Override + protected List getFieldsForByteVector(final byte[] array, boolean isDerivedSourceEnabled) { + if (this.isLuceneEngine) { + final List fields = new ArrayList<>(); + fields.add(new DerivedKnnByteVectorField(name(), array, fieldType, isDerivedSourceEnabled)); + if (hasDocValues && vectorFieldType != null) { + fields.add(new VectorField(name(), array, vectorFieldType)); + } + if (stored) { + fields.add(createStoredFieldForByteVector(name(), array)); + } + return fields; + } + return super.getFieldsForByteVector(array, isDerivedSourceEnabled); + } + + @Override + protected VectorValidator getVectorValidator() { + return vectorValidator; + } + + @Override + protected PerDimensionValidator getPerDimensionValidator() { + return perDimensionValidator; + } + + @Override + protected PerDimensionProcessor getPerDimensionProcessor() { + return perDimensionProcessor; + } + + @Override + protected VectorTransformer getVectorTransformer() { + if (isLuceneEngine) { + return super.getVectorTransformer(); + } + return vectorTransformer; + } + + @Override + void updateEngineStats() { + Optional.ofNullable(originalMappingParameters) + .ifPresent(params -> params.getResolvedKnnMethodContext().getKnnEngine().setInitialized(true)); + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 319596c7da..ab62d9527a 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -95,7 +95,7 @@ public static class Builder extends ParametrizedFieldMapper.Builder { protected Boolean ignoreMalformed; protected final Parameter stored = Parameter.storeParam(m -> toType(m).stored, false); - protected final Parameter hasDocValues = Parameter.docValuesParam(m -> toType(m).hasDocValues, true); + protected Parameter hasDocValues; protected final Parameter dimension = new Parameter<>( KNNConstants.DIMENSION, false, @@ -216,6 +216,22 @@ public Builder( this.indexCreatedVersion = indexCreatedVersion; this.knnMethodConfigContext = knnMethodConfigContext; this.originalParameters = originalParameters; + /* + * For indices created on or after OpenSearch 3.0.0, docValues + * defaults to false when not explicitly configured. This reduces storage + * overhead and improves indexing performance for k-NN vector fields. + * Changing the default value breaks BwC for existing indices on a cluster. + * + * Behavior matrix: + * - Index < 3.0.0: Uses original default value + * - Index >= 3.0.0, docValues not configured: Sets to false + * - Any version, docValues explicitly configured: Respects configured value + */ + if (indexCreatedVersion.before(Version.V_3_0_0)) { + hasDocValues = Parameter.docValuesParam(m -> toType(m).hasDocValues, true); + } else { + hasDocValues = Parameter.docValuesParam(m -> toType(m).hasDocValues, false); + } } @Override @@ -273,10 +289,16 @@ public KNNVectorFieldMapper build(BuilderContext context) { ); } - // return FlatVectorFieldMapper only for indices that are created on or after 2.17.0, for others, use either LuceneFieldMapper - // or - // MethodFieldMapper to maintain backwards compatibility + // return FlatVectorFieldMapper only for indices that are created on or after 2.17.0, for others, use + // EngineFieldMapper to maintain backwards compatibility if (originalParameters.getResolvedKnnMethodContext() == null && indexCreatedVersion.onOrAfter(Version.V_2_17_0)) { + // Prior to 3.0.0, hasDocValues defaulted to false. However, FlatVectorFieldMapper requires + // hasDocValues to be true to maintain proper functionality for vector search operations. + // For indices created on or after 3.0.0, we automatically set hasDocValues to true if not + // explicitly configured to ensure consistent behavior. + if (indexCreatedVersion.onOrAfter(Version.V_3_0_0) && hasDocValues.isConfigured() == false) { + hasDocValues = Parameter.docValuesParam(m -> toType(m).hasDocValues, true); + } return FlatVectorFieldMapper.createFieldMapper( buildFullName(context), name, @@ -295,28 +317,7 @@ public KNNVectorFieldMapper build(BuilderContext context) { ); } - if (originalParameters.getResolvedKnnMethodContext().getKnnEngine() == KNNEngine.LUCENE) { - log.debug(String.format(Locale.ROOT, "Use [LuceneFieldMapper] mapper for field [%s]", name)); - LuceneFieldMapper.CreateLuceneFieldMapperInput createLuceneFieldMapperInput = LuceneFieldMapper.CreateLuceneFieldMapperInput - .builder() - .name(name) - .multiFields(multiFieldsBuilder) - .copyTo(copyToBuilder) - .ignoreMalformed(ignoreMalformed) - .stored(stored.getValue()) - .hasDocValues(hasDocValues.getValue()) - .originalKnnMethodContext(knnMethodContext.get()) - .build(); - return LuceneFieldMapper.createFieldMapper( - buildFullName(context), - metaValue, - knnMethodConfigContext, - createLuceneFieldMapperInput, - originalParameters - ); - } - - return MethodFieldMapper.createFieldMapper( + return EngineFieldMapper.createFieldMapper( buildFullName(context), name, metaValue, @@ -325,7 +326,7 @@ public KNNVectorFieldMapper build(BuilderContext context) { copyToBuilder, ignoreMalformed, stored.getValue(), - hasDocValues.getValue(), + hasDocValues.get(), originalParameters ); } @@ -566,7 +567,8 @@ private void resolveKNNMethodComponents( KNNEngine resolvedKNNEngine = EngineResolver.INSTANCE.resolveEngine( builder.knnMethodConfigContext, builder.originalParameters.getResolvedKnnMethodContext(), - false + false, + builder.indexCreatedVersion ); setEngine(builder.originalParameters.getResolvedKnnMethodContext(), resolvedKNNEngine); diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java deleted file mode 100644 index 4a2561785d..0000000000 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ /dev/null @@ -1,192 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.mapper; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Optional; - -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.NonNull; -import org.apache.lucene.document.Field; -import org.apache.lucene.document.FieldType; -import org.opensearch.Version; -import org.opensearch.common.Explicit; -import org.opensearch.knn.index.DerivedKnnByteVectorField; -import org.opensearch.knn.index.DerivedKnnFloatVectorField; -import org.opensearch.knn.index.KNNVectorSimilarityFunction; -import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.VectorField; -import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; -import org.opensearch.knn.index.engine.KNNMethodConfigContext; -import org.opensearch.knn.index.engine.KNNMethodContext; - -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.buildDocValuesFieldType; - -/** - * Field mapper for case when Lucene has been set as an engine. - */ -public class LuceneFieldMapper extends KNNVectorFieldMapper { - - /** FieldType used for initializing VectorField, which is used for creating binary doc values. **/ - private final FieldType vectorFieldType; - - private final PerDimensionProcessor perDimensionProcessor; - private final PerDimensionValidator perDimensionValidator; - private final VectorValidator vectorValidator; - - static LuceneFieldMapper createFieldMapper( - String fullname, - Map metaValue, - KNNMethodConfigContext knnMethodConfigContext, - CreateLuceneFieldMapperInput createLuceneFieldMapperInput, - OriginalMappingParameters originalMappingParameters - ) { - final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( - fullname, - metaValue, - knnMethodConfigContext.getVectorDataType(), - new KNNMappingConfig() { - @Override - public Optional getKnnMethodContext() { - return Optional.of(originalMappingParameters.getResolvedKnnMethodContext()); - } - - @Override - public int getDimension() { - return knnMethodConfigContext.getDimension(); - } - - @Override - public Mode getMode() { - return knnMethodConfigContext.getMode(); - } - - @Override - public CompressionLevel getCompressionLevel() { - return knnMethodConfigContext.getCompressionLevel(); - } - - @Override - public Version getIndexCreatedVersion() { - return knnMethodConfigContext.getVersionCreated(); - } - } - ); - - return new LuceneFieldMapper(mappedFieldType, createLuceneFieldMapperInput, knnMethodConfigContext, originalMappingParameters); - } - - private LuceneFieldMapper( - final KNNVectorFieldType mappedFieldType, - final CreateLuceneFieldMapperInput input, - KNNMethodConfigContext knnMethodConfigContext, - OriginalMappingParameters originalMappingParameters - ) { - super( - input.getName(), - mappedFieldType, - input.getMultiFields(), - input.getCopyTo(), - input.getIgnoreMalformed(), - input.isStored(), - input.isHasDocValues(), - knnMethodConfigContext.getVersionCreated(), - originalMappingParameters - ); - KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig(); - KNNMethodContext resolvedKnnMethodContext = originalMappingParameters.getResolvedKnnMethodContext(); - VectorDataType vectorDataType = mappedFieldType.getVectorDataType(); - - final KNNVectorSimilarityFunction knnVectorSimilarityFunction = resolvedKnnMethodContext.getSpaceType() - .getKnnVectorSimilarityFunction(); - - this.fieldType = vectorDataType.createKnnVectorFieldType(knnMappingConfig.getDimension(), knnVectorSimilarityFunction); - - if (this.hasDocValues) { - this.vectorFieldType = buildDocValuesFieldType(resolvedKnnMethodContext.getKnnEngine()); - } else { - this.vectorFieldType = null; - } - KNNLibraryIndexingContext knnLibraryIndexingContext = resolvedKnnMethodContext.getKnnEngine() - .getKNNLibraryIndexingContext(resolvedKnnMethodContext, knnMethodConfigContext); - this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor(); - this.perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator(); - this.vectorValidator = knnLibraryIndexingContext.getVectorValidator(); - } - - @Override - protected List getFieldsForFloatVector(final float[] array, boolean isDerivedSourceEnabled) { - final List fieldsToBeAdded = new ArrayList<>(); - fieldsToBeAdded.add(new DerivedKnnFloatVectorField(name(), array, fieldType, isDerivedSourceEnabled)); - - if (hasDocValues && vectorFieldType != null) { - fieldsToBeAdded.add(new VectorField(name(), array, vectorFieldType)); - } - - if (this.stored) { - fieldsToBeAdded.add(createStoredFieldForFloatVector(name(), array)); - } - return fieldsToBeAdded; - } - - @Override - protected List getFieldsForByteVector(final byte[] array, boolean isDerivedSourceEnabled) { - final List fieldsToBeAdded = new ArrayList<>(); - fieldsToBeAdded.add(new DerivedKnnByteVectorField(name(), array, fieldType, isDerivedSourceEnabled)); - - if (hasDocValues && vectorFieldType != null) { - fieldsToBeAdded.add(new VectorField(name(), array, vectorFieldType)); - } - - if (this.stored) { - fieldsToBeAdded.add(createStoredFieldForByteVector(name(), array)); - } - return fieldsToBeAdded; - } - - @Override - protected VectorValidator getVectorValidator() { - return vectorValidator; - } - - @Override - protected PerDimensionValidator getPerDimensionValidator() { - return perDimensionValidator; - } - - @Override - protected PerDimensionProcessor getPerDimensionProcessor() { - return perDimensionProcessor; - } - - @Override - void updateEngineStats() { - KNNEngine.LUCENE.setInitialized(true); - } - - @AllArgsConstructor - @lombok.Builder - @Getter - static class CreateLuceneFieldMapperInput { - @NonNull - String name; - @NonNull - MultiFields multiFields; - @NonNull - CopyTo copyTo; - @NonNull - Explicit ignoreMalformed; - boolean stored; - boolean hasDocValues; - KNNMethodContext originalKnnMethodContext; - } -} diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java deleted file mode 100644 index 484fd64577..0000000000 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.mapper; - -import org.apache.lucene.document.FieldType; -import org.apache.lucene.index.DocValuesType; -import org.apache.lucene.index.VectorEncoding; -import org.opensearch.Version; -import org.opensearch.common.Explicit; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; -import org.opensearch.knn.index.engine.KNNMethodConfigContext; -import org.opensearch.knn.index.engine.KNNMethodContext; -import org.opensearch.knn.index.engine.qframe.QuantizationConfig; -import org.opensearch.knn.index.engine.qframe.QuantizationConfigParser; - -import java.io.IOException; -import java.util.Map; -import java.util.Optional; - -import static org.opensearch.knn.common.KNNConstants.DIMENSION; -import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; -import static org.opensearch.knn.common.KNNConstants.PARAMETERS; -import static org.opensearch.knn.common.KNNConstants.QFRAMEWORK_CONFIG; -import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; -import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; - -/** - * Field mapper for method definition in mapping - */ -public class MethodFieldMapper extends KNNVectorFieldMapper { - - private final PerDimensionProcessor perDimensionProcessor; - private final PerDimensionValidator perDimensionValidator; - private final VectorValidator vectorValidator; - private final VectorTransformer vectorTransformer; - - public static MethodFieldMapper createFieldMapper( - String fullname, - String simpleName, - Map metaValue, - KNNMethodConfigContext knnMethodConfigContext, - MultiFields multiFields, - CopyTo copyTo, - Explicit ignoreMalformed, - boolean stored, - boolean hasDocValues, - OriginalMappingParameters originalMappingParameters - ) { - - KNNMethodContext knnMethodContext = originalMappingParameters.getResolvedKnnMethodContext(); - QuantizationConfig quantizationConfig = knnMethodContext.getKnnEngine() - .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext) - .getQuantizationConfig(); - KNNLibraryIndexingContext libraryIndexingContext = knnMethodContext.getKnnEngine() - .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); - - final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( - fullname, - metaValue, - knnMethodConfigContext.getVectorDataType(), - new KNNMappingConfig() { - @Override - public Optional getKnnMethodContext() { - return Optional.of(originalMappingParameters.getResolvedKnnMethodContext()); - } - - @Override - public int getDimension() { - return knnMethodConfigContext.getDimension(); - } - - @Override - public Mode getMode() { - return Mode.fromName(originalMappingParameters.getMode()); - } - - @Override - public CompressionLevel getCompressionLevel() { - return knnMethodConfigContext.getCompressionLevel(); - } - - @Override - public QuantizationConfig getQuantizationConfig() { - return quantizationConfig; - } - - @Override - public Version getIndexCreatedVersion() { - return knnMethodConfigContext.getVersionCreated(); - } - - @Override - public KNNLibraryIndexingContext getKnnLibraryIndexingContext() { - return libraryIndexingContext; - } - } - ); - return new MethodFieldMapper( - simpleName, - mappedFieldType, - multiFields, - copyTo, - ignoreMalformed, - stored, - hasDocValues, - knnMethodConfigContext, - originalMappingParameters - ); - } - - private MethodFieldMapper( - String simpleName, - KNNVectorFieldType mappedFieldType, - MultiFields multiFields, - CopyTo copyTo, - Explicit ignoreMalformed, - boolean stored, - boolean hasDocValues, - KNNMethodConfigContext knnMethodConfigContext, - OriginalMappingParameters originalMappingParameters - ) { - - super( - simpleName, - mappedFieldType, - multiFields, - copyTo, - ignoreMalformed, - stored, - hasDocValues, - knnMethodConfigContext.getVersionCreated(), - originalMappingParameters - ); - this.useLuceneBasedVectorField = KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(indexCreatedVersion); - KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig(); - KNNMethodContext resolvedKnnMethodContext = originalMappingParameters.getResolvedKnnMethodContext(); - KNNEngine knnEngine = resolvedKnnMethodContext.getKnnEngine(); - KNNLibraryIndexingContext knnLibraryIndexingContext = knnEngine.getKNNLibraryIndexingContext( - resolvedKnnMethodContext, - knnMethodConfigContext - ); - QuantizationConfig quantizationConfig = knnLibraryIndexingContext.getQuantizationConfig(); - - this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); - this.fieldType.putAttribute(DIMENSION, String.valueOf(knnMappingConfig.getDimension())); - this.fieldType.putAttribute(SPACE_TYPE, resolvedKnnMethodContext.getSpaceType().getValue()); - // Conditionally add quantization config - if (quantizationConfig != null && quantizationConfig != QuantizationConfig.EMPTY) { - this.fieldType.putAttribute(QFRAMEWORK_CONFIG, QuantizationConfigParser.toCsv(quantizationConfig)); - } - - this.fieldType.putAttribute(VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); - this.fieldType.putAttribute(KNN_ENGINE, knnEngine.getName()); - try { - this.fieldType.putAttribute( - PARAMETERS, - XContentFactory.jsonBuilder().map(knnLibraryIndexingContext.getLibraryParameters()).toString() - ); - } catch (IOException ioe) { - throw new RuntimeException(String.format("Unable to create KNNVectorFieldMapper: %s", ioe)); - } - - if (useLuceneBasedVectorField) { - int adjustedDimension = mappedFieldType.vectorDataType == VectorDataType.BINARY - ? knnMappingConfig.getDimension() / 8 - : knnMappingConfig.getDimension(); - final VectorEncoding encoding = mappedFieldType.vectorDataType == VectorDataType.FLOAT - ? VectorEncoding.FLOAT32 - : VectorEncoding.BYTE; - fieldType.setVectorAttributes( - adjustedDimension, - encoding, - SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction() - ); - } else { - fieldType.setDocValuesType(DocValuesType.BINARY); - } - - this.fieldType.freeze(); - this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor(); - this.perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator(); - this.vectorValidator = knnLibraryIndexingContext.getVectorValidator(); - this.vectorTransformer = knnLibraryIndexingContext.getVectorTransformer(); - } - - @Override - protected VectorValidator getVectorValidator() { - return vectorValidator; - } - - @Override - protected PerDimensionValidator getPerDimensionValidator() { - return perDimensionValidator; - } - - @Override - protected PerDimensionProcessor getPerDimensionProcessor() { - return perDimensionProcessor; - } - - @Override - protected VectorTransformer getVectorTransformer() { - return vectorTransformer; - } -} diff --git a/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java index 4db11996c6..36424a45d0 100644 --- a/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java @@ -10,17 +10,27 @@ import lombok.Getter; import lombok.NonNull; import lombok.extern.log4j.Log4j2; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.ToChildBlockJoinQuery; +import org.opensearch.index.mapper.ObjectMapper; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.search.NestedHelper; +import org.opensearch.index.search.OpenSearchToParentBlockJoinQuery; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.query.rescore.RescoreContext; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Deque; +import java.util.LinkedList; +import java.util.List; import java.util.Map; import java.util.Optional; @@ -90,20 +100,73 @@ protected static Query getFilterQuery(BaseQueryFactory.CreateQueryRequest create createQueryRequest.getFieldName() ) ); + + // preserve nestedStack + Deque nestedLevelStack = new LinkedList<>(); + ObjectMapper objectMapper = null; + if (queryShardContext.nestedScope() != null) { + while ((objectMapper = queryShardContext.nestedScope().getObjectMapper()) != null) { + nestedLevelStack.push(objectMapper); + queryShardContext.nestedScope().previousLevel(); + } + } + final Query filterQuery; try { filterQuery = createQueryRequest.getFilter().get().toQuery(queryShardContext); } catch (IOException e) { throw new RuntimeException("Cannot create query with filter", e); + } finally { + while ((objectMapper = nestedLevelStack.peek()) != null) { + queryShardContext.nestedScope().nextLevel(objectMapper); + nestedLevelStack.pop(); + } } BitSetProducer parentFilter = queryShardContext.getParentFilter(); if (parentFilter != null) { boolean mightMatch = new NestedHelper(queryShardContext.getMapperService()).mightMatchNestedDocs(filterQuery); if (mightMatch) { return filterQuery; + } else if (filterQuery instanceof OpenSearchToParentBlockJoinQuery) { + // this case would happen when path = null, and filter is nested + return ((OpenSearchToParentBlockJoinQuery) filterQuery).getChildQuery(); + } else if (filterQuery instanceof BooleanQuery) { + KNNQueryVisitor knnQueryVisitor = new KNNQueryVisitor(); + filterQuery.visit(knnQueryVisitor); + BooleanQuery.Builder builder = (new BooleanQuery.Builder()).add( + new ToChildBlockJoinQuery(filterQuery, parentFilter), + BooleanClause.Occur.FILTER + ); + for (Query q : knnQueryVisitor.nestedQuery) { + builder.add(q, BooleanClause.Occur.FILTER); + } + return builder.build(); } return new ToChildBlockJoinQuery(filterQuery, parentFilter); } return filterQuery; } + + @Getter + static class KNNQueryVisitor extends QueryVisitor { + List nestedQuery; + + public KNNQueryVisitor() { + nestedQuery = new ArrayList<>(); + } + + public QueryVisitor getSubVisitor(BooleanClause.Occur occur, Query parent) { + if (parent instanceof BooleanQuery && occur == BooleanClause.Occur.FILTER) { + Collection collection = ((BooleanQuery) parent).getClauses(BooleanClause.Occur.FILTER); + for (Query q : collection) { + if (q instanceof OpenSearchToParentBlockJoinQuery) { + nestedQuery.add(((OpenSearchToParentBlockJoinQuery) q).getChildQuery()); + } else { + q.visit(this); + } + } + } + return this; + } + } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java index 50ad8ba701..3775166ba6 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java @@ -53,6 +53,9 @@ public class KNNQuery extends Query { private BitSetProducer parentsFilter; private Float radius; private Context context; + @Setter + @Getter + private boolean explain; // Note: ideally query should not have to deal with shard level information. Adding it for logging purposes only // TODO: ThreadContext does not work with logger, remove this from here once its figured out diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 25416ccd4d..ba43b9ffb2 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -38,6 +38,7 @@ import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.query.ExactSearcher.ExactSearcherContext.ExactSearcherContextBuilder; +import org.opensearch.knn.index.query.explain.KnnExplanation; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; @@ -77,6 +78,7 @@ public class KNNWeight extends Weight { private static ExactSearcher DEFAULT_EXACT_SEARCHER; private final QuantizationService quantizationService; + private final KnnExplanation knnExplanation; public KNNWeight(KNNQuery query, float boost) { super(query); @@ -86,6 +88,7 @@ public KNNWeight(KNNQuery query, float boost) { this.filterWeight = null; this.exactSearcher = DEFAULT_EXACT_SEARCHER; this.quantizationService = QuantizationService.getInstance(); + this.knnExplanation = new KnnExplanation(); } public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { @@ -96,6 +99,7 @@ public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { this.filterWeight = filterWeight; this.exactSearcher = DEFAULT_EXACT_SEARCHER; this.quantizationService = QuantizationService.getInstance(); + this.knnExplanation = new KnnExplanation(); } public static void initialize(ModelDao modelDao) { @@ -108,9 +112,148 @@ static void initialize(ModelDao modelDao, ExactSearcher exactSearcher) { KNNWeight.DEFAULT_EXACT_SEARCHER = exactSearcher; } + @VisibleForTesting + KnnExplanation getKnnExplanation() { + return knnExplanation; + } + @Override + // This method is called in case of Radial-Search public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.match(1.0f, "No Explanation"); + return explain(context, doc, 0); + } + + // This method is called for ANN/Exact/Disk-based/Efficient-filtering search + public Explanation explain(LeafReaderContext context, int doc, float score) { + knnQuery.setExplain(true); + try { + final KNNScorer knnScorer = getOrCreateKnnScorer(context); + if (score == 0) { + score = getKnnScore(knnScorer, doc); + } + } catch (IOException e) { + throw new RuntimeException(String.format("Error while explaining KNN score for doc [%d], score [%f]", doc, score), e); + } + final String highLevelExplanation = getHighLevelExplanation(); + final StringBuilder leafLevelExplanation = getLeafLevelExplanation(context); + + final SegmentReader reader = Lucene.segmentReader(context.reader()); + final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField()); + if (fieldInfo == null) { + return Explanation.match(score, highLevelExplanation, Explanation.match(score, leafLevelExplanation.toString())); + } + final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo); + leafLevelExplanation.append(", spaceType = ").append(spaceType.getValue()); + + final Float rawScore = knnExplanation.getRawScore(doc); + Explanation rawScoreDetail = null; + if (rawScore != null && knnQuery.getRescoreContext() == null) { + leafLevelExplanation.append(" where score is computed as ") + .append(spaceType.explainScoreTranslation(rawScore)) + .append(" from:"); + rawScoreDetail = Explanation.match( + rawScore, + "rawScore, returned from " + FieldInfoExtractor.extractKNNEngine(fieldInfo) + " library" + ); + } + + return rawScoreDetail != null + ? Explanation.match(score, highLevelExplanation, Explanation.match(score, leafLevelExplanation.toString(), rawScoreDetail)) + : Explanation.match(score, highLevelExplanation, Explanation.match(score, leafLevelExplanation.toString())); + } + + private StringBuilder getLeafLevelExplanation(LeafReaderContext context) { + int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()); + int cardinality = knnExplanation.getCardinality(); + final StringBuilder sb = new StringBuilder("the type of knn search executed at leaf was "); + if (filterWeight != null) { + if (isFilterIdCountLessThanK(cardinality)) { + sb.append(KNNConstants.EXACT_SEARCH) + .append(" since filteredIds = ") + .append(cardinality) + .append(" is less than or equal to K = ") + .append(knnQuery.getK()); + } else if (isExactSearchThresholdSettingSet(filterThresholdValue) && (filterThresholdValue >= cardinality)) { + sb.append(KNNConstants.EXACT_SEARCH) + .append(" since filtered threshold value = ") + .append(filterThresholdValue) + .append(" is greater than or equal to cardinality = ") + .append(cardinality); + } else if (!isExactSearchThresholdSettingSet(filterThresholdValue) && isMDCGreaterThanFilterIdCnt(cardinality)) { + sb.append(KNNConstants.EXACT_SEARCH) + .append(" since max distance computation = ") + .append(KNNConstants.MAX_DISTANCE_COMPUTATIONS) + .append(" is greater than or equal to cardinality = ") + .append(cardinality); + } + } + final Integer annResult = knnExplanation.getAnnResult(context.id()); + if (annResult != null && annResult == 0 && isMissingNativeEngineFiles(context)) { + sb.append(KNNConstants.EXACT_SEARCH).append(" since no native engine files are available"); + } + if (annResult != null && isFilteredExactSearchRequireAfterANNSearch(cardinality, annResult)) { + sb.append(KNNConstants.EXACT_SEARCH) + .append(" since the number of documents returned are less than K = ") + .append(knnQuery.getK()) + .append(" and there are more than K filtered Ids = ") + .append(cardinality); + } + if (annResult != null && annResult > 0 && !isFilteredExactSearchRequireAfterANNSearch(cardinality, annResult)) { + sb.append(KNNConstants.ANN_SEARCH); + } + sb.append(" with vectorDataType = ").append(knnQuery.getVectorDataType()); + return sb; + } + + private String getHighLevelExplanation() { + final StringBuilder sb = new StringBuilder("the type of knn search executed was "); + if (knnQuery.getRescoreContext() != null) { + sb.append(buildDiskBasedSearchExplanation()); + } else if (knnQuery.getRadius() != null) { + sb.append(KNNConstants.RADIAL_SEARCH).append(" with the radius of ").append(knnQuery.getRadius()); + } else { + sb.append(KNNConstants.ANN_SEARCH); + } + return sb.toString(); + } + + private String buildDiskBasedSearchExplanation() { + StringBuilder sb = new StringBuilder(KNNConstants.DISK_BASED_SEARCH); + boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName()); + if (!knnQuery.getRescoreContext().isRescoreEnabled()) { + isShardLevelRescoringDisabled = true; + } + int dimension = knnQuery.getQueryVector().length; + int firstPassK = knnQuery.getRescoreContext().getFirstPassK(knnQuery.getK(), isShardLevelRescoringDisabled, dimension); + sb.append(" and the first pass k was ") + .append(firstPassK) + .append(" with vector dimension of ") + .append(dimension) + .append(", over sampling factor of ") + .append(knnQuery.getRescoreContext().getOversampleFactor()); + if (isShardLevelRescoringDisabled) { + sb.append(", shard level rescoring disabled"); + } else { + sb.append(", shard level rescoring enabled"); + } + return sb.toString(); + } + + private KNNScorer getOrCreateKnnScorer(LeafReaderContext context) throws IOException { + // First try to get the cached scorer + KNNScorer scorer = knnExplanation.getKnnScorer(context); + + // If no cached scorer exists, create and cache a new one + if (scorer == null) { + scorer = (KNNScorer) scorer(context); + knnExplanation.addKnnScorer(context, scorer); + } + + return scorer; + } + + private float getKnnScore(KNNScorer knnScorer, int doc) throws IOException { + return (knnScorer.iterator().advance(doc) == doc) ? knnScorer.score() : 0; } @Override @@ -161,6 +304,9 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep if (filterWeight != null && cardinality == 0) { return PerLeafResult.EMPTY_RESULT; } + if (knnQuery.isExplain()) { + knnExplanation.setCardinality(cardinality); + } /* * The idea for this optimization is to get K results, we need to at least look at K vectors in the HNSW graph * . Hence, if filtered results are less than K and filter query is present we should shift to exact search. @@ -180,7 +326,9 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep StopWatch annStopWatch = startStopWatch(); final Map docIdsToScoreMap = doANNSearch(reader, context, annFilter, cardinality, k); stopStopWatchAndLog(annStopWatch, "ANN search", segmentName); - + if (knnQuery.isExplain()) { + knnExplanation.addLeafResult(context.id(), docIdsToScoreMap.size()); + } // See whether we have to perform exact search based on approx search results // This is required if there are no native engine files or if approximate search returned // results less than K, though we have more than k filtered docs @@ -417,6 +565,15 @@ private Map doANNSearch( log.debug("[KNN] Query yielded 0 results"); return Collections.emptyMap(); } + if (knnQuery.isExplain()) { + Arrays.stream(results).forEach(result -> { + if (KNNEngine.FAISS.getName().equals(knnEngine.getName()) && SpaceType.INNER_PRODUCT.equals(spaceType)) { + knnExplanation.addRawScore(result.getId(), -1 * result.getScore()); + } else { + knnExplanation.addRawScore(result.getId(), result.getScore()); + } + }); + } if (quantizedVector != null) { return Arrays.stream(results) @@ -463,12 +620,13 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) { ); int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()); // Refer this GitHub around more details https://github.com/opensearch-project/k-NN/issues/1049 on the logic - if (knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK()) { - return true; - } + if (isFilterIdCountLessThanK(filterIdsCount)) return true; // See user has defined Exact Search filtered threshold. if yes, then use that setting. if (isExactSearchThresholdSettingSet(filterThresholdValue)) { - return filterThresholdValue >= filterIdsCount; + if (filterThresholdValue >= filterIdsCount) { + return true; + } + return false; } // if no setting is set, then use the default max distance computation value to see if we can do exact search. @@ -476,7 +634,7 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) { * TODO we can have a different MAX_DISTANCE_COMPUTATIONS for binary index as computation cost for binary index * is cheaper than computation cost for non binary vector */ - return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * getQueryVectorLength(); + return isMDCGreaterThanFilterIdCnt(filterIdsCount); } /** @@ -495,6 +653,16 @@ private int getQueryVectorLength() { ); } + private boolean isMDCGreaterThanFilterIdCnt(int filterIdsCount) { + return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * (knnQuery.getVectorDataType() == VectorDataType.FLOAT + ? knnQuery.getQueryVector().length + : knnQuery.getByteQueryVector().length); + } + + private boolean isFilterIdCountLessThanK(int filterIdsCount) { + return knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK(); + } + /** * This function validates if {@link KNNSettings#ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD} is set or not. This * is done by validating if the setting value is equal to the default value. diff --git a/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationUtil.java b/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationUtil.java index 46db8bb6b4..a381409898 100644 --- a/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationUtil.java +++ b/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationUtil.java @@ -9,6 +9,8 @@ import org.apache.lucene.index.LeafReader; import org.opensearch.knn.index.codec.KNN990Codec.QuantizationConfigKNNCollector; import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.profiler.SegmentProfileKNNCollector; +import org.opensearch.knn.profiler.SegmentProfilerState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; @@ -57,4 +59,23 @@ static QuantizationState getQuantizationState(final LeafReader leafReader, Strin } return tempCollector.getQuantizationState(); } + + /** + * A utility function to get {@link SegmentProfilerState} for a given segment and field. + * This needs to public as we are accessing this on a transport action + * TODO: move this out of this Util class and into another one. + * @param leafReader {@link LeafReader} + * @param fieldName {@link String} + * @return {@link SegmentProfilerState} + * @throws IOException exception during reading the {@link SegmentProfilerState} + */ + public static SegmentProfilerState getSegmentProfileState(final LeafReader leafReader, String fieldName) throws IOException { + final SegmentProfileKNNCollector tempCollector = new SegmentProfileKNNCollector(); + leafReader.searchNearestVectors(fieldName, new float[0], tempCollector, null); + if (tempCollector.getSegmentProfilerState() == null) { + throw new IllegalStateException(String.format(Locale.ROOT, "No segment state found for field %s", fieldName)); + } + return tempCollector.getSegmentProfilerState(); + } + } diff --git a/src/main/java/org/opensearch/knn/index/query/SegmentProfilerUtil.java b/src/main/java/org/opensearch/knn/index/query/SegmentProfilerUtil.java new file mode 100644 index 0000000000..0aaf925331 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/SegmentProfilerUtil.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import lombok.experimental.UtilityClass; +import org.apache.lucene.index.LeafReader; +import org.opensearch.knn.profiler.SegmentProfileKNNCollector; +import org.opensearch.knn.profiler.SegmentProfilerState; + +import java.io.IOException; +import java.util.Locale; + +/** + * Utility class to get segment profiler state for a given field + */ +@UtilityClass +public class SegmentProfilerUtil { + + /** + * Gets the segment profile state for a given field + * @param leafReader The leaf reader to query + * @param fieldName The field name to profile + * @return The segment profiler state + * @throws IOException If there's an error reading the segment + */ + public static SegmentProfilerState getSegmentProfileState(final LeafReader leafReader, String fieldName) throws IOException { + final SegmentProfileKNNCollector tempCollector = new SegmentProfileKNNCollector(); + leafReader.searchNearestVectors(fieldName, new float[0], tempCollector, null); + if (tempCollector.getSegmentProfilerState() == null) { + throw new IllegalStateException(String.format(Locale.ROOT, "No segment state found for field %s", fieldName)); + } + return tempCollector.getSegmentProfilerState(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java b/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java index 0aee8dd8b7..88cd103d28 100644 --- a/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java @@ -16,6 +16,7 @@ import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; import org.opensearch.knn.index.query.KNNScorer; +import org.opensearch.knn.index.query.KNNWeight; import java.io.IOException; import java.util.Arrays; @@ -30,13 +31,15 @@ final class DocAndScoreQuery extends Query { private final float[] scores; private final int[] segmentStarts; private final Object contextIdentity; + private final KNNWeight knnWeight; - public DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) { + public DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity, KNNWeight knnWeight) { this.k = k; this.docs = docs; this.scores = scores; this.segmentStarts = segmentStarts; this.contextIdentity = contextIdentity; + this.knnWeight = knnWeight; } @Override @@ -52,7 +55,19 @@ public Explanation explain(LeafReaderContext context, int doc) { if (found < 0) { return Explanation.noMatch("not in top " + k); } - return Explanation.match(scores[found] * boost, "within top " + k); + float score = 0; + try { + final Scorer scorer = scorer(context); + assert scorer != null; + int resDoc = scorer.iterator().advance(doc); + if (resDoc == doc) { + score = scorer.score(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + + return knnWeight.explain(context, doc, score); } @Override diff --git a/src/main/java/org/opensearch/knn/index/query/common/QueryUtils.java b/src/main/java/org/opensearch/knn/index/query/common/QueryUtils.java index 5fc0fb077e..ce823c229f 100644 --- a/src/main/java/org/opensearch/knn/index/query/common/QueryUtils.java +++ b/src/main/java/org/opensearch/knn/index/query/common/QueryUtils.java @@ -18,6 +18,7 @@ import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; +import org.opensearch.knn.index.query.KNNWeight; import org.opensearch.knn.index.query.iterators.GroupedNestedDocIdSetIterator; import java.io.IOException; @@ -46,6 +47,10 @@ public class QueryUtils { * @return a query representing the given TopDocs */ public Query createDocAndScoreQuery(final IndexReader reader, final TopDocs topDocs) { + return createDocAndScoreQuery(reader, topDocs, null); + } + + public Query createDocAndScoreQuery(final IndexReader reader, final TopDocs topDocs, final KNNWeight knnWeight) { int len = topDocs.scoreDocs.length; Arrays.sort(topDocs.scoreDocs, Comparator.comparingInt(a -> a.doc)); int[] docs = new int[len]; @@ -55,7 +60,7 @@ public Query createDocAndScoreQuery(final IndexReader reader, final TopDocs topD scores[i] = topDocs.scoreDocs[i].score; } int[] segmentStarts = findSegmentStarts(reader, docs); - return new DocAndScoreQuery(len, docs, scores, segmentStarts, reader.getContext().id()); + return new DocAndScoreQuery(len, docs, scores, segmentStarts, reader.getContext().id(), knnWeight); } private int[] findSegmentStarts(final IndexReader reader, final int[] docs) { diff --git a/src/main/java/org/opensearch/knn/index/query/explain/KnnExplanation.java b/src/main/java/org/opensearch/knn/index/query/explain/KnnExplanation.java new file mode 100644 index 0000000000..7241f93a16 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/explain/KnnExplanation.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.explain; + +import lombok.Getter; +import lombok.Setter; +import org.opensearch.knn.index.query.KNNScorer; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * This class captures details around knn explain queries that is used + * by explain API to generate explanation for knn queries + */ +public class KnnExplanation { + + private final Map annResultPerLeaf; + + private final Map rawScores; + + private final Map knnScorerPerLeaf; + + @Setter + @Getter + private int cardinality; + + public KnnExplanation() { + this.annResultPerLeaf = new ConcurrentHashMap<>(); + this.rawScores = new ConcurrentHashMap<>(); + this.knnScorerPerLeaf = new ConcurrentHashMap<>(); + this.cardinality = 0; + } + + public void addLeafResult(Object leafId, int annResult) { + this.annResultPerLeaf.put(leafId, annResult); + } + + public void addRawScore(int docId, float rawScore) { + this.rawScores.put(docId, rawScore); + } + + public void addKnnScorer(Object leafId, KNNScorer knnScorer) { + this.knnScorerPerLeaf.put(leafId, knnScorer); + } + + public Integer getAnnResult(Object leafId) { + return this.annResultPerLeaf.get(leafId); + } + + public Float getRawScore(int docId) { + return this.rawScores.get(docId); + } + + public KNNScorer getKnnScorer(Object leafId) { + return this.knnScorerPerLeaf.get(leafId); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index 1ffaa804db..ba16cd5011 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -98,7 +98,7 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo if (topK.scoreDocs.length == 0) { return new MatchNoDocsQuery().createWeight(indexSearcher, scoreMode, boost); } - return queryUtils.createDocAndScoreQuery(reader, topK).createWeight(indexSearcher, scoreMode, boost); + return queryUtils.createDocAndScoreQuery(reader, topK, knnWeight).createWeight(indexSearcher, scoreMode, boost); } /** diff --git a/src/main/java/org/opensearch/knn/index/remote/RemoteIndexPoller.java b/src/main/java/org/opensearch/knn/index/remote/RemoteIndexPoller.java index 2de35915d8..4f8fb35013 100644 --- a/src/main/java/org/opensearch/knn/index/remote/RemoteIndexPoller.java +++ b/src/main/java/org/opensearch/knn/index/remote/RemoteIndexPoller.java @@ -17,6 +17,8 @@ import static org.opensearch.knn.index.KNNSettings.getRemoteBuildClientPollInterval; import static org.opensearch.knn.index.KNNSettings.getRemoteBuildClientTimeout; +import static org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue.STATUS_REQUEST_FAILURE_COUNT; +import static org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue.STATUS_REQUEST_SUCCESS_COUNT; /** * Implementation of a {@link RemoteIndexWaiter} that awaits the vector build by polling. @@ -63,7 +65,14 @@ public RemoteBuildStatusResponse awaitVectorBuild(RemoteBuildStatusRequest remot sleepWithJitter(pollInterval * INITIAL_DELAY_FACTOR); while (System.nanoTime() - startTime < timeout) { - RemoteBuildStatusResponse remoteBuildStatusResponse = client.getBuildStatus(remoteBuildStatusRequest); + RemoteBuildStatusResponse remoteBuildStatusResponse; + try { + remoteBuildStatusResponse = client.getBuildStatus(remoteBuildStatusRequest); + } catch (IOException e) { + STATUS_REQUEST_FAILURE_COUNT.increment(); + throw e; + } + STATUS_REQUEST_SUCCESS_COUNT.increment(); String taskStatus = remoteBuildStatusResponse.getTaskStatus(); if (StringUtils.isBlank(taskStatus)) { throw new IOException(String.format("Invalid response format, missing %s", TASK_STATUS)); diff --git a/src/main/java/org/opensearch/knn/jni/PlatformUtils.java b/src/main/java/org/opensearch/knn/jni/PlatformUtils.java index a67a88487e..0ec544a1fd 100644 --- a/src/main/java/org/opensearch/knn/jni/PlatformUtils.java +++ b/src/main/java/org/opensearch/knn/jni/PlatformUtils.java @@ -15,6 +15,7 @@ import org.apache.commons.lang.StringUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; + import oshi.util.platform.mac.SysctlUtil; import java.nio.file.Files; @@ -27,9 +28,18 @@ import java.util.stream.Stream; public class PlatformUtils { - private static final Logger logger = LogManager.getLogger(PlatformUtils.class); + private static volatile Boolean isAVX2Supported; + private static volatile Boolean isAVX512Supported; + private static volatile Boolean isAVX512SPRSupported; + + static void reset() { + isAVX2Supported = null; + isAVX512Supported = null; + isAVX512SPRSupported = null; + } + /** * Verify if the underlying system supports AVX2 SIMD Optimization or not * 1. If the architecture is not x86 return false. @@ -41,22 +51,26 @@ public class PlatformUtils { */ public static boolean isAVX2SupportedBySystem() { if (!Platform.isIntel() || Platform.isWindows()) { - return false; + isAVX2Supported = false; } - if (Platform.isMac()) { + if (isAVX2Supported != null) { + return isAVX2Supported; + } + if (Platform.isMac()) { // sysctl or system control retrieves system info and allows processes with appropriate privileges // to set system info. This system info contains the machine dependent cpu features that are supported by it. // On MacOS, if the underlying processor supports AVX2 instruction set, it will be listed under the "leaf7" // subset of instructions ("sysctl -a | grep machdep.cpu.leaf7_features"). // https://developer.apple.com/library/archive/documentation/System/Conceptual/ManPages_iPhoneOS/man3/sysctl.3.html try { - return AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + isAVX2Supported = AccessController.doPrivileged((PrivilegedExceptionAction) () -> { String flags = SysctlUtil.sysctl("machdep.cpu.leaf7_features", "empty"); return (flags.toLowerCase(Locale.ROOT)).contains("avx2"); }); } catch (Exception e) { + isAVX2Supported = false; logger.error("[KNN] Error fetching cpu flags info. [{}]", e.getMessage(), e); } @@ -70,25 +84,32 @@ public static boolean isAVX2SupportedBySystem() { // https://ark.intel.com/content/www/us/en/ark/products/199285/intel-pentium-gold-g6600-processor-4m-cache-4-20-ghz.html String fileName = "/proc/cpuinfo"; try { - return AccessController.doPrivileged( + isAVX2Supported = AccessController.doPrivileged( (PrivilegedExceptionAction) () -> (Boolean) Files.lines(Paths.get(fileName)) .filter(s -> s.startsWith("flags")) .anyMatch(s -> StringUtils.containsIgnoreCase(s, "avx2")) ); } catch (Exception e) { + isAVX2Supported = false; logger.error("[KNN] Error reading file [{}]. [{}]", fileName, e.getMessage(), e); } } - return false; + return isAVX2Supported; } public static boolean isAVX512SupportedBySystem() { - return areAVX512FlagsAvailable(new String[] { "avx512f", "avx512cd", "avx512vl", "avx512dq", "avx512bw" }); + if (isAVX512Supported == null) { + isAVX512Supported = areAVX512FlagsAvailable(new String[] { "avx512f", "avx512cd", "avx512vl", "avx512dq", "avx512bw" }); + } + return isAVX512Supported; } public static boolean isAVX512SPRSupportedBySystem() { - return areAVX512FlagsAvailable(new String[] { "avx512_fp16", "avx512_bf16", "avx512_vpopcntdq" }); + if (isAVX512SPRSupported == null) { + isAVX512SPRSupported = areAVX512FlagsAvailable(new String[] { "avx512_fp16", "avx512_bf16", "avx512_vpopcntdq" }); + } + return isAVX512SPRSupported; } private static boolean areAVX512FlagsAvailable(String[] avx512) { diff --git a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissMemoryOptimizedSearcher.java b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissMemoryOptimizedSearcher.java index 2c0d2e9c38..8ed4d12f87 100644 --- a/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissMemoryOptimizedSearcher.java +++ b/src/main/java/org/opensearch/knn/memoryoptsearch/faiss/FaissMemoryOptimizedSearcher.java @@ -26,14 +26,13 @@ public class FaissMemoryOptimizedSearcher implements VectorSearcher { private static final FlatVectorsScorer VECTOR_SCORER = FlatVectorScorerUtil.getLucene99FlatVectorsScorer(); private final IndexInput indexInput; - private FaissIndex faissIndex; - private FaissHnswGraph faissHnswGraph; + private final FaissIndex faissIndex; + private final FaissHNSW hnsw; public FaissMemoryOptimizedSearcher(IndexInput indexInput) throws IOException { this.indexInput = indexInput; this.faissIndex = FaissIndex.load(indexInput); - final FaissHNSW hnsw = extractFaissHnsw(faissIndex); - this.faissHnswGraph = new FaissHnswGraph(hnsw, indexInput); + this.hnsw = extractFaissHnsw(faissIndex); } private static FaissHNSW extractFaissHnsw(final FaissIndex faissIndex) { @@ -105,7 +104,7 @@ private void search( if (knnCollector.k() < scorer.maxOrd()) { // Do ANN search with Lucene's HNSW graph searcher. - HnswGraphSearcher.search(scorer, collector, faissHnswGraph, acceptedOrds); + HnswGraphSearcher.search(scorer, collector, new FaissHnswGraph(hnsw, indexInput), acceptedOrds); } else { // If k is larger than the number of vectors, we can just iterate over all vectors // and collect them. diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index e673c5d728..a549a1fe7d 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -30,6 +30,7 @@ import org.opensearch.index.codec.CodecServiceFactory; import org.opensearch.index.engine.EngineFactory; import org.opensearch.index.mapper.Mapper; +import org.opensearch.index.shard.IndexSettingProvider; import org.opensearch.indices.SystemIndexDescriptor; import org.opensearch.knn.common.featureflags.KNNFeatureFlags; import org.opensearch.knn.index.KNNCircuitBreaker; @@ -47,6 +48,7 @@ import org.opensearch.knn.indices.ModelCache; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelGraveyard; +import org.opensearch.knn.jni.PlatformUtils; import org.opensearch.knn.plugin.rest.RestClearCacheHandler; import org.opensearch.knn.plugin.rest.RestDeleteModelHandler; import org.opensearch.knn.plugin.rest.RestGetModelHandler; @@ -57,6 +59,9 @@ import org.opensearch.knn.plugin.script.KNNScoringScriptEngine; import org.opensearch.knn.plugin.search.KNNConcurrentSearchRequestDecider; import org.opensearch.knn.plugin.stats.KNNStats; +//import org.opensearch.knn.plugin.transport.*; +import org.opensearch.knn.plugin.transport.KNNProfileTransportAction; +import org.opensearch.knn.plugin.transport.KNNProfileAction; import org.opensearch.knn.plugin.transport.ClearCacheAction; import org.opensearch.knn.plugin.transport.ClearCacheTransportAction; import org.opensearch.knn.plugin.transport.DeleteModelAction; @@ -82,6 +87,7 @@ import org.opensearch.knn.plugin.transport.UpdateModelGraveyardTransportAction; import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction; import org.opensearch.knn.plugin.transport.UpdateModelMetadataTransportAction; +import org.opensearch.knn.profiler.RestKNNProfileHandler; import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCache; import org.opensearch.knn.training.TrainingJobClusterStateListener; import org.opensearch.knn.training.TrainingJobRunner; @@ -117,6 +123,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.ForkJoinPool; import java.util.function.Supplier; import static java.util.Collections.singletonList; @@ -124,6 +131,7 @@ import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; import static org.opensearch.knn.common.KNNConstants.TRAIN_THREAD_POOL; import static org.opensearch.knn.index.KNNCircuitBreaker.KNN_CIRCUIT_BREAKER_TIER; +import static org.opensearch.knn.index.KNNSettings.KNN_DERIVED_SOURCE_ENABLED; /** * Entry point for the KNN plugin where we define mapper for knn_vector type @@ -174,6 +182,14 @@ public class KNNPlugin extends Plugin private ClusterService clusterService; private Supplier repositoriesServiceSupplier; + static { + ForkJoinPool.commonPool().execute(() -> { + PlatformUtils.isAVX2SupportedBySystem(); + PlatformUtils.isAVX512SupportedBySystem(); + PlatformUtils.isAVX512SPRSupportedBySystem(); + }); + } + @Override public Map getMappers() { return Collections.singletonMap( @@ -232,6 +248,20 @@ public List> getSettings() { return KNNSettings.state().getSettings(); } + @Override + public Collection getAdditionalIndexSettingProviders() { + // Default derived source feature to true for knn indices. + return ImmutableList.of(new IndexSettingProvider() { + @Override + public Settings getAdditionalIndexSettings(String indexName, boolean isDataStreamIndex, Settings templateAndRequestSettings) { + if (templateAndRequestSettings.getAsBoolean(KNNSettings.KNN_INDEX, false)) { + return Settings.builder().put(KNN_DERIVED_SOURCE_ENABLED, true).build(); + } + return Settings.EMPTY; + } + }); + } + public List getRestHandlers( Settings settings, RestController restController, @@ -249,6 +279,12 @@ public List getRestHandlers( clusterService, indexNameExpressionResolver ); + RestKNNProfileHandler restKNNProfileHandler = new RestKNNProfileHandler( + settings, + restController, + clusterService, + indexNameExpressionResolver + ); RestGetModelHandler restGetModelHandler = new RestGetModelHandler(); RestDeleteModelHandler restDeleteModelHandler = new RestDeleteModelHandler(); RestTrainModelHandler restTrainModelHandler = new RestTrainModelHandler(); @@ -258,6 +294,7 @@ public List getRestHandlers( return ImmutableList.of( restKNNStatsHandler, restKNNWarmupHandler, + restKNNProfileHandler, restGetModelHandler, restDeleteModelHandler, restTrainModelHandler, @@ -274,6 +311,7 @@ public List getRestHandlers( return Arrays.asList( new ActionHandler<>(KNNStatsAction.INSTANCE, KNNStatsTransportAction.class), new ActionHandler<>(KNNWarmupAction.INSTANCE, KNNWarmupTransportAction.class), + new ActionHandler<>(KNNProfileAction.INSTANCE, KNNProfileTransportAction.class), new ActionHandler<>(UpdateModelMetadataAction.INSTANCE, UpdateModelMetadataTransportAction.class), new ActionHandler<>(TrainingJobRouteDecisionInfoAction.INSTANCE, TrainingJobRouteDecisionInfoTransportAction.class), new ActionHandler<>(GetModelAction.INSTANCE, GetModelTransportAction.class), diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java index d7a84817b2..1817ca73a1 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java @@ -114,9 +114,9 @@ public double execute(ScoreScript.ExplanationHolder explanationHolder) { * KNNVectors with float[] type. The query value passed in is expected to be float[]. The fieldType of the docs * being searched over are expected to be KNNVector type. */ - public static class KNNVectorType extends KNNScoreScript { + public static class KNNFloatVectorType extends KNNScoreScript { - public KNNVectorType( + public KNNFloatVectorType( Map params, float[] queryValue, String field, @@ -136,8 +136,45 @@ public KNNVectorType( * @return score of the vector to the query vector */ @Override + @SuppressWarnings("unchecked") public double execute(ScoreScript.ExplanationHolder explanationHolder) { - KNNVectorScriptDocValues scriptDocValues = (KNNVectorScriptDocValues) getDoc().get(this.field); + KNNVectorScriptDocValues scriptDocValues = (KNNVectorScriptDocValues) getDoc().get(this.field); + if (scriptDocValues.isEmpty()) { + return 0.0; + } + return this.scoringMethod.apply(this.queryValue, scriptDocValues.getValue()); + } + } + + /** + * KNNVectors with byte[] type. The query value passed in is expected to be byte[]. The fieldType of the docs + * being searched over are expected to be KNNVector type. + */ + public static class KNNByteVectorType extends KNNScoreScript { + + public KNNByteVectorType( + Map params, + byte[] queryValue, + String field, + BiFunction scoringMethod, + SearchLookup lookup, + LeafReaderContext leafContext, + IndexSearcher searcher + ) throws IOException { + super(params, queryValue, field, scoringMethod, lookup, leafContext, searcher); + } + + /** + * This function called for each doc in the segment. We evaluate the score of the vector in the doc + * + * @param explanationHolder A helper to take in an explanation from a script and turn + * it into an {@link org.apache.lucene.search.Explanation} + * @return score of the vector to the query vector + */ + @Override + @SuppressWarnings("unchecked") + public double execute(ScoreScript.ExplanationHolder explanationHolder) { + KNNVectorScriptDocValues scriptDocValues = (KNNVectorScriptDocValues) getDoc().get(this.field); if (scriptDocValues.isEmpty()) { return 0.0; } diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java index b613efab21..b77b0f4751 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java @@ -31,6 +31,7 @@ import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isLongFieldType; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToBigInteger; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToFloatArray; +import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToByteArray; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToLong; public interface KNNScoringSpace { @@ -54,9 +55,9 @@ ScoreScript getScoreScript(Map params, String field, SearchLooku abstract class KNNFieldSpace implements KNNScoringSpace { public static final Set DATA_TYPES_DEFAULT = Set.of(VectorDataType.FLOAT, VectorDataType.BYTE); - private float[] processedQuery; + private Object processedQuery; @Getter - private BiFunction scoringMethod; + private BiFunction scoringMethod; public KNNFieldSpace(final Object query, final MappedFieldType fieldType, final String spaceName) { this(query, fieldType, spaceName, DATA_TYPES_DEFAULT); @@ -73,6 +74,7 @@ public KNNFieldSpace( this.scoringMethod = getScoringMethod(this.processedQuery, knnVectorFieldType.getKnnMappingConfig().getIndexCreatedVersion()); } + @SuppressWarnings("unchecked") public ScoreScript getScoreScript( Map params, String field, @@ -80,7 +82,31 @@ public ScoreScript getScoreScript( LeafReaderContext ctx, IndexSearcher searcher ) throws IOException { - return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher); + if (processedQuery instanceof float[]) { + return new KNNScoreScript.KNNFloatVectorType( + params, + (float[]) this.processedQuery, + field, + (BiFunction) this.scoringMethod, + lookup, + ctx, + searcher + ); + } else if (processedQuery instanceof byte[]) { + return new KNNScoreScript.KNNByteVectorType( + params, + (byte[]) this.processedQuery, + field, + (BiFunction) this.scoringMethod, + lookup, + ctx, + searcher + ); + } else { + throw new IllegalStateException( + "Unexpected type for processedQuery. Expected float[] or byte[], but got: " + processedQuery.getClass().getName() + ); + } } private KNNVectorFieldType toKNNVectorFieldType( @@ -113,17 +139,27 @@ private KNNVectorFieldType toKNNVectorFieldType( return knnVectorFieldType; } - protected float[] getProcessedQuery(final Object query, final KNNVectorFieldType knnVectorFieldType) { - return parseToFloatArray( + protected Object getProcessedQuery(final Object query, final KNNVectorFieldType knnVectorFieldType) { + VectorDataType vectorDataType = knnVectorFieldType.getVectorDataType() == null + ? VectorDataType.FLOAT + : knnVectorFieldType.getVectorDataType(); + if (vectorDataType == VectorDataType.FLOAT) { + return parseToFloatArray( + query, + KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldType), + knnVectorFieldType.getVectorDataType() + ); + } + return parseToByteArray( query, KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldType), knnVectorFieldType.getVectorDataType() ); } - protected abstract BiFunction getScoringMethod(final float[] processedQuery); + public abstract BiFunction getScoringMethod(final Object processedQuery); - protected BiFunction getScoringMethod(final float[] processedQuery, Version indexCreatedVersion) { + protected BiFunction getScoringMethod(final Object processedQuery, Version indexCreatedVersion) { return getScoringMethod(processedQuery); } @@ -135,8 +171,12 @@ public L2(final Object query, final MappedFieldType fieldType) { } @Override - public BiFunction getScoringMethod(final float[] processedQuery) { - return (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v)); + public BiFunction getScoringMethod(final Object processedQuery) { + if (processedQuery instanceof float[]) { + return (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v)); + } else { + return (byte[] q, byte[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v)); + } } } @@ -146,30 +186,35 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) { } @Override - protected BiFunction getScoringMethod(float[] processedQuery) { + public BiFunction getScoringMethod(Object processedQuery) { return getScoringMethod(processedQuery, Version.CURRENT); } @Override - protected BiFunction getScoringMethod(final float[] processedQuery, Version indexCreatedVersion) { - SpaceType.COSINESIMIL.validateVector(processedQuery); - float qVectorSquaredMagnitude = getVectorMagnitudeSquared(processedQuery); - if (indexCreatedVersion.onOrAfter(Version.V_2_19_0)) { - // To be consistent, we will be using same formula used by lucene as mentioned below - // https://github.com/apache/lucene/blob/0494c824e0ac8049b757582f60d085932a890800/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java#L73 - // for indices that are created on or after 2.19.0 - // - // OS Score = ( 2 − cosineSimil) / 2 - // However cosineSimil = 1 - cos θ, after applying this to above formula, - // OS Score = ( 2 − ( 1 − cos θ ) ) / 2 - // which simplifies to - // OS Score = ( 1 + cos θ ) / 2 - return (float[] q, float[] v) -> Math.max( - ((1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude)) / 2.0F), - 0 - ); + protected BiFunction getScoringMethod(final Object processedQuery, Version indexCreatedVersion) { + if (processedQuery instanceof float[]) { + SpaceType.COSINESIMIL.validateVector((float[]) processedQuery); + float qVectorSquaredMagnitude = getVectorMagnitudeSquared((float[]) processedQuery); + if (indexCreatedVersion.onOrAfter(Version.V_2_19_0)) { + // To be consistent, we will be using same formula used by lucene as mentioned below + // https://github.com/apache/lucene/blob/0494c824e0ac8049b757582f60d085932a890800/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java#L73 + // for indices that are created on or after 2.19.0 + // + // OS Score = ( 2 − cosineSimil) / 2 + // However cosineSimil = 1 - cos θ, after applying this to above formula, + // OS Score = ( 2 − ( 1 − cos θ ) ) / 2 + // which simplifies to + // OS Score = ( 1 + cos θ ) / 2 + return (float[] q, float[] v) -> Math.max( + ((1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude)) / 2.0F), + 0 + ); + } + return (float[] q, float[] v) -> 1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude); + } else { + SpaceType.COSINESIMIL.validateVector((byte[]) processedQuery); + return (byte[] q, byte[] v) -> 1 + KNNScoringUtil.cosinesimil(q, v); } - return (float[] q, float[] v) -> 1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude); } } @@ -179,8 +224,12 @@ public L1(Object query, MappedFieldType fieldType) { } @Override - protected BiFunction getScoringMethod(final float[] processedQuery) { - return (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v)); + public BiFunction getScoringMethod(final Object processedQuery) { + if (processedQuery instanceof float[]) { + return (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v)); + } else { + return (byte[] q, byte[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v)); + } } } @@ -190,8 +239,12 @@ public LInf(Object query, MappedFieldType fieldType) { } @Override - protected BiFunction getScoringMethod(final float[] processedQuery) { - return (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v)); + public BiFunction getScoringMethod(final Object processedQuery) { + if (processedQuery instanceof float[]) { + return (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v)); + } else { + return (byte[] q, byte[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v)); + } } } @@ -201,8 +254,12 @@ public InnerProd(Object query, MappedFieldType fieldType) { } @Override - protected BiFunction getScoringMethod(final float[] processedQuery) { - return (float[] q, float[] v) -> KNNWeight.normalizeScore(-KNNScoringUtil.innerProduct(q, v)); + public BiFunction getScoringMethod(final Object processedQuery) { + if (processedQuery instanceof float[]) { + return (float[] q, float[] v) -> KNNWeight.normalizeScore(-KNNScoringUtil.innerProduct(q, v)); + } else { + return (byte[] q, byte[] v) -> KNNWeight.normalizeScore(-KNNScoringUtil.innerProduct(q, v)); + } } } @@ -214,17 +271,8 @@ public Hamming(Object query, MappedFieldType fieldType) { } @Override - protected BiFunction getScoringMethod(final float[] processedQuery) { - // TODO we want to avoid converting back and forth between byte and float - return (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.calculateHammingBit(toByte(q), toByte(v))); - } - - private byte[] toByte(final float[] vector) { - byte[] bytes = new byte[vector.length]; - for (int i = 0; i < vector.length; i++) { - bytes[i] = (byte) vector[i]; - } - return bytes; + public BiFunction getScoringMethod(final Object processedQuery) { + return (byte[] q, byte[] v) -> 1 / (1 + KNNScoringUtil.calculateHammingBit(q, v)); } } diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java index 7a97fdb058..7699403b3b 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java @@ -134,6 +134,48 @@ public static float[] convertVectorToPrimitive(Object vector, VectorDataType vec return primitiveVector; } + /** + * Convert an Object to a byte array. + * + * @param object Object to be converted to a byte array + * @param expectedVectorLength int representing the expected vector length of this array. + * @return byte[] of the object + */ + public static byte[] parseToByteArray(Object object, int expectedVectorLength, VectorDataType vectorDataType) { + byte[] byteArray = convertVectorToByteArray(object, vectorDataType); + if (expectedVectorLength != byteArray.length) { + KNNCounter.SCRIPT_QUERY_ERRORS.increment(); + throw new IllegalStateException( + "Object's length=" + byteArray.length + " does not match the " + "expected length=" + expectedVectorLength + "." + ); + } + return byteArray; + } + + /** + * Converts Object vector to byte[] + * + * Expects all numbers in the Object vector to be in the byte range of [-128 to 127] + * @param vector input vector + * @return Byte array representing the vector + */ + @SuppressWarnings("unchecked") + public static byte[] convertVectorToByteArray(Object vector, VectorDataType vectorDataType) { + byte[] byteVector = null; + if (vector != null) { + final List tmp = (List) vector; + byteVector = new byte[tmp.size()]; + for (int i = 0; i < byteVector.length; i++) { + float value = tmp.get(i).floatValue(); + if (VectorDataType.BYTE == vectorDataType || VectorDataType.BINARY == vectorDataType) { + validateByteVectorValue(value, vectorDataType); + } + byteVector[i] = tmp.get(i).byteValue(); + } + } + return byteVector; + } + /** * Calculates the magnitude of given vector * diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java index f61ae4349e..e14058e004 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java @@ -99,6 +99,18 @@ public static float l2Squared(float[] queryVector, float[] inputVector) { return VectorUtil.squareDistance(queryVector, inputVector); } + /** + * This method calculates L2 squared distance between byte query vector + * and byte input vector + * + * @param queryVector byte query vector + * @param inputVector byte input vector + * @return L2 score + */ + public static float l2Squared(byte[] queryVector, byte[] inputVector) { + return VectorUtil.squareDistance(queryVector, inputVector); + } + private static float[] toFloat(final List inputVector, final VectorDataType vectorDataType) { Objects.requireNonNull(inputVector); float[] value = new float[inputVector.size()]; @@ -144,6 +156,23 @@ public static float cosinesimil(float[] queryVector, float[] inputVector) { } } + /** + * This method calculates cosine similarity + * + * @param queryVector byte query vector + * @param inputVector byte input vector + * @return cosine score + */ + public static float cosinesimil(byte[] queryVector, byte[] inputVector) { + requireEqualDimension(queryVector, inputVector); + try { + return VectorUtil.cosine(queryVector, inputVector); + } catch (IllegalArgumentException | AssertionError e) { + logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end"); + return 0.0f; + } + } + /** * This method can be used script to avoid repeated calculation of normalization * for query vector for each filtered documents @@ -222,6 +251,24 @@ public static float l1Norm(float[] queryVector, float[] inputVector) { return distance; } + /** + * This method calculates L1 distance between byte query vector + * and byte input vector + * + * @param queryVector byte query vector + * @param inputVector byte input vector + * @return L1 score + */ + public static float l1Norm(byte[] queryVector, byte[] inputVector) { + requireEqualDimension(queryVector, inputVector); + float distance = 0; + for (int i = 0; i < inputVector.length; i++) { + float diff = queryVector[i] - inputVector[i]; + distance += Math.abs(diff); + } + return distance; + } + /** * This method calculates L-inf distance between query vector * and input vector @@ -240,6 +287,24 @@ public static float lInfNorm(float[] queryVector, float[] inputVector) { return distance; } + /** + * This method calculates L-inf distance between byte query vector + * and input vector + * + * @param queryVector byte query vector + * @param inputVector byte input vector + * @return L-inf score + */ + public static float lInfNorm(byte[] queryVector, byte[] inputVector) { + requireEqualDimension(queryVector, inputVector); + float distance = 0; + for (int i = 0; i < inputVector.length; i++) { + float diff = queryVector[i] - inputVector[i]; + distance = Math.max(Math.abs(diff), distance); + } + return distance; + } + /** * This method calculates dot product distance between query vector * and input vector @@ -253,6 +318,19 @@ public static float innerProduct(float[] queryVector, float[] inputVector) { return VectorUtil.dotProduct(queryVector, inputVector); } + /** + * This method calculates dot product distance between byte query vector + * and byte input vector + * + * @param queryVector query vector + * @param inputVector input vector + * @return dot product score + */ + public static float innerProduct(byte[] queryVector, byte[] inputVector) { + requireEqualDimension(queryVector, inputVector); + return VectorUtil.dotProduct(queryVector, inputVector); + } + /** ********************************************************************************************* * Functions to be used in painless script which is defined in knn_allowlist.txt @@ -275,9 +353,13 @@ public static float innerProduct(float[] queryVector, float[] inputVector) { * @param docValues script doc values * @return L2 score */ - public static float l2Squared(List queryVector, KNNVectorScriptDocValues docValues) { - requireNonBinaryType("l2Squared", docValues.getVectorDataType()); - return l2Squared(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); + public static float l2Squared(List queryVector, KNNVectorScriptDocValues docValues) { + final VectorDataType vectorDataType = docValues.getVectorDataType(); + requireNonBinaryType("l2Squared", vectorDataType); + if (VectorDataType.FLOAT == vectorDataType) { + return l2Squared(toFloat(queryVector, docValues.getVectorDataType()), (float[]) docValues.getValue()); + } + return l2Squared(toByte(queryVector, docValues.getVectorDataType()), (byte[]) docValues.getValue()); } /** @@ -296,9 +378,13 @@ public static float l2Squared(List queryVector, KNNVectorScriptDocValues * @param docValues script doc values * @return L-inf score */ - public static float lInfNorm(List queryVector, KNNVectorScriptDocValues docValues) { - requireNonBinaryType("lInfNorm", docValues.getVectorDataType()); - return lInfNorm(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); + public static float lInfNorm(List queryVector, KNNVectorScriptDocValues docValues) { + final VectorDataType vectorDataType = docValues.getVectorDataType(); + requireNonBinaryType("lInfNorm", vectorDataType); + if (VectorDataType.FLOAT == vectorDataType) { + return lInfNorm(toFloat(queryVector, docValues.getVectorDataType()), (float[]) docValues.getValue()); + } + return lInfNorm(toByte(queryVector, docValues.getVectorDataType()), (byte[]) docValues.getValue()); } /** @@ -317,9 +403,13 @@ public static float lInfNorm(List queryVector, KNNVectorScriptDocValues * @param docValues script doc values * @return L1 score */ - public static float l1Norm(List queryVector, KNNVectorScriptDocValues docValues) { - requireNonBinaryType("l1Norm", docValues.getVectorDataType()); - return l1Norm(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); + public static float l1Norm(List queryVector, KNNVectorScriptDocValues docValues) { + final VectorDataType vectorDataType = docValues.getVectorDataType(); + requireNonBinaryType("l1Norm", vectorDataType); + if (VectorDataType.FLOAT == vectorDataType) { + return l1Norm(toFloat(queryVector, docValues.getVectorDataType()), (float[]) docValues.getValue()); + } + return l1Norm(toByte(queryVector, docValues.getVectorDataType()), (byte[]) docValues.getValue()); } /** @@ -338,9 +428,13 @@ public static float l1Norm(List queryVector, KNNVectorScriptDocValues do * @param docValues script doc values * @return inner product score */ - public static float innerProduct(List queryVector, KNNVectorScriptDocValues docValues) { - requireNonBinaryType("innerProduct", docValues.getVectorDataType()); - return innerProduct(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); + public static float innerProduct(List queryVector, KNNVectorScriptDocValues docValues) { + final VectorDataType vectorDataType = docValues.getVectorDataType(); + requireNonBinaryType("innerProduct", vectorDataType); + if (VectorDataType.FLOAT == vectorDataType) { + return innerProduct(toFloat(queryVector, docValues.getVectorDataType()), (float[]) docValues.getValue()); + } + return innerProduct(toByte(queryVector, docValues.getVectorDataType()), (byte[]) docValues.getValue()); } /** @@ -359,11 +453,18 @@ public static float innerProduct(List queryVector, KNNVectorScriptDocVal * @param docValues script doc values * @return cosine score */ - public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues) { - requireNonBinaryType("cosineSimilarity", docValues.getVectorDataType()); - float[] inputVector = toFloat(queryVector, docValues.getVectorDataType()); - SpaceType.COSINESIMIL.validateVector(inputVector); - return cosinesimil(inputVector, docValues.getValue()); + public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues) { + final VectorDataType vectorDataType = docValues.getVectorDataType(); + requireNonBinaryType("cosineSimilarity", vectorDataType); + if (VectorDataType.FLOAT == vectorDataType) { + float[] inputVector = toFloat(queryVector, docValues.getVectorDataType()); + SpaceType.COSINESIMIL.validateVector(inputVector); + return cosinesimil(inputVector, (float[]) docValues.getValue()); + } else { + byte[] inputVector = toByte(queryVector, docValues.getVectorDataType()); + SpaceType.COSINESIMIL.validateVector(inputVector); + return cosinesimil(inputVector, (byte[]) docValues.getValue()); + } } /** @@ -383,11 +484,21 @@ public static float cosineSimilarity(List queryVector, KNNVectorScriptDo * @param queryVectorMagnitude the magnitude of the query vector. * @return cosine score */ - public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues, Number queryVectorMagnitude) { - requireNonBinaryType("cosineSimilarity", docValues.getVectorDataType()); + public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues, Number queryVectorMagnitude) { + final VectorDataType vectorDataType = docValues.getVectorDataType(); + requireNonBinaryType("cosineSimilarity", vectorDataType); float[] inputVector = toFloat(queryVector, docValues.getVectorDataType()); SpaceType.COSINESIMIL.validateVector(inputVector); - return cosinesimilOptimized(inputVector, docValues.getValue(), queryVectorMagnitude.floatValue()); + if (VectorDataType.FLOAT == vectorDataType) { + return cosinesimilOptimized(inputVector, (float[]) docValues.getValue(), queryVectorMagnitude.floatValue()); + } else { + byte[] docVectorInByte = (byte[]) docValues.getValue(); + float[] docVectorInFloat = new float[docVectorInByte.length]; + for (int i = 0; i < docVectorInByte.length; i++) { + docVectorInFloat[i] = docVectorInByte[i]; + } + return cosinesimilOptimized(inputVector, docVectorInFloat, queryVectorMagnitude.floatValue()); + } } /** @@ -406,17 +517,9 @@ public static float cosineSimilarity(List queryVector, KNNVectorScriptDo * @param docValues script doc values * @return hamming score */ - public static float hamming(List queryVector, KNNVectorScriptDocValues docValues) { + public static float hamming(List queryVector, KNNVectorScriptDocValues docValues) { requireBinaryType("hamming", docValues.getVectorDataType()); byte[] queryVectorInByte = toByte(queryVector, docValues.getVectorDataType()); - - // TODO Optimization need be done for doc value to return byte[] instead of float[] - float[] docVectorInFloat = docValues.getValue(); - byte[] docVectorInByte = new byte[docVectorInFloat.length]; - for (int i = 0; i < docVectorInByte.length; i++) { - docVectorInByte[i] = (byte) docVectorInFloat[i]; - } - - return calculateHammingBit(queryVectorInByte, docVectorInByte); + return calculateHammingBit(queryVectorInByte, (byte[]) docValues.getValue()); } } diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNRemoteIndexBuildValue.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNRemoteIndexBuildValue.java new file mode 100644 index 0000000000..2ca4af399b --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNRemoteIndexBuildValue.java @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.stats; + +import lombok.Getter; + +import java.util.concurrent.atomic.LongAdder; + +public enum KNNRemoteIndexBuildValue { + + // Repository Accumulating Stats + WRITE_SUCCESS_COUNT("write_success_count"), + WRITE_FAILURE_COUNT("write_failure_count"), + WRITE_TIME("successful_write_time_in_millis"), + READ_SUCCESS_COUNT("read_success_count"), + READ_FAILURE_COUNT("read_failure_count"), + READ_TIME("successful_read_time_in_millis"), + + // Remote Index Build Stats + REMOTE_INDEX_BUILD_CURRENT_MERGE_OPERATIONS("remote_index_build_current_merge_operations"), + REMOTE_INDEX_BUILD_CURRENT_FLUSH_OPERATIONS("remote_index_build_current_flush_operations"), + REMOTE_INDEX_BUILD_CURRENT_MERGE_SIZE("remote_index_build_current_merge_size"), + REMOTE_INDEX_BUILD_CURRENT_FLUSH_SIZE("remote_index_build_current_flush_size"), + REMOTE_INDEX_BUILD_MERGE_TIME("remote_index_build_merge_time_in_millis"), + REMOTE_INDEX_BUILD_FLUSH_TIME("remote_index_build_flush_time_in_millis"), + + // Client Stats + BUILD_REQUEST_SUCCESS_COUNT("build_request_success_count"), + BUILD_REQUEST_FAILURE_COUNT("build_request_failure_count"), + STATUS_REQUEST_SUCCESS_COUNT("status_request_success_count"), + STATUS_REQUEST_FAILURE_COUNT("status_request_failure_count"), + INDEX_BUILD_SUCCESS_COUNT("index_build_success_count"), + INDEX_BUILD_FAILURE_COUNT("index_build_failure_count"), + WAITING_TIME("waiting_time_in_ms"); + + @Getter + private final String name; + private final LongAdder value; + + /** + * Constructor + * + * @param name name of the value + */ + KNNRemoteIndexBuildValue(String name) { + this.name = name; + this.value = new LongAdder(); + } + + /** + * Get the value + * @return value + */ + public Long getValue() { + return value.longValue(); + } + + /** + * Increment the value + */ + public void increment() { + value.increment(); + } + + /** + * Decrement the value + */ + public void decrement() { + value.decrement(); + } + + /** + * Increment the value by a specified amount + * + * @param delta The amount to increment + */ + public void incrementBy(long delta) { + value.add(delta); + } + + /** + * Decrement the value by a specified amount + * + * @param delta The amount to decrement + */ + public void decrementBy(long delta) { + value.add(delta * -1); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java b/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java index bcd419ea68..9fc03f55d2 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java @@ -8,8 +8,8 @@ import com.google.common.cache.CacheStats; import com.google.common.collect.ImmutableMap; import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.indices.ModelCache; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.plugin.stats.suppliers.EventOccurredWithinThresholdSupplier; @@ -86,6 +86,7 @@ private Map> buildStatsMap() { addScriptStats(builder); addModelStats(builder); addGraphStats(builder); + addRemoteIndexBuildStats(builder); return builder.build(); } @@ -218,4 +219,77 @@ private Map> createGraphStatsMap() { graphStatsMap.put(StatNames.REFRESH.getName(), refreshMap); return graphStatsMap; } + + private void addRemoteIndexBuildStats(ImmutableMap.Builder> builder) { + builder.put(StatNames.REMOTE_VECTOR_INDEX_BUILD_STATS.getName(), new KNNStat<>(false, this::createRemoteIndexStatsMap)); + } + + private Map> createRemoteIndexStatsMap() { + Map clientStatsMap = new HashMap<>(); + clientStatsMap.put( + KNNRemoteIndexBuildValue.BUILD_REQUEST_SUCCESS_COUNT.getName(), + KNNRemoteIndexBuildValue.BUILD_REQUEST_SUCCESS_COUNT.getValue() + ); + clientStatsMap.put( + KNNRemoteIndexBuildValue.BUILD_REQUEST_FAILURE_COUNT.getName(), + KNNRemoteIndexBuildValue.BUILD_REQUEST_FAILURE_COUNT.getValue() + ); + clientStatsMap.put( + KNNRemoteIndexBuildValue.STATUS_REQUEST_SUCCESS_COUNT.getName(), + KNNRemoteIndexBuildValue.STATUS_REQUEST_SUCCESS_COUNT.getValue() + ); + clientStatsMap.put( + KNNRemoteIndexBuildValue.STATUS_REQUEST_FAILURE_COUNT.getName(), + KNNRemoteIndexBuildValue.STATUS_REQUEST_FAILURE_COUNT.getValue() + ); + clientStatsMap.put( + KNNRemoteIndexBuildValue.INDEX_BUILD_SUCCESS_COUNT.getName(), + KNNRemoteIndexBuildValue.INDEX_BUILD_SUCCESS_COUNT.getValue() + ); + clientStatsMap.put( + KNNRemoteIndexBuildValue.INDEX_BUILD_FAILURE_COUNT.getName(), + KNNRemoteIndexBuildValue.INDEX_BUILD_FAILURE_COUNT.getValue() + ); + clientStatsMap.put(KNNRemoteIndexBuildValue.WAITING_TIME.getName(), KNNRemoteIndexBuildValue.WAITING_TIME.getValue()); + + Map repoStatsMap = new HashMap<>(); + repoStatsMap.put(KNNRemoteIndexBuildValue.WRITE_SUCCESS_COUNT.getName(), KNNRemoteIndexBuildValue.WRITE_SUCCESS_COUNT.getValue()); + repoStatsMap.put(KNNRemoteIndexBuildValue.WRITE_FAILURE_COUNT.getName(), KNNRemoteIndexBuildValue.WRITE_FAILURE_COUNT.getValue()); + repoStatsMap.put(KNNRemoteIndexBuildValue.WRITE_TIME.getName(), KNNRemoteIndexBuildValue.WRITE_TIME.getValue()); + repoStatsMap.put(KNNRemoteIndexBuildValue.READ_SUCCESS_COUNT.getName(), KNNRemoteIndexBuildValue.READ_SUCCESS_COUNT.getValue()); + repoStatsMap.put(KNNRemoteIndexBuildValue.READ_FAILURE_COUNT.getName(), KNNRemoteIndexBuildValue.READ_FAILURE_COUNT.getValue()); + repoStatsMap.put(KNNRemoteIndexBuildValue.READ_TIME.getName(), KNNRemoteIndexBuildValue.READ_TIME.getValue()); + + Map buildStatsMap = new HashMap<>(); + buildStatsMap.put( + KNNRemoteIndexBuildValue.REMOTE_INDEX_BUILD_CURRENT_FLUSH_OPERATIONS.getName(), + KNNRemoteIndexBuildValue.REMOTE_INDEX_BUILD_CURRENT_FLUSH_OPERATIONS.getValue() + ); + buildStatsMap.put( + KNNRemoteIndexBuildValue.REMOTE_INDEX_BUILD_CURRENT_MERGE_OPERATIONS.getName(), + KNNRemoteIndexBuildValue.REMOTE_INDEX_BUILD_CURRENT_MERGE_OPERATIONS.getValue() + ); + buildStatsMap.put( + KNNRemoteIndexBuildValue.REMOTE_INDEX_BUILD_CURRENT_FLUSH_SIZE.getName(), + KNNRemoteIndexBuildValue.REMOTE_INDEX_BUILD_CURRENT_FLUSH_SIZE.getValue() + ); + buildStatsMap.put( + KNNRemoteIndexBuildValue.REMOTE_INDEX_BUILD_CURRENT_MERGE_SIZE.getName(), + KNNRemoteIndexBuildValue.REMOTE_INDEX_BUILD_CURRENT_MERGE_SIZE.getValue() + ); + buildStatsMap.put( + KNNRemoteIndexBuildValue.REMOTE_INDEX_BUILD_FLUSH_TIME.getName(), + KNNRemoteIndexBuildValue.REMOTE_INDEX_BUILD_FLUSH_TIME.getValue() + ); + buildStatsMap.put( + KNNRemoteIndexBuildValue.REMOTE_INDEX_BUILD_MERGE_TIME.getName(), + KNNRemoteIndexBuildValue.REMOTE_INDEX_BUILD_MERGE_TIME.getValue() + ); + + Map> remoteIndexBuildStatsMap = new HashMap<>(); + remoteIndexBuildStatsMap.put(StatNames.BUILD_STATS.getName(), buildStatsMap); + remoteIndexBuildStatsMap.put(StatNames.CLIENT_STATS.getName(), clientStatsMap); + remoteIndexBuildStatsMap.put(StatNames.REPOSITORY_STATS.getName(), repoStatsMap); + return remoteIndexBuildStatsMap; + } } diff --git a/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java b/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java index e7f4fd4a2d..1ed3385898 100644 --- a/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java +++ b/src/main/java/org/opensearch/knn/plugin/stats/StatNames.java @@ -45,6 +45,10 @@ public enum StatNames { GRAPH_STATS("graph_stats"), REFRESH("refresh"), MERGE("merge"), + REMOTE_VECTOR_INDEX_BUILD_STATS("remote_vector_index_build_stats"), + CLIENT_STATS("client_stats"), + REPOSITORY_STATS("repository_stats"), + BUILD_STATS("build_stats"), MIN_SCORE_QUERY_REQUESTS(KNNCounter.MIN_SCORE_QUERY_REQUESTS.getName()), MIN_SCORE_QUERY_WITH_FILTER_REQUESTS(KNNCounter.MIN_SCORE_QUERY_WITH_FILTER_REQUESTS.getName()), MAX_DISTANCE_QUERY_REQUESTS(KNNCounter.MAX_DISTANCE_QUERY_REQUESTS.getName()), diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileAction.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileAction.java new file mode 100644 index 0000000000..6e5a9c836e --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileAction.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.action.ActionType; + +/** + * Action for profiling KNN vectors in an index + */ +public class KNNProfileAction extends ActionType { + public static final String NAME = "indices:knn/vector/profile"; + public static final KNNProfileAction INSTANCE = new KNNProfileAction(); + + private KNNProfileAction() { + super(NAME, KNNProfileResponse::new); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileRequest.java new file mode 100644 index 0000000000..c16891b5ec --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileRequest.java @@ -0,0 +1,76 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.action.support.broadcast.BroadcastRequest; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.IOException; + +/** + * Request for KNN profile operation + */ +public class KNNProfileRequest extends BroadcastRequest { + private String fieldName; + + /** + * Constructor + */ + public KNNProfileRequest() { + super(); + } + + /** + * Constructor with indices + * + * @param indices Indices to profile + */ + public KNNProfileRequest(String... indices) { + super(indices); + } + + /** + * Constructor from StreamInput + * + * @param in StreamInput + * @throws IOException if there's an error reading from stream + */ + public KNNProfileRequest(StreamInput in) throws IOException { + super(in); + this.fieldName = in.readOptionalString(); + } + + /** + * Write to StreamOutput + * + * @param out StreamOutput + * @throws IOException if there's an error writing to stream + */ + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalString(fieldName); + } + + /** + * Get field name to profile + * + * @return field name + */ + public String getFieldName() { + return fieldName; + } + + /** + * Set field name to profile + * + * @param fieldName field name + */ + public void setFieldName(String fieldName) { + this.fieldName = fieldName; + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileResponse.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileResponse.java new file mode 100644 index 0000000000..17bef461f4 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileResponse.java @@ -0,0 +1,172 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import lombok.extern.log4j.Log4j2; +import org.apache.commons.math3.stat.descriptive.StatisticalSummaryValues; +import org.opensearch.action.support.broadcast.BroadcastResponse; +import org.opensearch.core.action.support.DefaultShardOperationFailedException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.HashMap; + +/** + * Response for KNN profile request + */ +@Log4j2 +public class KNNProfileResponse extends BroadcastResponse implements ToXContentObject { + private final List shardResults; + + /** + * Constructor + */ + public KNNProfileResponse( + int totalShards, + int successfulShards, + int failedShards, + List shardResults, + List shardFailures + ) { + super(totalShards, successfulShards, failedShards, shardFailures); + this.shardResults = shardResults != null ? shardResults : List.of(); + } + + /** + * Constructor for serialization + */ + public KNNProfileResponse(StreamInput in) throws IOException { + super(in); + int size = in.readVInt(); + shardResults = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + shardResults.add(new KNNProfileShardResult(in)); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeVInt(shardResults.size()); + for (KNNProfileShardResult result : shardResults) { + result.writeTo(out); + } + } + + /** + * Get aggregated dimension statistics by index + */ + private Map>> getAggregatedStats() { + Map>> indexDimensions = new HashMap<>(); + + for (KNNProfileShardResult result : shardResults) { + String indexName = result.getShardId().getIndexName(); + List stats = result.getDimensionStats(); + + if (stats == null || stats.isEmpty()) { + continue; + } + + Map> dimensions = indexDimensions.computeIfAbsent(indexName, k -> new HashMap<>()); + + for (int i = 0; i < stats.size(); i++) { + StatisticalSummaryValues stat = stats.get(i); + if (stat == null) { + continue; + } + + int finalI = i; + Map dimension = dimensions.computeIfAbsent(i, k -> { + Map newDim = new HashMap<>(); + newDim.put("dimension", finalI); + newDim.put("count", 0L); + newDim.put("sum", 0.0); + return newDim; + }); + + long oldCount = (Long) dimension.get("count"); + double oldSum = (Double) dimension.get("sum"); + long newCount = oldCount + stat.getN(); + double newSum = oldSum + stat.getSum(); + + dimension.put("count", newCount); + dimension.put("sum", newSum); + dimension.put("mean", newCount > 0 ? newSum / newCount : 0); + + if (!dimension.containsKey("min") || stat.getMin() < (Double) dimension.get("min")) { + dimension.put("min", stat.getMin()); + } + if (!dimension.containsKey("max") || stat.getMax() > (Double) dimension.get("max")) { + dimension.put("max", stat.getMax()); + } + + if (newCount > 0) { + dimension.put("variance", stat.getVariance()); + dimension.put("std_deviation", Math.sqrt(stat.getVariance())); + } + } + } + + return indexDimensions; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("total_shards", getTotalShards()); + builder.field("successful_shards", getSuccessfulShards()); + builder.field("failed_shards", getFailedShards()); + + builder.startObject("profile_results"); + + try { + Map>> indexDimensions = getAggregatedStats(); + + for (Map.Entry>> indexEntry : indexDimensions.entrySet()) { + String indexName = indexEntry.getKey(); + Map> dimensions = indexEntry.getValue(); + + builder.startObject(indexName); + builder.startArray("dimensions"); + + List dimensionIndices = new ArrayList<>(dimensions.keySet()); + Collections.sort(dimensionIndices); + + for (Integer dimIndex : dimensionIndices) { + builder.map(dimensions.get(dimIndex)); + } + + builder.endArray(); + builder.endObject(); + } + + } catch (Exception e) { + log.error("Error generating profile results", e); + } + + builder.endObject(); + + if (getShardFailures() != null && getShardFailures().length > 0) { + builder.startArray("failures"); + for (DefaultShardOperationFailedException failure : getShardFailures()) { + builder.startObject(); + failure.toXContent(builder, params); + builder.endObject(); + } + builder.endArray(); + } + + builder.endObject(); + return builder; + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileShardResult.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileShardResult.java new file mode 100644 index 0000000000..e3fe9e2f2a --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileShardResult.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.apache.commons.math3.stat.descriptive.StatisticalSummaryValues; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.index.shard.ShardId; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.opensearch.knn.plugin.transport.KNNWarmupTransportAction.logger; + +/** + * Shard-level result for KNN profiling + */ +public class KNNProfileShardResult implements Writeable { + private final ShardId shardId; + private final List dimensionStats; + + /** + * Constructor + * @param shardId the shard ID + * @param dimensionStats statistical summaries for each dimension + */ + public KNNProfileShardResult(ShardId shardId, List dimensionStats) { + this.shardId = shardId; + this.dimensionStats = dimensionStats; + logger.info("[KNN] Created KNNProfileShardResult for shard {} with {} stats", shardId, this.dimensionStats.size()); + } + + /** + * Constructor for serialization + * @param in stream input + * @throws IOException if there's an error reading from the stream + */ + public KNNProfileShardResult(StreamInput in) throws IOException { + this.shardId = new ShardId(in); + int size = in.readVInt(); + this.dimensionStats = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + double mean = in.readDouble(); + double variance = in.readDouble(); + long n = in.readVLong(); + double min = in.readDouble(); + double max = in.readDouble(); + double sum = in.readDouble(); + this.dimensionStats.add(new StatisticalSummaryValues(mean, variance, n, min, max, sum)); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + shardId.writeTo(out); + out.writeVInt(dimensionStats.size()); + for (StatisticalSummaryValues stats : dimensionStats) { + out.writeDouble(stats.getMean()); + out.writeDouble(stats.getVariance()); + out.writeVLong(stats.getN()); + out.writeDouble(stats.getMin()); + out.writeDouble(stats.getMax()); + out.writeDouble(stats.getSum()); + } + } + + /** + * Get the shard ID + * @return shard ID + */ + public ShardId getShardId() { + return shardId; + } + + /** + * Get the statistical summaries for each dimension + * @return list of statistical summaries + */ + public List getDimensionStats() { + return dimensionStats; + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileTransportAction.java new file mode 100644 index 0000000000..614c6ee419 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNProfileTransportAction.java @@ -0,0 +1,115 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.apache.commons.math3.stat.descriptive.StatisticalSummaryValues; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.broadcast.node.TransportBroadcastByNodeAction; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.block.ClusterBlockException; +import org.opensearch.cluster.block.ClusterBlockLevel; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.routing.ShardRouting; +import org.opensearch.cluster.routing.ShardsIterator; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.support.DefaultShardOperationFailedException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.indices.IndicesService; +import org.opensearch.knn.index.KNNIndexShard; +import org.opensearch.transport.TransportService; +import org.opensearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.List; + +/** + * Transport action for profiling KNN vectors in an index + */ +public class KNNProfileTransportAction extends TransportBroadcastByNodeAction< + KNNProfileRequest, + KNNProfileResponse, + KNNProfileShardResult> { + + public static Logger logger = LogManager.getLogger(KNNProfileTransportAction.class); + private final IndicesService indicesService; + + @Inject + public KNNProfileTransportAction( + ClusterService clusterService, + TransportService transportService, + IndicesService indicesService, + ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver + ) { + super( + KNNProfileAction.NAME, + clusterService, + transportService, + actionFilters, + indexNameExpressionResolver, + KNNProfileRequest::new, + ThreadPool.Names.SEARCH + ); + this.indicesService = indicesService; + } + + @Override + protected KNNProfileShardResult readShardResult(StreamInput in) throws IOException { + return new KNNProfileShardResult(in); + } + + @Override + protected KNNProfileResponse newResponse( + KNNProfileRequest request, + int totalShards, + int successfulShards, + int failedShards, + List shardResults, + List shardFailures, + ClusterState clusterState + ) { + return new KNNProfileResponse(totalShards, successfulShards, failedShards, shardResults, shardFailures); + } + + @Override + protected KNNProfileRequest readRequestFrom(StreamInput in) throws IOException { + return new KNNProfileRequest(in); + } + + @Override + protected KNNProfileShardResult shardOperation(KNNProfileRequest request, ShardRouting shardRouting) throws IOException { + KNNIndexShard knnIndexShard = new KNNIndexShard( + indicesService.indexServiceSafe(shardRouting.shardId().getIndex()).getShard(shardRouting.shardId().id()) + ); + + List profileResults = knnIndexShard.profile(request.getFieldName()); + logger.info( + "[KNN] Profile completed for field: {} on shard: {} - stats count: {}", + request.getFieldName(), + shardRouting.shardId(), + profileResults != null ? profileResults.size() : 0 + ); + return new KNNProfileShardResult(shardRouting.shardId(), profileResults); + } + + @Override + protected ShardsIterator shards(ClusterState state, KNNProfileRequest request, String[] concreteIndices) { + return state.routingTable().allShards(concreteIndices); + } + + @Override + protected ClusterBlockException checkGlobalBlock(ClusterState state, KNNProfileRequest request) { + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_READ); + } + + @Override + protected ClusterBlockException checkRequestBlock(ClusterState state, KNNProfileRequest request, String[] concreteIndices) { + return state.blocks().indicesBlockedException(ClusterBlockLevel.METADATA_READ, concreteIndices); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsRequest.java index 2e245e5a33..dabc50fdc5 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNStatsRequest.java @@ -8,6 +8,7 @@ import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.knn.common.featureflags.KNNFeatureFlags; import org.opensearch.knn.plugin.stats.StatNames; import java.io.IOException; @@ -32,7 +33,7 @@ public class KNNStatsRequest extends BaseNodesRequest { */ public KNNStatsRequest() { super((String[]) null); - validStats = StatNames.getNames(); + validStats = getValidStats(); statsToBeRetrieved = new HashSet<>(); } @@ -55,7 +56,7 @@ public KNNStatsRequest(StreamInput in) throws IOException { */ public KNNStatsRequest(String... nodeIds) { super(nodeIds); - validStats = StatNames.getNames(); + validStats = getValidStats(); statsToBeRetrieved = new HashSet<>(); } @@ -95,6 +96,17 @@ public Set getStatsToBeRetrieved() { return statsToBeRetrieved; } + /** + * Get all valid stats, possibly omitting stats associated with disabled features + */ + private Set getValidStats() { + Set stats = StatNames.getNames(); + if (!KNNFeatureFlags.isKNNRemoteVectorBuildEnabled()) { + stats.remove(StatNames.REMOTE_VECTOR_INDEX_BUILD_STATS.getName()); + } + return stats; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/KNNWarmupTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/KNNWarmupTransportAction.java index a738527ff5..70c73421e4 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/KNNWarmupTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/KNNWarmupTransportAction.java @@ -89,6 +89,7 @@ protected EmptyResult shardOperation(KNNWarmupRequest request, ShardRouting shar KNNIndexShard knnIndexShard = new KNNIndexShard( indicesService.indexServiceSafe(shardRouting.shardId().getIndex()).getShard(shardRouting.shardId().id()) ); + knnIndexShard.warmup(); return EmptyResult.INSTANCE; } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java index 9906ab490b..0a3daa5a72 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java @@ -134,7 +134,7 @@ public TrainingModelRequest( .mode(mode) .build(); - KNNEngine knnEngine = EngineResolver.INSTANCE.resolveEngine(knnMethodConfigContext, knnMethodContext, true); + KNNEngine knnEngine = EngineResolver.INSTANCE.resolveEngine(knnMethodConfigContext, knnMethodContext, true, Version.CURRENT); ResolvedMethodContext resolvedMethodContext = knnEngine.resolveMethod(knnMethodContext, knnMethodConfigContext, true, spaceType); this.knnMethodContext = resolvedMethodContext.getKnnMethodContext(); this.compressionLevel = resolvedMethodContext.getCompressionLevel(); diff --git a/src/main/java/org/opensearch/knn/profiler/KNN990ProfileStateReader.java b/src/main/java/org/opensearch/knn/profiler/KNN990ProfileStateReader.java new file mode 100644 index 0000000000..57c54b4724 --- /dev/null +++ b/src/main/java/org/opensearch/knn/profiler/KNN990ProfileStateReader.java @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.profiler; + +import com.google.common.annotations.VisibleForTesting; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.opensearch.knn.common.KNNConstants; + +import java.io.IOException; + +/** + * Reader class for segment profiler states + */ +@Log4j2 +public final class KNN990ProfileStateReader { + + /** + * Reads a segment profiler state for a given field + * + * @param readConfig config for reading the profiler state + * @return SegmentProfilerState object + * @throws IOException if there's an error reading the state + */ + public static SegmentProfilerState read(SegmentProfileStateReadConfig readConfig) throws IOException { + SegmentReadState segmentReadState = readConfig.getSegmentReadState(); + String field = readConfig.getField(); + String stateFileName = getQuantizationStateFileName(segmentReadState); + int fieldNumber = segmentReadState.fieldInfos.fieldInfo(field).getFieldNumber(); + + try (IndexInput input = segmentReadState.directory.openInput(stateFileName, IOContext.DEFAULT)) { + CodecUtil.retrieveChecksum(input); + int numFields = getNumFields(input); + + long position = -1; + int length = 0; + + // Read each field's metadata from the index section, break when correct field is found + for (int i = 0; i < numFields; i++) { + int tempFieldNumber = input.readInt(); + int tempLength = input.readInt(); + long tempPosition = input.readVLong(); + if (tempFieldNumber == fieldNumber) { + position = tempPosition; + length = tempLength; + break; + } + } + + if (position == -1 || length == 0) { + throw new IllegalArgumentException(String.format("Field %s not found", field)); + } + + byte[] stateBytes = readStateBytes(input, position, length); + return SegmentProfilerState.fromBytes(stateBytes); + } + } + + @VisibleForTesting + static int getNumFields(IndexInput input) throws IOException { + long footerStart = input.length() - CodecUtil.footerLength(); + long markerAndIndexPosition = footerStart - Integer.BYTES - Long.BYTES; + input.seek(markerAndIndexPosition); + long indexStartPosition = input.readLong(); + input.seek(indexStartPosition); + return input.readInt(); + } + + @VisibleForTesting + static byte[] readStateBytes(IndexInput input, long position, int length) throws IOException { + input.seek(position); + byte[] stateBytes = new byte[length]; + input.readBytes(stateBytes, 0, length); + return stateBytes; + } + + @VisibleForTesting + static String getQuantizationStateFileName(SegmentReadState state) { + return IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX); + } +} diff --git a/src/main/java/org/opensearch/knn/profiler/RestKNNProfileHandler.java b/src/main/java/org/opensearch/knn/profiler/RestKNNProfileHandler.java new file mode 100644 index 0000000000..7c7e6ae786 --- /dev/null +++ b/src/main/java/org/opensearch/knn/profiler/RestKNNProfileHandler.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.profiler; + +import org.apache.commons.lang.StringUtils; +import org.opensearch.knn.common.exception.KNNInvalidIndicesException; +import org.opensearch.knn.plugin.KNNPlugin; +import org.opensearch.knn.plugin.transport.KNNProfileAction; +import org.opensearch.knn.plugin.transport.KNNProfileRequest; +import com.google.common.collect.ImmutableList; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.transport.client.node.NodeClient; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.index.Index; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestController; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; +import static org.opensearch.action.support.IndicesOptions.strictExpandOpen; + +/** + * RestHandler for k-NN index profile API. API provides the ability for a user to get statistical information + * about vector dimensions in specific indices. + */ +public class RestKNNProfileHandler extends BaseRestHandler { + private static final Logger logger = LogManager.getLogger(RestKNNProfileHandler.class); + private static final String URL_PATH = "/profile/{index}"; + public static String NAME = "knn_profile_action"; + private IndexNameExpressionResolver indexNameExpressionResolver; + private ClusterService clusterService; + + public RestKNNProfileHandler( + Settings settings, + RestController controller, + ClusterService clusterService, + IndexNameExpressionResolver indexNameExpressionResolver + ) { + this.clusterService = clusterService; + this.indexNameExpressionResolver = indexNameExpressionResolver; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.GET, KNNPlugin.KNN_BASE_URI + URL_PATH)); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + KNNProfileRequest knnProfileRequest = createKNNProfileRequest(request); + logger.info( + "[KNN] Profile started for the following indices: {} and field: {}", + String.join(",", knnProfileRequest.indices()), + knnProfileRequest.getFieldName() + ); + return channel -> client.execute(KNNProfileAction.INSTANCE, knnProfileRequest, new RestToXContentListener<>(channel)); + } + + private KNNProfileRequest createKNNProfileRequest(RestRequest request) throws IOException { + String[] indexNames = StringUtils.split(request.param("index"), ","); + Index[] indices = indexNameExpressionResolver.concreteIndices(clusterService.state(), strictExpandOpen(), indexNames); + List invalidIndexNames = new ArrayList<>(); + + Arrays.stream(indices).forEach(index -> { + if (!"true".equals(clusterService.state().metadata().getIndexSafe(index).getSettings().get(KNN_INDEX))) { + invalidIndexNames.add(index.getName()); + } + }); + + if (invalidIndexNames.size() != 0) { + throw new KNNInvalidIndicesException( + invalidIndexNames, + "Profile request rejected. One or more indices have 'index.knn' set to false." + ); + } + + KNNProfileRequest profileRequest = new KNNProfileRequest(indexNames); + + String fieldName = request.param("field_name", "my_vector_field"); + profileRequest.setFieldName(fieldName); + + return profileRequest; + } +} diff --git a/src/main/java/org/opensearch/knn/profiler/SegmentProfileKNNCollector.java b/src/main/java/org/opensearch/knn/profiler/SegmentProfileKNNCollector.java new file mode 100644 index 0000000000..6ed8022a5e --- /dev/null +++ b/src/main/java/org/opensearch/knn/profiler/SegmentProfileKNNCollector.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.profiler; + +import lombok.Getter; +import lombok.Setter; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.TopDocs; + +@Setter +@Getter +public class SegmentProfileKNNCollector implements KnnCollector { + + private SegmentProfilerState segmentProfilerState; + + private final String NATIVE_ENGINE_SEARCH_ERROR_MESSAGE = "Search functionality using codec is not supported with Native Engine Reader"; + + @Override + public boolean earlyTerminated() { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } + + @Override + public void incVisitedCount(int i) { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } + + @Override + public long visitedCount() { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } + + @Override + public long visitLimit() { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } + + @Override + public int k() { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } + + @Override + public boolean collect(int i, float v) { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } + + @Override + public float minCompetitiveSimilarity() { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } + + @Override + public TopDocs topDocs() { + throw new UnsupportedOperationException(NATIVE_ENGINE_SEARCH_ERROR_MESSAGE); + } +} diff --git a/src/main/java/org/opensearch/knn/profiler/SegmentProfileStateReadConfig.java b/src/main/java/org/opensearch/knn/profiler/SegmentProfileStateReadConfig.java new file mode 100644 index 0000000000..a13c7a1479 --- /dev/null +++ b/src/main/java/org/opensearch/knn/profiler/SegmentProfileStateReadConfig.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.profiler; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.apache.lucene.index.SegmentReadState; + +@Getter +@AllArgsConstructor +public class SegmentProfileStateReadConfig { + private SegmentReadState segmentReadState; + private String field; +} diff --git a/src/main/java/org/opensearch/knn/profiler/SegmentProfilerState.java b/src/main/java/org/opensearch/knn/profiler/SegmentProfilerState.java new file mode 100644 index 0000000000..6324bd8ecb --- /dev/null +++ b/src/main/java/org/opensearch/knn/profiler/SegmentProfilerState.java @@ -0,0 +1,171 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.profiler; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.extern.log4j.Log4j2; +import org.apache.commons.math3.stat.descriptive.SummaryStatistics; +import org.opensearch.knn.index.codec.util.KNNCodecUtil; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ObjectInputStream; +import java.io.IOException; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** + * SegmentProfilerState is responsible for analyzing and profiling vector data within segments. + * This class calculates statistical measurements for each dimension of the vectors in a segment. + */ +@Log4j2 +@AllArgsConstructor +public class SegmentProfilerState implements Serializable { + + // Stores statistical summaries for each dimension of the vectors + @Getter + private final List statistics; + + @Getter + private final int dimension; + + /** + * Profiles vectors in a segment by analyzing their statistical values + * @param knnVectorValuesSupplier + * @return SegmentProfilerState + * @throws IOException + */ + public static SegmentProfilerState profileVectors(final Supplier> knnVectorValuesSupplier) throws IOException { + KNNVectorValues vectorValues = knnVectorValuesSupplier.get(); + + if (vectorValues == null) { + log.info("No vector values available"); + return new SegmentProfilerState(new ArrayList<>(), 0); + } + + // Initialize vector values + KNNCodecUtil.initializeVectorValues(vectorValues); + List statistics = new ArrayList<>(); + + // Return empty state if no documents are present + if (vectorValues.docId() == NO_MORE_DOCS) { + log.info("No vectors to profile"); + return new SegmentProfilerState(statistics, vectorValues.dimension()); + } + + int dimension = vectorValues.dimension(); + log.info("Starting vector profiling with dimension: {}", dimension); + + // Initialize statistics collectors for each dimension + for (int i = 0; i < dimension; i++) { + statistics.add(new SummaryStatistics()); + } + + // Process all vectors + int vectorCount = 0; + for (int doc = vectorValues.docId(); doc != NO_MORE_DOCS; doc = vectorValues.nextDoc()) { + vectorCount++; + processVectors(vectorValues.getVector(), statistics); + } + + log.info("Vector profiling completed - processed {} vectors", vectorCount); + + logDimensionStatistics(statistics, dimension); + + return new SegmentProfilerState(statistics, vectorValues.dimension()); + } + + /** + * Helper method to process a vector and update statistics + * @param vector + * @param statistics + */ + private static void processVectors(T vector, List statistics) { + if (vector instanceof float[]) { + processFloatVector((float[]) vector, statistics); + } else if (vector instanceof byte[]) { + processByteVector((byte[]) vector, statistics); + } else { + log.warn("Unsupported vector type: {}.", vector.getClass()); + } + } + + /** + * Processes a float vector by updating the statistical summaries for each dimension + * @param vector + * @param statistics + */ + private static void processFloatVector(float[] vector, List statistics) { + for (int j = 0; j < vector.length; j++) { + statistics.get(j).addValue(vector[j]); + } + } + + /** + * Processes a byte vector by updating the statistical summaries for each dimension + * @param vector + * @param statistics + */ + private static void processByteVector(byte[] vector, List statistics) { + for (int j = 0; j < vector.length; j++) { + statistics.get(j).addValue(vector[j] & 0xFF); + } + } + + /** + * Helper method to log statistics for each dimension + * @param statistics + * @param dimension + */ + private static void logDimensionStatistics(final List statistics, final int dimension) { + for (int i = 0; i < dimension; i++) { + SummaryStatistics stats = statistics.get(i); + log.info( + "Dimension {} stats: mean={}, std={}, min={}, max={}", + i, + stats.getMean(), + stats.getStandardDeviation(), + stats.getMin(), + stats.getMax() + ); + } + } + + /** + * Serializes a SegmentProfilerState to a byte array + * @return + */ + public byte[] toByteArray() { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)) { + + oos.writeObject(this); + return baos.toByteArray(); + } catch (IOException e) { + throw new RuntimeException("Failed to serialize SegmentProfilerStates", e); + } + } + + /** + * Deserializes a SegmentProfilerState from a byte array + * @param bytes + * @return + */ + public static SegmentProfilerState fromBytes(byte[] bytes) { + try (ByteArrayInputStream bais = new ByteArrayInputStream(bytes); ObjectInputStream ois = new ObjectInputStream(bais)) { + + return (SegmentProfilerState) ois.readObject(); + } catch (IOException | ClassNotFoundException e) { + throw new RuntimeException("Failed to deserialize SegmentProfilerState", e); + } + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java index d2b99fef04..f82eea1966 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCache.java @@ -22,6 +22,8 @@ import java.io.Closeable; import java.io.IOException; import java.time.Instant; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import static org.opensearch.knn.index.KNNSettings.QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES; @@ -66,7 +68,8 @@ static QuantizationStateCache getInstance() { } private void buildCache() { - this.cache = CacheBuilder.newBuilder().concurrencyLevel(1).maximumWeight(maxCacheSizeInKB).weigher((k, v) -> { + final long maxCacheSizeInBytes = maxCacheSizeInKB * 1024; + this.cache = CacheBuilder.newBuilder().concurrencyLevel(1).maximumWeight(maxCacheSizeInBytes).weigher((k, v) -> { try { return ((QuantizationState) v).toByteArray().length; } catch (IOException e) { @@ -122,17 +125,12 @@ synchronized void rebuildCache() { * @param fieldName The name of the field. * @return The associated QuantizationState, or null if not present. */ - QuantizationState getQuantizationState(String fieldName) { - return cache.getIfPresent(fieldName); - } - - /** - * Adds or updates a quantization state in the cache. - * @param fieldName The name of the field. - * @param quantizationState The quantization state to store. - */ - void addQuantizationState(String fieldName, QuantizationState quantizationState) { - cache.put(fieldName, quantizationState); + QuantizationState getQuantizationState(final String fieldName, final Callable valueLoader) { + try { + return cache.get(fieldName, valueLoader); + } catch (ExecutionException e) { + throw new RuntimeException(e); + } } /** diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java index 63282029a7..363ea0e8fc 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManager.java @@ -42,26 +42,14 @@ public synchronized void rebuildCache() { * @return The associated QuantizationState */ public QuantizationState getQuantizationState(QuantizationStateReadConfig quantizationStateReadConfig) throws IOException { - QuantizationState quantizationState = QuantizationStateCache.getInstance() - .getQuantizationState(quantizationStateReadConfig.getCacheKey()); - if (quantizationState == null) { - quantizationState = KNN990QuantizationStateReader.read(quantizationStateReadConfig); - if (quantizationState != null) { - addQuantizationState(quantizationStateReadConfig.getCacheKey(), quantizationState); - } - } + final QuantizationState quantizationState = QuantizationStateCache.getInstance() + .getQuantizationState( + quantizationStateReadConfig.getCacheKey(), + () -> KNN990QuantizationStateReader.read(quantizationStateReadConfig) + ); return quantizationState; } - /** - * Adds or updates a quantization state in the cache. - * @param fieldName The name of the field. - * @param quantizationState The quantization state to store. - */ - public void addQuantizationState(String fieldName, QuantizationState quantizationState) { - QuantizationStateCache.getInstance().addQuantizationState(fieldName, quantizationState); - } - /** * Removes the quantization state associated with a given field name. * @param fieldName The name of the field. diff --git a/src/main/resources/org/opensearch/knn/plugin/script/knn_allowlist.txt b/src/main/resources/org/opensearch/knn/plugin/script/knn_allowlist.txt index 388cdda8ae..0462cab039 100644 --- a/src/main/resources/org/opensearch/knn/plugin/script/knn_allowlist.txt +++ b/src/main/resources/org/opensearch/knn/plugin/script/knn_allowlist.txt @@ -4,7 +4,7 @@ # Painless definition of classes used by knn plugin class org.opensearch.knn.index.KNNVectorScriptDocValues { - float[] getValue() + Object getValue() } static_import { float l2Squared(List, org.opensearch.knn.index.KNNVectorScriptDocValues) from_class org.opensearch.knn.plugin.script.KNNScoringUtil diff --git a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java index 2ec9ce6b50..928a7ea7b0 100644 --- a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java @@ -41,6 +41,9 @@ import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.test.hamcrest.OpenSearchAssertions; +import com.carrotsearch.randomizedtesting.ThreadFilter; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters; + import java.io.IOException; import java.util.Base64; import java.util.Collection; @@ -63,7 +66,18 @@ import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +@ThreadLeakFilters(defaultFilters = true, filters = { KNNSingleNodeTestCase.ForkJoinFilter.class }) public class KNNSingleNodeTestCase extends OpenSearchSingleNodeTestCase { + /** + * The the ForkJoinPool.commonPool() never terminates until program shutdown. + */ + public static final class ForkJoinFilter implements ThreadFilter { + @Override + public boolean reject(Thread t) { + return t.getName().startsWith("ForkJoinPool.commonPool-worker"); + } + } + @Override public void setUp() throws Exception { super.setUp(); diff --git a/src/test/java/org/opensearch/knn/index/AdvancedFilteringUseCasesIT.java b/src/test/java/org/opensearch/knn/index/AdvancedFilteringUseCasesIT.java index b2d69dbb1b..427f42ee94 100644 --- a/src/test/java/org/opensearch/knn/index/AdvancedFilteringUseCasesIT.java +++ b/src/test/java/org/opensearch/knn/index/AdvancedFilteringUseCasesIT.java @@ -58,7 +58,7 @@ public class AdvancedFilteringUseCasesIT extends KNNRestTestCase { private static final String TERM_FIELD = "term"; - private static final int k = 20; + private static final int k = 100; private static final String FIELD_NAME_METADATA = "parking"; @@ -448,7 +448,7 @@ private void validateFilterSearch(final String query, final String engine) throw String response = EntityUtils.toString(performSearch(INDEX_NAME, query).getEntity()); // Validate number of documents returned as the expected number of documents Assert.assertEquals("For engine " + engine + ", hits: ", DOCUMENT_IN_RESPONSE, parseHits(response)); - Assert.assertEquals("For engine " + engine + ", totalSearchHits: ", k, parseTotalSearchHits(response)); + Assert.assertEquals("For engine " + engine + ", totalSearchHits: ", NUM_DOCS / 2, parseTotalSearchHits(response)); if (KNNEngine.getEngine(engine) == KNNEngine.FAISS) { // Update the filter threshold to 0 to ensure that we are hitting ANN Search use case for FAISS updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, 0)); @@ -456,7 +456,11 @@ private void validateFilterSearch(final String query, final String engine) throw // Validate number of documents returned as the expected number of documents Assert.assertEquals("For engine " + engine + ", hits with ANN search :", DOCUMENT_IN_RESPONSE, parseHits(response)); - Assert.assertEquals("For engine " + engine + ", totalSearchHits with ANN search :", k, parseTotalSearchHits(response)); + Assert.assertEquals( + "For engine " + engine + ", totalSearchHits with ANN search :", + NUM_DOCS / 2, + parseTotalSearchHits(response) + ); } } diff --git a/src/test/java/org/opensearch/knn/index/ExplainIT.java b/src/test/java/org/opensearch/knn/index/ExplainIT.java new file mode 100644 index 0000000000..db2c0e700f --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/ExplainIT.java @@ -0,0 +1,333 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import com.google.common.collect.ImmutableMap; +import lombok.SneakyThrows; +import org.apache.hc.core5.http.ParseException; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.index.mapper.Mode; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.knn.index.query.parser.RescoreParser; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.ANN_SEARCH; +import static org.opensearch.knn.common.KNNConstants.COMPRESSION_LEVEL_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.DISK_BASED_SEARCH; +import static org.opensearch.knn.common.KNNConstants.EXACT_SEARCH; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; +import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.MODE_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.RADIAL_SEARCH; + +public class ExplainIT extends KNNRestTestCase { + + @SneakyThrows + public void testAnnSearch() { + int dimension = 128; + int numDocs = 100; + createDefaultKnnIndex(dimension); + indexTestData(INDEX_NAME, FIELD_NAME, dimension, numDocs); + float[] queryVector = new float[dimension]; + Arrays.fill(queryVector, (float) numDocs); + XContentBuilder queryBuilder = buildSearchQuery(FIELD_NAME, 10, queryVector, null); + // validate primaries are working + validateExplainSearchResponse( + queryBuilder, + ANN_SEARCH, + VectorDataType.FLOAT.name(), + SpaceType.L2.getValue(), + SpaceType.L2.explainScoreTranslation(0) + ); + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testANNWithExactSearch() { + createDefaultKnnIndex(2); + indexTestData(INDEX_NAME, FIELD_NAME, 2, 2); + + // Execute the search request with a match all query to ensure exact logic gets called + updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, 1000)); + + float[] queryVector = new float[] { 1.0f, 1.0f }; + + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, 2, QueryBuilders.matchAllQuery()); + XContentBuilder queryBuilder = XContentFactory.jsonBuilder().startObject().startObject("query"); + knnQueryBuilder.doXContent(queryBuilder, ToXContent.EMPTY_PARAMS); + queryBuilder.endObject().endObject(); + + validateExplainSearchResponse( + queryBuilder, + ANN_SEARCH, + EXACT_SEARCH, + VectorDataType.FLOAT.name(), + SpaceType.L2.getValue(), + "since filteredIds", + "is less than or equal to K" + ); + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testRadialWithANNSearch() { + int dimension = 128; + int numDocs = 100; + createDefaultKnnIndex(dimension); + indexTestData(INDEX_NAME, FIELD_NAME, dimension, numDocs); + float[] queryVector = new float[dimension]; + Arrays.fill(queryVector, (float) numDocs); + + float distance = 15f; + XContentBuilder queryBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("knn") + .startObject(FIELD_NAME) + .field("vector", queryVector) + .field(MAX_DISTANCE, distance) + .endObject() + .endObject() + .endObject() + .endObject(); + + validateExplainSearchResponse( + queryBuilder, + RADIAL_SEARCH, + ANN_SEARCH, + VectorDataType.FLOAT.name(), + SpaceType.L2.getValue(), + SpaceType.L2.explainScoreTranslation(0), + String.valueOf(distance) + ); + // Delete index + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testRadialWithExactSearch() { + setupKNNIndexForFilterQuery(); + + final float[] queryVector = new float[] { 3.3f, 3.0f, 5.0f }; + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery("color", "red"); + float distance = 15f; + + XContentBuilder queryBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("knn") + .startObject(FIELD_NAME) + .field("vector", queryVector) + .field(MAX_DISTANCE, distance) + .field("filter", termQueryBuilder) + .endObject() + .endObject() + .endObject() + .endObject(); + + validateExplainSearchResponse( + queryBuilder, + RADIAL_SEARCH, + EXACT_SEARCH, + VectorDataType.FLOAT.name(), + SpaceType.L2.getValue(), + String.valueOf(distance) + ); + + // Delete index + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testDiskBasedSearchWithDefaultRescoring() { + int dimension = 16; + float[] queryVector = new float[] { + 1.0f, + 2.0f, + 1.0f, + 2.0f, + 1.0f, + 2.0f, + 1.0f, + 2.0f, + 1.0f, + 2.0f, + 1.0f, + 2.0f, + 1.0f, + 2.0f, + 1.0f, + 2.0f }; + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", dimension) + .field(MODE_PARAMETER, Mode.ON_DISK.getName()) + .field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x32.getName()) + .endObject() + .endObject() + .endObject(); + createKnnIndex(INDEX_NAME, builder.toString()); + addKNNDocs(INDEX_NAME, FIELD_NAME, dimension, 0, 5); + + // Search with default rescore + XContentBuilder queryBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("knn") + .startObject(FIELD_NAME) + .field("vector", queryVector) + .field("k", 5) + .endObject() + .endObject() + .endObject() + .endObject(); + + validateExplainSearchResponse( + queryBuilder, + DISK_BASED_SEARCH, + ANN_SEARCH, + VectorDataType.FLOAT.name(), + SpaceType.L2.getValue(), + "shard level rescoring enabled", + String.valueOf(dimension) + ); + } + + @SneakyThrows + public void testDiskBasedSearchWithRescoringDisabled() { + int dimension = 16; + float[] queryVector = new float[] { + 1.0f, + 2.0f, + 1.0f, + 2.0f, + 1.0f, + 2.0f, + 1.0f, + 2.0f, + 1.0f, + 2.0f, + 1.0f, + 2.0f, + 1.0f, + 2.0f, + 1.0f, + 2.0f }; + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", dimension) + .field(MODE_PARAMETER, Mode.ON_DISK.getName()) + .field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x32.getName()) + .endObject() + .endObject() + .endObject(); + createKnnIndex(INDEX_NAME, builder.toString()); + addKNNDocs(INDEX_NAME, FIELD_NAME, dimension, 0, 5); + + // Search without rescore + XContentBuilder queryBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("knn") + .startObject(FIELD_NAME) + .field("vector", queryVector) + .field("k", 5) + .field(RescoreParser.RESCORE_PARAMETER, false) + .endObject() + .endObject() + .endObject() + .endObject(); + + validateExplainSearchResponse( + queryBuilder, + DISK_BASED_SEARCH, + ANN_SEARCH, + VectorDataType.FLOAT.name(), + SpaceType.L2.getValue(), + "shard level rescoring disabled", + String.valueOf(dimension) + ); + } + + private void validateExplainSearchResponse(XContentBuilder queryBuilder, String... descriptions) throws IOException, ParseException { + String responseBody = EntityUtils.toString(performSearch(INDEX_NAME, queryBuilder.toString(), "explain=true").getEntity()); + List searchResponseHits = parseSearchResponseHits(responseBody); + searchResponseHits.stream().forEach(hit -> { + Map hitMap = (Map) hit; + Double score = (Double) hitMap.get("_score"); + String explanation = hitMap.get("_explanation").toString(); + assertNotNull(explanation); + for (String description : descriptions) { + assertTrue(explanation.contains(description)); + } + assertTrue(explanation.contains(String.valueOf(score))); + }); + } + + private void createDefaultKnnIndex(int dimension) throws IOException { + // Create Mappings + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .endObject() + .endObject() + .endObject() + .endObject(); + final String mapping = builder.toString(); + createKnnIndex(INDEX_NAME, getKNNDefaultIndexSettings(), mapping); + } + + private void setupKNNIndexForFilterQuery() throws Exception { + createDefaultKnnIndex(3); + addKnnDocWithAttributes("doc1", new float[] { 6.0f, 7.9f, 3.1f }, ImmutableMap.of("color", "red", "taste", "sweet")); + addKnnDocWithAttributes("doc2", new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of("color", "green")); + addKnnDocWithAttributes("doc3", new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of("color", "red")); + + refreshIndex(INDEX_NAME); + } + + private void indexTestData(final String indexName, final String fieldName, final int dimension, final int numDocs) throws Exception { + for (int i = 0; i < numDocs; i++) { + float[] indexVector = new float[dimension]; + Arrays.fill(indexVector, (float) i); + addKnnDocWithAttributes(indexName, Integer.toString(i), fieldName, indexVector, ImmutableMap.of("rating", String.valueOf(i))); + } + + // Assert that all docs are ingested + refreshAllNonSystemIndices(); + assertEquals(numDocs, getDocCount(indexName)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java index 3f671afcfc..5d90c0fe44 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java @@ -58,20 +58,22 @@ public void tearDown() throws Exception { directory.close(); } + @SuppressWarnings("unchecked") public void testGetScriptValues() { KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData( leafReaderContext.reader(), MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT ); - ScriptDocValues scriptValues = leafFieldData.getScriptValues(); + ScriptDocValues scriptValues = (ScriptDocValues) leafFieldData.getScriptValues(); assertNotNull(scriptValues); assertTrue(scriptValues instanceof KNNVectorScriptDocValues); } + @SuppressWarnings("unchecked") public void testGetScriptValuesWrongFieldName() { KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), "invalid", VectorDataType.FLOAT); - ScriptDocValues scriptValues = leafFieldData.getScriptValues(); + ScriptDocValues scriptValues = (ScriptDocValues) leafFieldData.getScriptValues(); assertNotNull(scriptValues); } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java index 71aab0b092..737daaf315 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java @@ -51,37 +51,39 @@ public void tearDown() throws Exception { /** Test for Float Vector Values */ @Test + @SuppressWarnings("unchecked") public void testFloatVectorValues() throws IOException { createKNNVectorDocument(directory, FloatVectorValues.class); reader = DirectoryReader.open(directory); LeafReader leafReader = reader.leaves().get(0).reader(); // Separate scriptDocValues instance for this test - KNNVectorScriptDocValues scriptDocValues = KNNVectorScriptDocValues.create( + KNNVectorScriptDocValues scriptDocValues = KNNVectorScriptDocValues.create( leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME), MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT ); scriptDocValues.setNextDocId(0); - Assert.assertArrayEquals(SAMPLE_VECTOR_DATA, scriptDocValues.getValue(), 0.1f); + Assert.assertArrayEquals(SAMPLE_VECTOR_DATA, ((KNNVectorScriptDocValues) scriptDocValues).getValue(), 0.1f); } /** Test for Byte Vector Values */ @Test + @SuppressWarnings("unchecked") public void testByteVectorValues() throws IOException { createKNNVectorDocument(directory, ByteVectorValues.class); reader = DirectoryReader.open(directory); LeafReader leafReader = reader.leaves().get(0).reader(); - KNNVectorScriptDocValues scriptDocValues = KNNVectorScriptDocValues.create( + KNNVectorScriptDocValues scriptDocValues = KNNVectorScriptDocValues.create( leafReader.getByteVectorValues(MOCK_INDEX_FIELD_NAME), MOCK_INDEX_FIELD_NAME, VectorDataType.BYTE ); scriptDocValues.setNextDocId(0); - Assert.assertArrayEquals(new float[] { SAMPLE_BYTE_VECTOR_DATA[0], SAMPLE_BYTE_VECTOR_DATA[1] }, scriptDocValues.getValue(), 0.1f); + Assert.assertArrayEquals(SAMPLE_BYTE_VECTOR_DATA, ((KNNVectorScriptDocValues) scriptDocValues).getValue()); } /** Test for Binary Vector Values */ @@ -91,7 +93,7 @@ public void testBinaryVectorValues() throws IOException { reader = DirectoryReader.open(directory); LeafReader leafReader = reader.leaves().get(0).reader(); - KNNVectorScriptDocValues scriptDocValues = KNNVectorScriptDocValues.create( + KNNVectorScriptDocValues scriptDocValues = KNNVectorScriptDocValues.create( leafReader.getBinaryDocValues(MOCK_INDEX_FIELD_NAME), MOCK_INDEX_FIELD_NAME, VectorDataType.BINARY @@ -108,7 +110,7 @@ public void testGetValueFails() throws IOException { reader = DirectoryReader.open(directory); LeafReader leafReader = reader.leaves().get(0).reader(); - KNNVectorScriptDocValues scriptDocValues = KNNVectorScriptDocValues.create( + KNNVectorScriptDocValues scriptDocValues = KNNVectorScriptDocValues.create( leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME), MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT @@ -156,7 +158,7 @@ public void testUnsupportedValues() throws IOException { /** Ensure empty values case */ @Test public void testEmptyValues() throws IOException { - KNNVectorScriptDocValues values = KNNVectorScriptDocValues.emptyValues(MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT); + KNNVectorScriptDocValues values = KNNVectorScriptDocValues.emptyValues(MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT); assertEquals(0, values.size()); } diff --git a/src/test/java/org/opensearch/knn/index/RemoteBuildIT.java b/src/test/java/org/opensearch/knn/index/RemoteBuildIT.java new file mode 100644 index 0000000000..70a40e0b86 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/RemoteBuildIT.java @@ -0,0 +1,397 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.Floats; +import lombok.AllArgsConstructor; +import lombok.SneakyThrows; +import org.apache.hc.core5.http.ParseException; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.Before; +import org.junit.BeforeClass; +import org.opensearch.client.Response; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.knn.KNNJsonQueryBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.knn.KNNResult; +import org.opensearch.knn.TestUtils; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.common.featureflags.KNNFeatureFlags; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.plugin.script.KNNScoringUtil; +import org.apache.lucene.util.VectorUtil; + +import java.io.IOException; +import java.net.URL; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.Collection; +import java.util.function.BiFunction; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; +import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.MIN_SCORE; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; +import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD; +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX_REMOTE_VECTOR_BUILD; +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD; + +@AllArgsConstructor +public class RemoteBuildIT extends KNNRestTestCase { + private static TestUtils.TestData testData; + private String description; + private SpaceType spaceType; + + @BeforeClass + public static void setUpClass() throws IOException { + if (RemoteBuildIT.class.getClassLoader() == null) { + throw new IllegalStateException("ClassLoader of RemoteBuildIT Class is null"); + } + URL testIndexVectors = RemoteBuildIT.class.getClassLoader().getResource("data/test_vectors_1000x128.json"); + URL testQueries = RemoteBuildIT.class.getClassLoader().getResource("data/test_queries_100x128.csv"); + assert testIndexVectors != null; + assert testQueries != null; + testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath()); + } + + @Before + public void setupAdditionalRemoteIndexBuildSettings() throws Exception { + updateClusterSettings(KNNFeatureFlags.KNN_REMOTE_VECTOR_BUILD_SETTING.getKey(), true); + updateClusterSettings(KNNSettings.KNN_REMOTE_VECTOR_REPO, "integ-test-repo"); + updateClusterSettings(KNNSettings.KNN_REMOTE_BUILD_SERVICE_ENDPOINT, "http://0.0.0.0:80"); + updateClusterSettings(KNNSettings.KNN_REMOTE_BUILD_CLIENT_POLL_INTERVAL, 0); + setupRepository("integ-test-repo"); + } + + @ParametersFactory(argumentFormatting = "description:%1$s; spaceType:%2$s") + public static Collection parameters() throws IOException { + return Arrays.asList( + $$( + $("SpaceType L2", SpaceType.L2), + $("SpaceType INNER_PRODUCT", SpaceType.INNER_PRODUCT), + $("SpaceType COSINESIMIL", SpaceType.COSINESIMIL) + ) + ); + } + + @SneakyThrows + public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHNSWFlat_thenSucceed() { + List mValues = ImmutableList.of(16, 32, 64, 128); + List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + List efSearchValues = ImmutableList.of(16, 32, 64, 128); + + Integer dimension = testData.indexData.vectors[0].length; + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_M, randomFrom(mValues)) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, randomFrom(efConstructionValues)) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, randomFrom(efSearchValues)) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + final Settings knnIndexSettings = buildKNNIndexSettingsRemoteBuild(0); + createKnnIndex(INDEX_NAME, knnIndexSettings, mapping); + + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(INDEX_NAME))); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + INDEX_NAME, + Integer.toString(testData.indexData.docs[i]), + FIELD_NAME, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + // Assert we have the right number of documents + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(INDEX_NAME)); + + float distance = 300000000000f; + validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, distance, null, spaceType, null, null); + + forceMergeKnnIndex(INDEX_NAME, 1); + validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, distance, null, spaceType, null, null); + + // Delete index + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testFilteredSearchWithFaissHnsw_whenFiltersMatchAllDocs_thenReturnCorrectResults() { + String filterFieldName = "color"; + final int expectResultSize = randomIntBetween(1, 3); + final String filterValue = "red"; + final Settings knnIndexSettings = buildKNNIndexSettingsRemoteBuild(0); + createKnnIndex(INDEX_NAME, knnIndexSettings, createKnnIndexMapping(FIELD_NAME, 3, METHOD_HNSW, FAISS_NAME, spaceType.getValue())); + + // ingest 5 vector docs into the index with the same field {"color": "red"} + for (int i = 0; i < 5; i++) { + addKnnDocWithAttributes(String.valueOf(i), new float[] { i + 1, i + 1, i + 1 }, ImmutableMap.of(filterFieldName, filterValue)); + } + + refreshIndex(INDEX_NAME); + forceMergeKnnIndex(INDEX_NAME); + + updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, 0)); + + Float[] queryVector = { 3f, 3f, 3f }; + // All docs in one segment will match the filters value + String query = KNNJsonQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(expectResultSize) + .filterFieldName(filterFieldName) + .filterValue(filterValue) + .build() + .getQueryString(); + Response response = searchKNNIndex(INDEX_NAME, query, expectResultSize); + String entity = EntityUtils.toString(response.getEntity()); + List docIds = parseIds(entity); + assertEquals(expectResultSize, docIds.size()); + assertEquals(expectResultSize, parseTotalSearchHits(entity)); + } + + @SneakyThrows + public void testHNSW_whenIndexedAndQueried_thenSucceed() { + String indexName = "test-index-hnsw"; + String fieldName = "test-field-hnsw"; + + List mValues = ImmutableList.of(16, 32, 64, 128); + List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + List efSearchValues = ImmutableList.of(16, 32, 64, 128); + + int dimension = 128; + int numDocs = 100; + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_M, randomFrom(mValues)) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, randomFrom(efConstructionValues)) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, randomFrom(efSearchValues)) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + final Settings knnIndexSettings = buildKNNIndexSettingsRemoteBuild(0); + + createKnnIndex(indexName, knnIndexSettings, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); + indexTestData(indexName, fieldName, dimension, numDocs); + + refreshIndex(indexName); + forceMergeKnnIndex(indexName); + + queryTestData(indexName, fieldName, dimension, numDocs); + deleteKNNIndex(indexName); + validateGraphEviction(); + } + + private void indexTestData(final String indexName, final String fieldName, final int dimension, final int numDocs) throws Exception { + for (int i = 0; i < numDocs; i++) { + float[] indexVector = new float[dimension]; + Arrays.fill(indexVector, (float) i + 1); + addKnnDoc(indexName, Integer.toString(i), fieldName, indexVector); + } + + // Assert that all docs are ingested + refreshAllNonSystemIndices(); + assertEquals(numDocs, getDocCount(indexName)); + } + + @SneakyThrows + private void queryTestData(final String indexName, final String fieldName, final int dimension, final int numDocs) throws IOException, + ParseException { + float[] queryVector = new float[dimension]; + Arrays.fill(queryVector, (float) numDocs); + int k = 10; + + Response searchResponse = searchKNNIndex(indexName, buildSearchQuery(fieldName, k, queryVector, null), k); + final String responseBody = EntityUtils.toString(searchResponse.getEntity()); + List results = parseSearchResponse(responseBody, fieldName); + assertEquals(k, results.size()); + + if (spaceType == SpaceType.COSINESIMIL) { + final List actualScores = parseSearchResponseScore(responseBody, fieldName); + final BiFunction scoringFunction = VectorUtil::cosine; + + for (int j = 0; j < k; j++) { + final float[] primitiveArray = results.get(j).getVector(); + assertEquals( + KNNEngine.FAISS.score(scoringFunction.apply(queryVector, primitiveArray), SpaceType.COSINESIMIL), + actualScores.get(j), + 0.0001 + ); + } + } else { + for (int i = 0; i < k; i++) { + assertEquals(numDocs - i - 1, Integer.parseInt(results.get(i).getDocId())); + } + } + } + + private void validateGraphEviction() throws Exception { + // Search every 5 seconds 14 times to confirm graph gets evicted + int intervals = 14; + for (int i = 0; i < intervals; i++) { + if (getTotalGraphsInCache() == 0) { + return; + } + + Thread.sleep(5 * 1000); + } + + fail("Graphs are not getting evicted"); + } + + private List> validateRadiusSearchResults( + String indexName, + String fieldName, + float[][] queryVectors, + Float distanceThreshold, + Float scoreThreshold, + final SpaceType spaceType, + TermQueryBuilder filterQuery, + Map methodParameters + ) throws IOException, ParseException { + List> queryResults = new ArrayList<>(); + for (float[] queryVector : queryVectors) { + XContentBuilder queryBuilder = XContentFactory.jsonBuilder().startObject().startObject("query"); + queryBuilder.startObject("knn"); + queryBuilder.startObject(fieldName); + queryBuilder.field("vector", queryVector); + if (distanceThreshold != null) { + queryBuilder.field(MAX_DISTANCE, distanceThreshold); + } else if (scoreThreshold != null) { + queryBuilder.field(MIN_SCORE, scoreThreshold); + } else { + throw new IllegalArgumentException("Invalid threshold"); + } + if (filterQuery != null) { + queryBuilder.field("filter", filterQuery); + } + if (methodParameters != null && methodParameters.size() > 0) { + queryBuilder.startObject(METHOD_PARAMETER); + for (Map.Entry entry : methodParameters.entrySet()) { + queryBuilder.field(entry.getKey(), entry.getValue()); + } + queryBuilder.endObject(); + } + queryBuilder.endObject(); + queryBuilder.endObject(); + queryBuilder.endObject().endObject(); + final String responseBody = EntityUtils.toString(searchKNNIndex(indexName, queryBuilder, 10).getEntity()); + + List knnResults = parseSearchResponse(responseBody, fieldName); + + for (KNNResult knnResult : knnResults) { + float[] vector = knnResult.getVector(); + float distance = TestUtils.computeDistFromSpaceType(spaceType, vector, queryVector); + if (spaceType == SpaceType.L2) { + assertTrue(KNNScoringUtil.l2Squared(queryVector, vector) <= distance); + } else if (spaceType == SpaceType.INNER_PRODUCT) { + assertTrue(KNNScoringUtil.innerProduct(queryVector, vector) >= distance); + } else if (spaceType == SpaceType.COSINESIMIL) { + assertTrue(KNNScoringUtil.cosinesimil(queryVector, vector) >= distance); + } else { + throw new IllegalArgumentException("Invalid space type"); + } + } + queryResults.add(knnResults); + } + return queryResults; + } + + @SneakyThrows + protected void setupRepository(String repository) { + final String bucket = System.getProperty("test.bucket", null); + final String base_path = System.getProperty("test.base_path", null); + + Settings.Builder builder = Settings.builder() + .put("bucket", bucket) + .put("base_path", base_path) + .put("region", "us-east-1") + .put("s3_upload_retry_enabled", false); + + final String remoteBuild = System.getProperty("test.remoteBuild", null); + if (remoteBuild != null && remoteBuild.equals("s3.localStack")) { + builder.put("endpoint", "http://s3.localhost.localstack.cloud:4566"); + } + + registerRepository(repository, "s3", false, builder.build()); + + } + + protected Settings buildKNNIndexSettingsRemoteBuild(int approximateThreshold) { + return Settings.builder() + .put("number_of_shards", 1) + .put("number_of_replicas", 0) + .put(KNN_INDEX, true) + .put(INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD, approximateThreshold) + .put(KNN_INDEX_REMOTE_VECTOR_BUILD, true) + .put(KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD, "0kb") + .build(); + } +} diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java index 86aac950cf..32f3f8ecf6 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java @@ -32,7 +32,7 @@ public class VectorDataTypeTests extends KNNTestCase { @SneakyThrows public void testGetDocValuesWithFloatVectorDataType() { - KNNVectorScriptDocValues scriptDocValues = getKNNFloatVectorScriptDocValues(); + KNNVectorScriptDocValues scriptDocValues = getKNNFloatVectorScriptDocValues(); scriptDocValues.setNextDocId(0); Assert.assertArrayEquals(SAMPLE_FLOAT_VECTOR_DATA, scriptDocValues.getValue(), 0.1f); @@ -43,35 +43,37 @@ public void testGetDocValuesWithFloatVectorDataType() { @SneakyThrows public void testGetDocValuesWithByteVectorDataType() { - KNNVectorScriptDocValues scriptDocValues = getKNNByteVectorScriptDocValues(); + KNNVectorScriptDocValues scriptDocValues = getKNNByteVectorScriptDocValues(); scriptDocValues.setNextDocId(0); - Assert.assertArrayEquals(SAMPLE_FLOAT_VECTOR_DATA, scriptDocValues.getValue(), 0.1f); + Assert.assertArrayEquals(SAMPLE_BYTE_VECTOR_DATA, scriptDocValues.getValue()); reader.close(); directory.close(); } + @SuppressWarnings("unchecked") @SneakyThrows - private KNNVectorScriptDocValues getKNNFloatVectorScriptDocValues() { + private KNNVectorScriptDocValues getKNNFloatVectorScriptDocValues() { directory = newDirectory(); createKNNFloatVectorDocument(directory); reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - return KNNVectorScriptDocValues.create( + return (KNNVectorScriptDocValues) KNNVectorScriptDocValues.create( leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME), VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME, VectorDataType.FLOAT ); } + @SuppressWarnings("unchecked") @SneakyThrows - private KNNVectorScriptDocValues getKNNByteVectorScriptDocValues() { + private KNNVectorScriptDocValues getKNNByteVectorScriptDocValues() { directory = newDirectory(); createKNNByteVectorDocument(directory); reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - return KNNVectorScriptDocValues.create( + return (KNNVectorScriptDocValues) KNNVectorScriptDocValues.create( leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME), VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME, VectorDataType.BYTE @@ -101,8 +103,7 @@ private void createKNNByteVectorDocument(Directory directory) throws IOException public void testGetVectorFromBytesRef_whenBinary_thenException() { byte[] vector = { 1, 2, 3 }; - float[] expected = { 1, 2, 3 }; BytesRef bytesRef = new BytesRef(vector); - assertArrayEquals(expected, VectorDataType.BINARY.getVectorFromBytesRef(bytesRef), 0.01f); + assertArrayEquals(vector, VectorDataType.BINARY.getVectorFromBytesRef(bytesRef)); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/backward_codecs/KNN9120Codec/DerivedSourceVectorInjectorTests.java b/src/test/java/org/opensearch/knn/index/codec/backward_codecs/KNN9120Codec/DerivedSourceVectorInjectorTests.java index 801b1053af..d769f0656e 100644 --- a/src/test/java/org/opensearch/knn/index/codec/backward_codecs/KNN9120Codec/DerivedSourceVectorInjectorTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/backward_codecs/KNN9120Codec/DerivedSourceVectorInjectorTests.java @@ -59,7 +59,7 @@ public void testInjectVectors() { }); DerivedSourceVectorInjector derivedSourceVectorInjector = new DerivedSourceVectorInjector( - new KNN9120DerivedSourceReadersSupplier(s -> null, s -> null, s -> null, s -> null), + new KNN9120DerivedSourceReaders(null, null, null, null), null, fields ); @@ -117,20 +117,18 @@ public void testShouldInject() { KNNCodecTestUtil.FieldInfoBuilder.builder("test3").build() ); - try ( - DerivedSourceVectorInjector vectorInjector = new DerivedSourceVectorInjector( - new KNN9120DerivedSourceReadersSupplier(s -> null, s -> null, s -> null, s -> null), - null, - fields - ) - ) { - assertTrue(vectorInjector.shouldInject(null, null)); - assertTrue(vectorInjector.shouldInject(new String[] { "test1" }, null)); - assertTrue(vectorInjector.shouldInject(new String[] { "test1", "test2", "test3" }, null)); - assertTrue(vectorInjector.shouldInject(null, new String[] { "test2" })); - assertTrue(vectorInjector.shouldInject(new String[] { "test1" }, new String[] { "test2" })); - assertTrue(vectorInjector.shouldInject(new String[] { "test1" }, new String[] { "test2", "test3" })); - assertFalse(vectorInjector.shouldInject(null, new String[] { "test1", "test2", "test3" })); - } + DerivedSourceVectorInjector vectorInjector = new DerivedSourceVectorInjector( + new KNN9120DerivedSourceReaders(null, null, null, null), + null, + fields + ); + assertTrue(vectorInjector.shouldInject(null, null)); + assertTrue(vectorInjector.shouldInject(new String[] { "test1" }, null)); + assertTrue(vectorInjector.shouldInject(new String[] { "test1", "test2", "test3" }, null)); + assertTrue(vectorInjector.shouldInject(null, new String[] { "test2" })); + assertTrue(vectorInjector.shouldInject(new String[] { "test1" }, new String[] { "test2" })); + assertTrue(vectorInjector.shouldInject(new String[] { "test1" }, new String[] { "test2", "test3" })); + assertFalse(vectorInjector.shouldInject(null, new String[] { "test1", "test2", "test3" })); + } } diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceLuceneHelperTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceLuceneHelperTests.java new file mode 100644 index 0000000000..c63637661e --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceLuceneHelperTests.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.SneakyThrows; +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.NumericDocValues; +import org.apache.lucene.index.SegmentReadState; +import org.junit.Before; +import org.mockito.Mock; +import org.opensearch.knn.KNNTestCase; + +import java.io.IOException; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.index.codec.derivedsource.DerivedSourceLuceneHelper.NO_CHILDREN_INDICATOR; + +public class DerivedSourceLuceneHelperTests extends KNNTestCase { + @Mock + private FieldInfos fieldInfos; + + @Mock + private FieldInfo fieldInfo; + + @Mock + private DerivedSourceReaders derivedSourceReaders; + + @Mock + private DocValuesProducer docValuesProducer; + + @Mock + private NumericDocValues numericDocValues; + + private SegmentReadState segmentReadState; + private DerivedSourceLuceneHelper helper; + + @Before + public void setUp() throws Exception { + super.setUp(); + segmentReadState = new SegmentReadState(null, null, fieldInfos, null, null); + + when(fieldInfos.fieldInfo("_primary_term")).thenReturn(fieldInfo); + when(derivedSourceReaders.getDocValuesProducer()).thenReturn(docValuesProducer); + when(docValuesProducer.getNumeric(fieldInfo)).thenReturn(numericDocValues); + helper = new DerivedSourceLuceneHelper(derivedSourceReaders, segmentReadState); + } + + @SneakyThrows + public void testGetFirstChild_WhenNoDocumentsBeforeParent() throws IOException { + int parentDocId = 5; + int startingPoint = 0; + when(numericDocValues.advance(startingPoint)).thenReturn(10); // First doc is after parent + when(numericDocValues.docID()).thenReturn(10, NO_MORE_DOCS); + + int result = helper.getFirstChild(parentDocId, startingPoint); + + assertEquals(0, result); + } + + @SneakyThrows + public void testGetFirstChild_WhenNoChildren() { + int parentDocId = 5; + int startingPoint = 0; + when(numericDocValues.advance(startingPoint)).thenReturn(4); + when(numericDocValues.docID()).thenReturn(4, 4, 4, 5, 5); + when(numericDocValues.nextDoc()).thenReturn(5); + + int result = helper.getFirstChild(parentDocId, startingPoint); + + assertEquals(NO_CHILDREN_INDICATOR, result); + } + + @SneakyThrows + public void testGetFirstChild_WhenChildrenExist() { + int parentDocId = 10; + int startingPoint = 0; + when(numericDocValues.advance(startingPoint)).thenReturn(5); + when(numericDocValues.docID()).thenReturn(5, 5, 5, 10, 10); + when(numericDocValues.nextDoc()).thenReturn(10); + + int result = helper.getFirstChild(parentDocId, startingPoint); + + assertEquals(6, result); // Should return previousParentDocId + 1 + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersTests.java b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersTests.java new file mode 100644 index 0000000000..0b1b8b55f7 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReadersTests.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.derivedsource; + +import lombok.SneakyThrows; +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.mockito.Mock; +import org.opensearch.knn.KNNTestCase; + +import static org.mockito.Mockito.verify; + +public class DerivedSourceReadersTests extends KNNTestCase { + + @Mock + private KnnVectorsReader mockKnnVectorsReader; + @Mock + private DocValuesProducer mockDocValuesProducer; + + private DerivedSourceReaders readers; + + @SneakyThrows + public void testInitialReferenceCount() { + readers = new DerivedSourceReaders(mockKnnVectorsReader, mockDocValuesProducer); + + // Initial reference count is 1, so closing once should trigger actual close + readers.close(); + + verify(mockKnnVectorsReader).close(); + verify(mockDocValuesProducer).close(); + } + + @SneakyThrows + public void testNullReaders() { + // Test with null readers to ensure no NPE + DerivedSourceReaders nullReaders = new DerivedSourceReaders(null, null); + nullReaders.close(); // Should not throw any exception + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategyTests.java index 931031d2eb..a9487f5a36 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategyTests.java @@ -16,6 +16,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue; import org.opensearch.remoteindexbuild.model.RemoteBuildRequest; import org.opensearch.repositories.RepositoriesService; import org.opensearch.repositories.RepositoryMissingException; @@ -42,6 +43,8 @@ import static org.opensearch.knn.index.KNNSettings.KNN_INDEX_REMOTE_VECTOR_BUILD_THRESHOLD_SETTING; import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_VECTOR_REPO_SETTING; import static org.opensearch.knn.index.SpaceType.INNER_PRODUCT; +import static org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue.REMOTE_INDEX_BUILD_FLUSH_TIME; +import static org.opensearch.knn.plugin.stats.KNNRemoteIndexBuildValue.REMOTE_INDEX_BUILD_MERGE_TIME; import static org.opensearch.remoteindexbuild.constants.KNNRemoteConstants.DOC_ID_FILE_EXTENSION; import static org.opensearch.remoteindexbuild.constants.KNNRemoteConstants.METHOD_PARAMETER_ENCODER; import static org.opensearch.remoteindexbuild.constants.KNNRemoteConstants.S3; @@ -71,6 +74,17 @@ public void testRemoteIndexBuildStrategyFallback() throws IOException { ); objectUnderTest.buildAndWriteIndex(buildIndexParams); assertTrue(fallback.get()); + for (KNNRemoteIndexBuildValue value : KNNRemoteIndexBuildValue.values()) { + if (value == REMOTE_INDEX_BUILD_FLUSH_TIME && buildIndexParams.isFlush()) { + assertTrue(value.getValue() >= 0L); + } else if (value == REMOTE_INDEX_BUILD_MERGE_TIME && !buildIndexParams.isFlush()) { + assertTrue(value.getValue() >= 0L); + } else if (value == KNNRemoteIndexBuildValue.INDEX_BUILD_FAILURE_COUNT) { + assertEquals(1L, (long) value.getValue()); + } else { + assertEquals(0L, (long) value.getValue()); + } + } } public void testShouldBuildIndexRemotely() { diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildTests.java index 1ce0f107ff..e15610baeb 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildTests.java @@ -100,6 +100,7 @@ abstract class RemoteIndexBuildTests extends KNNTestCase { .knnVectorValuesSupplier(knnVectorValuesSupplier) .totalLiveDocs((int) knnVectorValues.totalLiveDocs()) .segmentWriteState(segmentWriteState) + .isFlush(randomBoolean()) .build(); record TestIndexBuildStrategy(SetOnce fallback) implements NativeIndexBuildStrategy { diff --git a/src/test/java/org/opensearch/knn/index/engine/EngineResolverTests.java b/src/test/java/org/opensearch/knn/index/engine/EngineResolverTests.java index 291f0c671a..376908ce8d 100644 --- a/src/test/java/org/opensearch/knn/index/engine/EngineResolverTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/EngineResolverTests.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.engine; +import org.opensearch.Version; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.mapper.CompressionLevel; @@ -41,10 +42,23 @@ public void testResolveEngine_whenModeAndCompressionAreFalse_thenDefault() { ); } - public void testResolveEngine_whenModeSpecifiedAndCompressionIsNotSpecified_thenNMSLIB() { + public void testResolveEngine_whenModeSpecifiedAndCompressionIsNotSpecified_whenVersionBefore2_19_thenNMSLIB() { assertEquals(KNNEngine.DEFAULT, ENGINE_RESOLVER.resolveEngine(KNNMethodConfigContext.builder().build(), null, false)); assertEquals( KNNEngine.NMSLIB, + ENGINE_RESOLVER.resolveEngine( + KNNMethodConfigContext.builder().mode(Mode.IN_MEMORY).build(), + new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.UNDEFINED, MethodComponentContext.EMPTY, false), + false, + Version.V_2_18_0 + ) + ); + } + + public void testResolveEngine_whenModeSpecifiedAndCompressionIsNotSpecified_thenFAISS() { + assertEquals(KNNEngine.DEFAULT, ENGINE_RESOLVER.resolveEngine(KNNMethodConfigContext.builder().build(), null, false)); + assertEquals( + KNNEngine.FAISS, ENGINE_RESOLVER.resolveEngine( KNNMethodConfigContext.builder().mode(Mode.IN_MEMORY).build(), new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.UNDEFINED, MethodComponentContext.EMPTY, false), @@ -63,9 +77,18 @@ public void testResolveEngine_whenCompressionIs1x_thenEngineBasedOnMode() { ) ); assertEquals( - KNNEngine.NMSLIB, + KNNEngine.FAISS, ENGINE_RESOLVER.resolveEngine(KNNMethodConfigContext.builder().compressionLevel(CompressionLevel.x1).build(), null, false) ); + assertEquals( + KNNEngine.NMSLIB, + ENGINE_RESOLVER.resolveEngine( + KNNMethodConfigContext.builder().compressionLevel(CompressionLevel.x1).build(), + null, + false, + Version.V_2_18_0 + ) + ); } public void testResolveEngine_whenCompressionIs4x_thenEngineIsLucene() { diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index 4e03f88a52..20bf7411f5 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -11,6 +11,7 @@ import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.BytesRef; import org.junit.Assert; import org.mockito.MockedStatic; @@ -176,7 +177,7 @@ public void testTypeParser_build_fromKnnMethodContext() throws IOException { Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); - assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); + assertTrue(knnVectorFieldMapper instanceof EngineFieldMapper); assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isPresent()); assertEquals(spaceType, knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().get().getSpaceType()); assertEquals( @@ -261,6 +262,95 @@ public void testKNNVectorFieldMapper_withBlockedKNNEngine() throws IOException { assertNotNull(builderWithFaiss); } + public void testKNNVectorFieldMapperLucene_docValueDefaults() throws IOException { + String fieldName = "test-field-name"; + String indexName = "test-index"; + + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); + ModelDao modelDao = mock(ModelDao.class); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); + + // Creating a mapping before version 3.0.0 (doc values should be true) + XContentBuilder legacyDocValuesContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, 128) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, KNNEngine.LUCENE.getName()) + .endObject() + .endObject(); + + // Should be true for versions before 3.0.0 + KNNVectorFieldMapper.Builder builderBeforeV3 = (KNNVectorFieldMapper.Builder) typeParser.parse( + fieldName, + xContentBuilderToMap(legacyDocValuesContentBuilder), + buildLegacyParserContext(indexName, settings, Version.V_2_19_0) // Version < 3.0.0 + ); + assertNotNull(builderBeforeV3); + assertTrue(builderBeforeV3.hasDocValues.getValue()); + + // Creating a mapping with Lucene on or after version 3.0.0 (doc values should default to false) + XContentBuilder currentDocValuesContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, 128) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, KNNEngine.LUCENE.getName()) + .endObject() + .endObject(); + + KNNVectorFieldMapper.Builder builderAfterV3 = (KNNVectorFieldMapper.Builder) typeParser.parse( + fieldName, + xContentBuilderToMap(currentDocValuesContentBuilder), + buildParserContext(indexName, settings) // Version >= 3.0.0 + ); + assertNotNull(builderAfterV3); + assertFalse(builderAfterV3.hasDocValues.getValue()); + } + + public void testKNNVectorFieldMapperModel_docValueDefaults() throws IOException { + String fieldName = "test-field-name"; + String indexName = "test-index"; + String modelId = "test-model-id"; + + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); + ModelDao modelDao = mock(ModelDao.class); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); + + // Creating a model mapping before version 3.0.0 (doc values should be true) + XContentBuilder legacyDocValuesContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(MODEL_ID, modelId) + .endObject(); + + // Should be true for versions before 3.0.0 + KNNVectorFieldMapper.Builder builderBeforeV3 = (KNNVectorFieldMapper.Builder) typeParser.parse( + fieldName, + xContentBuilderToMap(legacyDocValuesContentBuilder), + buildLegacyParserContext(indexName, settings, Version.V_2_19_0) // Version < 3.0.0 + ); + assertNotNull(builderBeforeV3); + assertTrue(builderBeforeV3.hasDocValues.getValue()); + + // Creating a mapping with model on or after version 3.0.0 (doc values should default to false) + XContentBuilder currentDocValuesContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(MODEL_ID, modelId) + .endObject(); + + KNNVectorFieldMapper.Builder builderAfterV3 = (KNNVectorFieldMapper.Builder) typeParser.parse( + fieldName, + xContentBuilderToMap(currentDocValuesContentBuilder), + buildParserContext(indexName, settings) // Version >= 3.0.0 + ); + assertNotNull(builderAfterV3); + assertFalse(builderAfterV3.hasDocValues.getValue()); + } + public void testTypeParser_withDifferentSpaceTypeCombinations_thenSuccess() throws IOException { // Check that knnMethodContext takes precedent over both model and legacy ModelDao modelDao = mock(ModelDao.class); @@ -282,7 +372,7 @@ public void testTypeParser_withDifferentSpaceTypeCombinations_thenSuccess() thro Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); - assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); + assertTrue(knnVectorFieldMapper instanceof EngineFieldMapper); assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isPresent()); assertEquals(topLevelSpaceType, knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().get().getSpaceType()); assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getModelId().isEmpty()); @@ -298,7 +388,7 @@ public void testTypeParser_withDifferentSpaceTypeCombinations_thenSuccess() thro builderContext = new Mapper.BuilderContext(settings, new ContentPath()); knnVectorFieldMapper = builder.build(builderContext); - assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); + assertTrue(knnVectorFieldMapper instanceof EngineFieldMapper); assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isPresent()); assertEquals(SpaceType.DEFAULT, knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().get().getSpaceType()); assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getModelId().isEmpty()); @@ -313,7 +403,7 @@ public void testTypeParser_withDifferentSpaceTypeCombinations_thenSuccess() thro builderContext = new Mapper.BuilderContext(settings, new ContentPath()); knnVectorFieldMapper = builder.build(builderContext); - assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); + assertTrue(knnVectorFieldMapper instanceof EngineFieldMapper); assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isPresent()); assertEquals(topLevelSpaceType, knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().get().getSpaceType()); assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getModelId().isEmpty()); @@ -337,7 +427,7 @@ public void testTypeParser_withDifferentSpaceTypeCombinations_thenSuccess() thro builderContext = new Mapper.BuilderContext(settings, new ContentPath()); knnVectorFieldMapper = builder.build(builderContext); - assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); + assertTrue(knnVectorFieldMapper instanceof EngineFieldMapper); assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isPresent()); assertEquals( SpaceType.DEFAULT_BINARY, @@ -363,7 +453,7 @@ public void testTypeParser_withDifferentSpaceTypeCombinations_thenSuccess() thro builderContext = new Mapper.BuilderContext(settings, new ContentPath()); knnVectorFieldMapper = builder.build(builderContext); - assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); + assertTrue(knnVectorFieldMapper instanceof EngineFieldMapper); assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isPresent()); assertEquals( topLevelSpaceType, @@ -408,7 +498,7 @@ public void testTypeParser_withSpaceTypeAndMode_thenSuccess() throws IOException Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); - assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); + assertTrue(knnVectorFieldMapper instanceof EngineFieldMapper); assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isPresent()); assertEquals(topLevelSpaceType, knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().get().getSpaceType()); assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getModelId().isEmpty()); @@ -477,7 +567,7 @@ public void testSpaceType_build_fromLegacy() throws IOException { // Setup settings Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); - assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); + assertTrue(knnVectorFieldMapper instanceof EngineFieldMapper); assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isPresent()); assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getModelId().isEmpty()); assertEquals(SpaceType.L2, knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().get().getSpaceType()); @@ -507,7 +597,7 @@ public void testBuilder_build_fromLegacy() throws IOException { // Setup settings Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); - assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); + assertTrue(knnVectorFieldMapper instanceof EngineFieldMapper); assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isPresent()); assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getModelId().isEmpty()); assertEquals(SpaceType.L2, knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().get().getSpaceType()); @@ -1188,6 +1278,89 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { expectThrows(IllegalArgumentException.class, () -> knnVectorFieldMapper1.merge(knnVectorFieldMapper3)); } + @SneakyThrows + public void testMethodFieldMapper_saveBestMatchingVectorSimilarityFunction() { + final MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); + + doTestMethodFieldMapper_saveBestMatchingVectorSimilarityFunction(methodComponentContext, SpaceType.INNER_PRODUCT, Version.V_2_19_0); + + doTestMethodFieldMapper_saveBestMatchingVectorSimilarityFunction(methodComponentContext, SpaceType.HAMMING, Version.V_2_19_0); + + doTestMethodFieldMapper_saveBestMatchingVectorSimilarityFunction(methodComponentContext, SpaceType.INNER_PRODUCT, Version.V_3_0_0); + + doTestMethodFieldMapper_saveBestMatchingVectorSimilarityFunction(methodComponentContext, SpaceType.HAMMING, Version.V_3_0_0); + } + + @SneakyThrows + private void doTestMethodFieldMapper_saveBestMatchingVectorSimilarityFunction( + MethodComponentContext methodComponentContext, + SpaceType spaceType, + final Version version + ) { + try (MockedStatic utilMockedStatic = Mockito.mockStatic(KNNVectorFieldMapperUtil.class)) { + final VectorDataType dataType = VectorDataType.FLOAT; + final int dimension = TEST_VECTOR.length; + + final KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(dataType) + .versionCreated(version) + .dimension(dimension) + .build(); + final KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, spaceType, methodComponentContext); + + final IndexSettings indexSettingsMock = mock(IndexSettings.class); + when(indexSettingsMock.getSettings()).thenReturn(Settings.EMPTY); + final ParseContext.Document document = new ParseContext.Document(); + final ContentPath contentPath = new ContentPath(); + final ParseContext parseContext = mock(ParseContext.class); + when(parseContext.doc()).thenReturn(document); + when(parseContext.path()).thenReturn(contentPath); + when(parseContext.parser()).thenReturn(createXContentParser(dataType)); + when(parseContext.indexSettings()).thenReturn(indexSettingsMock); + + utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(true); + utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useFullFieldNameValidation(Mockito.any())).thenReturn(true); + + final OriginalMappingParameters originalMappingParameters = new OriginalMappingParameters( + dataType, + dimension, + knnMethodContext, + Mode.NOT_CONFIGURED.getName(), + CompressionLevel.NOT_CONFIGURED.getName(), + null, + SpaceType.UNDEFINED.getValue() + ); + originalMappingParameters.setResolvedKnnMethodContext(knnMethodContext); + EngineFieldMapper methodFieldMapper = EngineFieldMapper.createFieldMapper( + TEST_FIELD_NAME, + TEST_FIELD_NAME, + Collections.emptyMap(), + knnMethodConfigContext, + FieldMapper.MultiFields.empty(), + FieldMapper.CopyTo.empty(), + new Explicit<>(true, true), + false, + false, + originalMappingParameters + ); + methodFieldMapper.parseCreateField(parseContext, dimension, dataType); + final IndexableField field1 = document.getFields().get(0); + + VectorSimilarityFunction similarityFunction = SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction(); + + if (version.onOrAfter(Version.V_3_0_0)) { + // If version >= 3.0, then it should find the best matching function. + try { + similarityFunction = spaceType.getKnnVectorSimilarityFunction().getVectorSimilarityFunction(); + } catch (Exception e) { + // Ignore + } + } + + assertEquals(similarityFunction, field1.fieldType().vectorSimilarityFunction()); + } + } + @SneakyThrows public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldTypes() { try (MockedStatic utilMockedStatic = Mockito.mockStatic(KNNVectorFieldMapperUtil.class)) { @@ -1226,7 +1399,7 @@ public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldT SpaceType.UNDEFINED.getValue() ); originalMappingParameters.setResolvedKnnMethodContext(knnMethodContext); - MethodFieldMapper methodFieldMapper = MethodFieldMapper.createFieldMapper( + EngineFieldMapper methodFieldMapper = EngineFieldMapper.createFieldMapper( TEST_FIELD_NAME, TEST_FIELD_NAME, Collections.emptyMap(), @@ -1253,10 +1426,10 @@ public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldT assertEquals(field1.fieldType().vectorDimension(), adjustDimensionForSearch(dimension, dataType)); assertEquals(Integer.parseInt(field1.fieldType().getAttributes().get(DIMENSION_FIELD_NAME)), dimension); - assertEquals( - field1.fieldType().vectorSimilarityFunction(), - SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction() - ); + final VectorSimilarityFunction similarityFunction = spaceType != SpaceType.HAMMING + ? spaceType.getKnnVectorSimilarityFunction().getVectorSimilarityFunction() + : SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction(); + assertEquals(field1.fieldType().vectorSimilarityFunction(), similarityFunction); utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(false); utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useFullFieldNameValidation(Mockito.any())).thenReturn(false); @@ -1267,7 +1440,7 @@ public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldT when(parseContext.path()).thenReturn(contentPath); when(parseContext.parser()).thenReturn(createXContentParser(dataType)); when(parseContext.indexSettings()).thenReturn(indexSettingsMock); - methodFieldMapper = MethodFieldMapper.createFieldMapper( + methodFieldMapper = EngineFieldMapper.createFieldMapper( TEST_FIELD_NAME, TEST_FIELD_NAME, Collections.emptyMap(), @@ -1412,9 +1585,6 @@ public void testModelFieldMapperParseCreateField_validInput_thenDifferentFieldTy @SneakyThrows public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { - // Create a lucene field mapper that creates a binary doc values field as well as KnnVectorField - LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder inputBuilder = - createLuceneFieldMapperInputBuilder(); IndexSettings indexSettingsMock = mock(IndexSettings.class); when(indexSettingsMock.getSettings()).thenReturn(Settings.EMPTY); ParseContext.Document document = new ParseContext.Document(); @@ -1430,10 +1600,16 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { .dimension(TEST_DIMENSION) .build(); + KNNMethodContext luceneMethodContext = new KNNMethodContext( + KNNEngine.LUCENE, + SpaceType.DEFAULT, + new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()) + ); + OriginalMappingParameters originalMappingParameters = new OriginalMappingParameters( VectorDataType.FLOAT, TEST_DIMENSION, - getDefaultKNNMethodContext(), + luceneMethodContext, Mode.NOT_CONFIGURED.getName(), CompressionLevel.NOT_CONFIGURED.getName(), null, @@ -1441,11 +1617,16 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { ); originalMappingParameters.setResolvedKnnMethodContext(originalMappingParameters.getKnnMethodContext()); - LuceneFieldMapper luceneFieldMapper = LuceneFieldMapper.createFieldMapper( + EngineFieldMapper luceneFieldMapper = EngineFieldMapper.createFieldMapper( + TEST_FIELD_NAME, TEST_FIELD_NAME, Collections.emptyMap(), knnMethodConfigContext, - inputBuilder.build(), + FieldMapper.MultiFields.empty(), + FieldMapper.CopyTo.empty(), + new Explicit<>(true, true), + false, + true, originalMappingParameters ); luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.FLOAT); @@ -1482,30 +1663,31 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { when(parseContext.parser()).thenReturn(createXContentParser(VectorDataType.FLOAT)); when(parseContext.indexSettings()).thenReturn(indexSettingsMock); - inputBuilder.hasDocValues(false); - knnMethodConfigContext = KNNMethodConfigContext.builder() .vectorDataType(VectorDataType.FLOAT) .versionCreated(CURRENT) .dimension(TEST_DIMENSION) .build(); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); - KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.DEFAULT, methodComponentContext); originalMappingParameters = new OriginalMappingParameters( VectorDataType.FLOAT, TEST_DIMENSION, - knnMethodContext, + luceneMethodContext, Mode.NOT_CONFIGURED.getName(), CompressionLevel.NOT_CONFIGURED.getName(), null, SpaceType.UNDEFINED.getValue() ); originalMappingParameters.setResolvedKnnMethodContext(originalMappingParameters.getKnnMethodContext()); - luceneFieldMapper = LuceneFieldMapper.createFieldMapper( + luceneFieldMapper = EngineFieldMapper.createFieldMapper( + TEST_FIELD_NAME, TEST_FIELD_NAME, Collections.emptyMap(), knnMethodConfigContext, - inputBuilder.build(), + FieldMapper.MultiFields.empty(), + FieldMapper.CopyTo.empty(), + new Explicit<>(true, true), + false, + false, originalMappingParameters ); luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.FLOAT); @@ -1521,10 +1703,6 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { @SneakyThrows public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { - // Create a lucene field mapper that creates a binary doc values field as well as KnnByteVectorField - - LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder inputBuilder = - createLuceneFieldMapperInputBuilder(); IndexSettings indexSettingsMock = mock(IndexSettings.class); when(indexSettingsMock.getSettings()).thenReturn(Settings.EMPTY); ParseContext.Document document = new ParseContext.Document(); @@ -1534,10 +1712,13 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { when(parseContext.path()).thenReturn(contentPath); when(parseContext.indexSettings()).thenReturn(indexSettingsMock); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); + KNNMethodContext luceneByteKnnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.DEFAULT, methodComponentContext); + OriginalMappingParameters originalMappingParameters = new OriginalMappingParameters( VectorDataType.BYTE, TEST_DIMENSION, - getDefaultByteKNNMethodContext(), + luceneByteKnnMethodContext, Mode.NOT_CONFIGURED.getName(), CompressionLevel.NOT_CONFIGURED.getName(), null, @@ -1545,8 +1726,9 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { ); originalMappingParameters.setResolvedKnnMethodContext(originalMappingParameters.getKnnMethodContext()); - LuceneFieldMapper luceneFieldMapper = Mockito.spy( - LuceneFieldMapper.createFieldMapper( + EngineFieldMapper luceneFieldMapper = Mockito.spy( + EngineFieldMapper.createFieldMapper( + TEST_FIELD_NAME, TEST_FIELD_NAME, Collections.emptyMap(), KNNMethodConfigContext.builder() @@ -1554,7 +1736,11 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { .versionCreated(CURRENT) .dimension(TEST_DIMENSION) .build(), - inputBuilder.build(), + FieldMapper.MultiFields.empty(), + FieldMapper.CopyTo.empty(), + new Explicit<>(true, true), + false, + true, originalMappingParameters ) ); @@ -1594,10 +1780,9 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { when(parseContext.path()).thenReturn(contentPath); when(parseContext.indexSettings()).thenReturn(indexSettingsMock); - inputBuilder.hasDocValues(false); - luceneFieldMapper = Mockito.spy( - LuceneFieldMapper.createFieldMapper( + EngineFieldMapper.createFieldMapper( + TEST_FIELD_NAME, TEST_FIELD_NAME, Collections.emptyMap(), KNNMethodConfigContext.builder() @@ -1605,7 +1790,11 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { .versionCreated(CURRENT) .dimension(TEST_DIMENSION) .build(), - inputBuilder.build(), + FieldMapper.MultiFields.empty(), + FieldMapper.CopyTo.empty(), + new Explicit<>(true, true), + false, + false, originalMappingParameters ) ); @@ -1819,7 +2008,7 @@ public void testTypeParser_whenModeAndCompressionAreSet_thenHandle() throws IOEx assertFalse(builder.getOriginalParameters().isLegacyMapping()); validateBuilderAfterParsing( builder, - KNNEngine.NMSLIB, + KNNEngine.FAISS, SpaceType.L2, VectorDataType.FLOAT, CompressionLevel.x1, @@ -2213,16 +2402,6 @@ private void validateBuilderAfterParsing( } } - private LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder createLuceneFieldMapperInputBuilder() { - return LuceneFieldMapper.CreateLuceneFieldMapperInput.builder() - .name(TEST_FIELD_NAME) - .multiFields(FieldMapper.MultiFields.empty()) - .copyTo(FieldMapper.CopyTo.empty()) - .hasDocValues(true) - .ignoreMalformed(new Explicit<>(true, true)) - .originalKnnMethodContext(getDefaultKNNMethodContext()); - } - private XContentBuilder createXContentForFieldMapping( SpaceType topLevelSpaceType, SpaceType methodSpaceType, diff --git a/src/test/java/org/opensearch/knn/index/query/ExplainTests.java b/src/test/java/org/opensearch/knn/index/query/ExplainTests.java new file mode 100644 index 0000000000..3fa097e02e --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/ExplainTests.java @@ -0,0 +1,849 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import com.google.common.collect.Comparators; +import com.google.common.collect.ImmutableMap; +import lombok.SneakyThrows; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SegmentReader; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.FixedBitSet; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.util.KNNVectorAsCollectionOfFloatsSerializer; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.query.rescore.RescoreContext; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.jni.JNIService; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.KNNRestTestCase.INDEX_NAME; +import static org.opensearch.knn.common.KNNConstants.ANN_SEARCH; +import static org.opensearch.knn.common.KNNConstants.EXACT_SEARCH; +import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.RADIAL_SEARCH; +import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; + +public class ExplainTests extends KNNWeightTestCase { + + @Mock + private Weight filterQueryWeight; + @Mock + private LeafReaderContext leafReaderContext; + + private void setupTest(final int[] filterDocIds, final Map attributesMap) throws IOException { + setupTest(filterDocIds, attributesMap, filterDocIds != null ? filterDocIds.length : 0, SpaceType.L2, true, null, null, null); + } + + private void setupTest( + final int[] filterDocIds, + final Map attributesMap, + final int maxDoc, + final SpaceType spaceType, + final boolean isCompoundFile, + final byte[] byteVector, + final float[] floatVector, + final MockedStatic vectorValuesFactoryMockedStatic + ) throws IOException { + + final Scorer filterScorer = mock(Scorer.class); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + + Bits liveDocsBits = null; + if (filterDocIds != null) { + FixedBitSet filterBitSet = new FixedBitSet(filterDocIds.length); + for (int docId : filterDocIds) { + filterBitSet.set(docId); + } + liveDocsBits = mock(Bits.class); + for (int filterDocId : filterDocIds) { + when(liveDocsBits.get(filterDocId)).thenReturn(true); + } + when(liveDocsBits.length()).thenReturn(1000); + + when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); + when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length + 1)); + } + final SegmentReader reader = mockSegmentReader(isCompoundFile); + when(reader.maxDoc()).thenReturn(maxDoc); + when(reader.getLiveDocs()).thenReturn(liveDocsBits); + + when(leafReaderContext.reader()).thenReturn(reader); + when(leafReaderContext.id()).thenReturn(new Object()); + + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn(attributesMap); + when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(spaceType.getValue()); + when(fieldInfo.getName()).thenReturn(FIELD_NAME); + + if (floatVector != null) { + final BinaryDocValues binaryDocValues = mock(BinaryDocValues.class); + when(reader.getBinaryDocValues(FIELD_NAME)).thenReturn(binaryDocValues); + when(binaryDocValues.advance(0)).thenReturn(0); + BytesRef vectorByteRef = new BytesRef(KNNVectorAsCollectionOfFloatsSerializer.INSTANCE.floatToByteArray(floatVector)); + when(binaryDocValues.binaryValue()).thenReturn(vectorByteRef); + } + + if (byteVector != null) { + final KNNBinaryVectorValues knnBinaryVectorValues = mock(KNNBinaryVectorValues.class); + vectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(fieldInfo, reader)) + .thenReturn(knnBinaryVectorValues); + when(knnBinaryVectorValues.advance(0)).thenReturn(0); + when(knnBinaryVectorValues.getVector()).thenReturn(byteVector); + } + } + + private void assertExplanation(Explanation explanation, float expectedScore, String topSearch, String... leafDescription) { + assertNotNull(explanation); + assertTrue(explanation.isMatch()); + assertEquals(expectedScore, explanation.getValue().floatValue(), 0.01f); + assertTrue(explanation.getDescription().contains(topSearch)); + assertEquals(1, explanation.getDetails().length); + Explanation explanationDetail = explanation.getDetails()[0]; + assertEquals(expectedScore, explanation.getValue().floatValue(), 0.01f); + for (String description : leafDescription) { + assertTrue(explanationDetail.getDescription().contains(description)); + } + } + + private void assertDiskSearchExplanation(Explanation explanation, String[] topSearchDesc, String... leafDescription) { + assertNotNull(explanation); + assertTrue(explanation.isMatch()); + for (String description : topSearchDesc) { + assertTrue(explanation.getDescription().contains(description)); + } + assertEquals(1, explanation.getDetails().length); + Explanation explanationDetail = explanation.getDetails()[0]; + for (String description : leafDescription) { + assertTrue(explanationDetail.getDescription().contains(description)); + } + } + + @SneakyThrows + public void testDiskBasedSearchWithShardRescoringEnabledANN() { + int k = 3; + knnSettingsMockedStatic.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME)).thenReturn(false); + + jniServiceMockedStatic.when( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), eq(null), anyInt(), any()) + ).thenReturn(getFilteredKNNQueryResults()); + + RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(RescoreContext.MIN_OVERSAMPLE_FACTOR - 1).build(); + + final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; + + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + SPACE_TYPE, + SpaceType.L2.getValue() + ); + + setupTest(filterDocIds, attributesMap); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(QUERY_VECTOR) + .k(k) + .indexName(INDEX_NAME) + .filterQuery(FILTER_QUERY) + .methodParameters(HNSW_METHOD_PARAMETERS) + .vectorDataType(VectorDataType.FLOAT) + .rescoreContext(rescoreContext) + .explain(true) + .build(); + query.setExplain(true); + + final float boost = 1; + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); + + // When + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + + // Then + assertNotNull(knnScorer); + knnWeight.getKnnExplanation().addKnnScorer(leafReaderContext, knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + + jniServiceMockedStatic.verify( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), any(), anyInt(), any()), + times(1) + ); + + final List actualDocIds = new ArrayList<>(); + final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + float score = translatedScores.get(docId) * boost; + Explanation explanation = knnWeight.explain(leafReaderContext, docId, score); + String[] expectedTopDescription = new String[] { + KNNConstants.DISK_BASED_SEARCH, + "the first pass k was " + rescoreContext.getFirstPassK(k, false, QUERY_VECTOR.length), + "over sampling factor of " + rescoreContext.getOversampleFactor(), + "with vector dimension of " + QUERY_VECTOR.length, + "shard level rescoring enabled" }; + assertDiskSearchExplanation( + explanation, + expectedTopDescription, + ANN_SEARCH, + VectorDataType.FLOAT.name(), + SpaceType.L2.getValue() + ); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + + @SneakyThrows + public void testDiskBasedSearchWithShardRescoringDisabledExact() { + knnSettingsMockedStatic.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME)).thenReturn(true); + RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(RescoreContext.MAX_OVERSAMPLE_FACTOR - 1).build(); + + ExactSearcher mockedExactSearcher = mock(ExactSearcher.class); + KNNWeight.initialize(null, mockedExactSearcher); + + final float[] queryVector = new float[] { 0.1f, 2.0f, 3.0f }; + final SpaceType spaceType = randomFrom(SpaceType.L2, SpaceType.INNER_PRODUCT); + + Map attributesMap = Map.of( + SPACE_TYPE, + spaceType.getValue(), + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); + + setupTest(null, attributesMap, 1, spaceType, false, null, null, null); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(queryVector) + .indexName(INDEX_NAME) + .methodParameters(HNSW_METHOD_PARAMETERS) + .vectorDataType(VectorDataType.FLOAT) + .rescoreContext(rescoreContext) + .explain(true) + .build(); + final KNNWeight knnWeight = new KNNWeight(query, 1.0f); + + final ExactSearcher.ExactSearcherContext exactSearchContext = ExactSearcher.ExactSearcherContext.builder() + .isParentHits(true) + // setting to true, so that if quantization details are present we want to do search on the quantized + // vectors as this flow is used in first pass of search. + .useQuantizedVectorsForSearch(true) + .knnQuery(query) + .build(); + when(mockedExactSearcher.searchLeaf(leafReaderContext, exactSearchContext)).thenReturn(DOC_ID_TO_SCORES); + + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + knnWeight.getKnnExplanation().addKnnScorer(leafReaderContext, knnScorer); + + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + final List actualDocIds = new ArrayList<>(); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + float score = DOC_ID_TO_SCORES.get(docId); + Explanation explanation = knnWeight.explain(leafReaderContext, docId, score); + String[] expectedTopDescription = new String[] { + KNNConstants.DISK_BASED_SEARCH, + "the first pass k was " + rescoreContext.getFirstPassK(0, true, queryVector.length), + "over sampling factor of " + rescoreContext.getOversampleFactor(), + "with vector dimension of " + queryVector.length, + "shard level rescoring disabled" }; + assertDiskSearchExplanation( + explanation, + expectedTopDescription, + EXACT_SEARCH, + VectorDataType.FLOAT.name(), + spaceType.getValue(), + "no native engine files" + ); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + // verify JNI Service is not called + jniServiceMockedStatic.verifyNoInteractions(); + verify(mockedExactSearcher).searchLeaf(leafReaderContext, exactSearchContext); + } + + @SneakyThrows + public void testDefaultANNSearch() { + // Given + int k = 3; + jniServiceMockedStatic.when( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), eq(null), anyInt(), any()) + ).thenReturn(getFilteredKNNQueryResults()); + + final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + SPACE_TYPE, + SpaceType.L2.getValue() + ); + + setupTest(filterDocIds, attributesMap); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(QUERY_VECTOR) + .k(k) + .indexName(INDEX_NAME) + .filterQuery(FILTER_QUERY) + .methodParameters(HNSW_METHOD_PARAMETERS) + .vectorDataType(VectorDataType.FLOAT) + .explain(true) + .build(); + query.setExplain(true); + + final float boost = 1; + + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); + + // When + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + + // Then + assertNotNull(knnScorer); + knnWeight.getKnnExplanation().addKnnScorer(leafReaderContext, knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + + jniServiceMockedStatic.verify( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), any(), anyInt(), any()), + times(1) + ); + + final List actualDocIds = new ArrayList<>(); + final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + float score = translatedScores.get(docId) * boost; + Explanation explanation = knnWeight.explain(leafReaderContext, docId, score); + assertExplanation( + explanation, + score, + ANN_SEARCH, + ANN_SEARCH, + VectorDataType.FLOAT.name(), + SpaceType.L2.getValue(), + SpaceType.L2.explainScoreTranslation(DOC_ID_TO_SCORES.get(docId)) + ); + Explanation nestedDetail = explanation.getDetails()[0].getDetails()[0]; + assertTrue(nestedDetail.getDescription().contains(KNNEngine.FAISS.name())); + assertEquals(DOC_ID_TO_SCORES.get(docId), nestedDetail.getValue().floatValue(), 0.01f); + assertEquals(score, knnScorer.score(), 0.01f); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + + @SneakyThrows + public void testANN_FilteredExactSearchAfterANN() { + ExactSearcher mockedExactSearcher = mock(ExactSearcher.class); + KNNWeight.initialize(null, mockedExactSearcher); + final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); + when(mockedExactSearcher.searchLeaf(any(), any())).thenReturn(translatedScores); + // Given + int k = 4; + jniServiceMockedStatic.when( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), eq(null), anyInt(), any()) + ).thenReturn(getFilteredKNNQueryResults()); + + final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + SPACE_TYPE, + SpaceType.L2.getValue() + ); + + setupTest(filterDocIds, attributesMap); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(QUERY_VECTOR) + .k(k) + .indexName(INDEX_NAME) + .filterQuery(FILTER_QUERY) + .methodParameters(HNSW_METHOD_PARAMETERS) + .vectorDataType(VectorDataType.FLOAT) + .explain(true) + .build(); + query.setExplain(true); + + final float boost = 1; + KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); + + // When + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + + // Then + assertNotNull(knnScorer); + knnWeight.getKnnExplanation().addKnnScorer(leafReaderContext, knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + + jniServiceMockedStatic.verify( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), any(), anyInt(), any()), + times(1) + ); + + final List actualDocIds = new ArrayList<>(); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + float score = translatedScores.get(docId) * boost; + Explanation explanation = knnWeight.explain(leafReaderContext, docId, score); + assertExplanation( + explanation, + score, + ANN_SEARCH, + EXACT_SEARCH, + VectorDataType.FLOAT.name(), + SpaceType.L2.getValue(), + "since the number of documents returned are less than K", + "there are more than K filtered Ids" + ); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + + @SneakyThrows + public void testANN_whenNoEngineFiles_thenPerformExactSearch() { + ExactSearcher mockedExactSearcher = mock(ExactSearcher.class); + final float[] queryVector = new float[] { 0.1f, 2.0f, 3.0f }; + final SpaceType spaceType = randomFrom(SpaceType.L2, SpaceType.INNER_PRODUCT); + KNNWeight.initialize(null, mockedExactSearcher); + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(queryVector) + .indexName(INDEX_NAME) + .methodParameters(HNSW_METHOD_PARAMETERS) + .vectorDataType(VectorDataType.FLOAT) + .explain(true) + .build(); + final KNNWeight knnWeight = new KNNWeight(query, 1.0f); + + Map attributesMap = Map.of( + SPACE_TYPE, + spaceType.getValue(), + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); + + setupTest(null, attributesMap, 1, spaceType, false, null, null, null); + + final ExactSearcher.ExactSearcherContext exactSearchContext = ExactSearcher.ExactSearcherContext.builder() + .isParentHits(true) + // setting to true, so that if quantization details are present we want to do search on the quantized + // vectors as this flow is used in first pass of search. + .useQuantizedVectorsForSearch(true) + .knnQuery(query) + .build(); + when(mockedExactSearcher.searchLeaf(leafReaderContext, exactSearchContext)).thenReturn(DOC_ID_TO_SCORES); + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + knnWeight.getKnnExplanation().addKnnScorer(leafReaderContext, knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + final List actualDocIds = new ArrayList<>(); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + float score = DOC_ID_TO_SCORES.get(docId); + assertEquals(score, knnScorer.score(), 0.00000001f); + Explanation explanation = knnWeight.explain(leafReaderContext, docId, score); + assertExplanation( + explanation, + score, + ANN_SEARCH, + EXACT_SEARCH, + VectorDataType.FLOAT.name(), + spaceType.getValue(), + "no native engine files" + ); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + // verify JNI Service is not called + jniServiceMockedStatic.verifyNoInteractions(); + verify(mockedExactSearcher).searchLeaf(leafReaderContext, exactSearchContext); + } + + @SneakyThrows + public void testANNWithFilterQuery_whenFTVGreaterThanFilterId() { + + KNNWeight.initialize(null); + knnSettingsMockedStatic.when(() -> KNNSettings.getFilteredExactSearchThreshold(INDEX_NAME)).thenReturn(10); + byte[] vector = new byte[] { 1, 3 }; + int k = 1; + final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + SPACE_TYPE, + SpaceType.HAMMING.name(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "BHNSW32") + ); + + try (MockedStatic vectorValuesFactoryMockedStatic = Mockito.mockStatic(KNNVectorValuesFactory.class)) { + setupTest(filterDocIds, attributesMap, 100, SpaceType.HAMMING, true, vector, null, vectorValuesFactoryMockedStatic); + final KNNQuery query = new KNNQuery( + FIELD_NAME, + BYTE_QUERY_VECTOR, + k, + INDEX_NAME, + FILTER_QUERY, + null, + VectorDataType.BINARY, + null + ); + + query.setExplain(true); + final float boost = (float) randomDoubleBetween(0, 10, true); + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); + + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + knnWeight.getKnnExplanation().addKnnScorer(leafReaderContext, knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + + final List actualDocIds = new ArrayList<>(); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + float score = BINARY_EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost; + assertEquals(score, knnScorer.score(), 0.01f); + Explanation explanation = knnWeight.explain(leafReaderContext, docId, score); + assertExplanation( + explanation, + score, + ANN_SEARCH, + EXACT_SEARCH, + VectorDataType.BINARY.name(), + SpaceType.HAMMING.getValue(), + "is greater than or equal to cardinality", + "since filtered threshold value" + ); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + } + + @SneakyThrows + public void testANNWithFilterQuery_whenMDCGreaterThanFilterId() { + ModelDao modelDao = mock(ModelDao.class); + KNNWeight.initialize(modelDao); + knnSettingsMockedStatic.when(() -> KNNSettings.getFilteredExactSearchThreshold(INDEX_NAME)).thenReturn(-1); + float[] vector = new float[] { 0.1f, 0.3f }; + int k = 1; + final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + SPACE_TYPE, + SpaceType.L2.name(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); + + setupTest(filterDocIds, attributesMap, 100, SpaceType.L2, true, null, vector, null); + + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY, null, null); + query.setExplain(true); + + final float boost = (float) randomDoubleBetween(0, 10, true); + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); + + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + knnWeight.getKnnExplanation().addKnnScorer(leafReaderContext, knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + + final List actualDocIds = new ArrayList<>(); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + float score = EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost; + assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost, knnScorer.score(), 0.01f); + Explanation explanation = knnWeight.explain(leafReaderContext, docId, score); + assertExplanation( + explanation, + score, + ANN_SEARCH, + EXACT_SEARCH, + VectorDataType.FLOAT.name(), + SpaceType.L2.getValue(), + "since max distance computation", + "is greater than or equal to cardinality" + ); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + + @SneakyThrows + public void testANNWithFilterQuery_whenFilterIdLessThanK() { + ModelDao modelDao = mock(ModelDao.class); + KNNWeight.initialize(modelDao); + knnSettingsMockedStatic.when(() -> KNNSettings.getFilteredExactSearchThreshold(INDEX_NAME)).thenReturn(-1); + float[] vector = new float[] { 0.1f, 0.3f }; + final int[] filterDocIds = new int[] { 0 }; + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + SPACE_TYPE, + SpaceType.L2.name(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); + + setupTest(filterDocIds, attributesMap, 100, SpaceType.L2, true, null, vector, null); + + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, null, null); + query.setExplain(true); + + final float boost = 1; + final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + knnWeight.getKnnExplanation().addKnnScorer(leafReaderContext, knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + + final List actualDocIds = new ArrayList<>(); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + float score = EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost; + assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost, knnScorer.score(), 0.01f); + Explanation explanation = knnWeight.explain(leafReaderContext, docId, score); + assertExplanation( + explanation, + score, + ANN_SEARCH, + EXACT_SEARCH, + VectorDataType.FLOAT.name(), + SpaceType.L2.getValue(), + "since filteredIds", + "is less than or equal to K" + ); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + + @SneakyThrows + public void testRadialANNSearch() { + final float[] queryVector = new float[] { 0.1f, 0.3f }; + final float radius = 0.5f; + final int maxResults = 1000; + jniServiceMockedStatic.when( + () -> JNIService.radiusQueryIndex( + anyLong(), + eq(queryVector), + eq(radius), + eq(HNSW_METHOD_PARAMETERS), + any(), + eq(maxResults), + any(), + anyInt(), + any() + ) + ).thenReturn(getKNNQueryResults()); + + Map attributesMap = Map.of( + SPACE_TYPE, + SpaceType.L2.getValue(), + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); + + setupTest(null, attributesMap); + + KNNQuery.Context context = mock(KNNQuery.Context.class); + when(context.getMaxResultWindow()).thenReturn(maxResults); + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(queryVector) + .radius(radius) + .indexName(INDEX_NAME) + .context(context) + .explain(true) + .vectorDataType(VectorDataType.FLOAT) + .methodParameters(HNSW_METHOD_PARAMETERS) + .build(); + final float boost = 1; + final KNNWeight knnWeight = new KNNWeight(query, boost); + + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + knnWeight.getKnnExplanation().addKnnScorer(leafReaderContext, knnScorer); + jniServiceMockedStatic.verify( + () -> JNIService.radiusQueryIndex( + anyLong(), + eq(queryVector), + eq(radius), + eq(HNSW_METHOD_PARAMETERS), + any(), + eq(maxResults), + any(), + anyInt(), + any() + ) + ); + + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + + final List actualDocIds = new ArrayList<>(); + final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + float score = translatedScores.get(docId) * boost; + assertEquals(score, knnScorer.score(), 0.01f); + Explanation explanation = knnWeight.explain(leafReaderContext, docId, score); + assertExplanation( + explanation, + score, + RADIAL_SEARCH, + ANN_SEARCH, + VectorDataType.FLOAT.name(), + SpaceType.L2.getValue(), + SpaceType.L2.explainScoreTranslation(DOC_ID_TO_SCORES.get(docId)) + ); + Explanation nestedDetail = explanation.getDetails()[0].getDetails()[0]; + assertTrue(nestedDetail.getDescription().contains(KNNEngine.FAISS.name())); + assertEquals(DOC_ID_TO_SCORES.get(docId), nestedDetail.getValue().floatValue(), 0.01f); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + + @SneakyThrows + public void testRadialExactSearch() { + ExactSearcher mockedExactSearcher = mock(ExactSearcher.class); + final SpaceType spaceType = randomFrom(SpaceType.L2, SpaceType.INNER_PRODUCT); + KNNWeight.initialize(null, mockedExactSearcher); + + final float[] queryVector = new float[] { 0.1f, 0.3f }; + final float radius = 0.5f; + final int maxResults = 1000; + jniServiceMockedStatic.when( + () -> JNIService.radiusQueryIndex( + anyLong(), + eq(queryVector), + eq(radius), + eq(HNSW_METHOD_PARAMETERS), + any(), + eq(maxResults), + any(), + anyInt(), + any() + ) + ).thenReturn(getKNNQueryResults()); + + Map attributesMap = Map.of( + SPACE_TYPE, + spaceType.getValue(), + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); + + setupTest(null, attributesMap, 0, spaceType, false, null, null, null); + + KNNQuery.Context context = mock(KNNQuery.Context.class); + when(context.getMaxResultWindow()).thenReturn(maxResults); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(queryVector) + .radius(radius) + .indexName(INDEX_NAME) + .context(context) + .explain(true) + .vectorDataType(VectorDataType.FLOAT) + .methodParameters(HNSW_METHOD_PARAMETERS) + .build(); + final float boost = 1; + final KNNWeight knnWeight = new KNNWeight(query, boost); + final ExactSearcher.ExactSearcherContext exactSearchContext = ExactSearcher.ExactSearcherContext.builder() + .isParentHits(true) + // setting to true, so that if quantization details are present we want to do search on the quantized + // vectors as this flow is used in first pass of search. + .useQuantizedVectorsForSearch(true) + .knnQuery(query) + .build(); + when(mockedExactSearcher.searchLeaf(leafReaderContext, exactSearchContext)).thenReturn(DOC_ID_TO_SCORES); + + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + knnWeight.getKnnExplanation().addKnnScorer(leafReaderContext, knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + final List actualDocIds = new ArrayList<>(); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + float score = DOC_ID_TO_SCORES.get(docId) * boost; + assertEquals(score, knnScorer.score(), 0.01f); + Explanation explanation = knnWeight.explain(leafReaderContext, docId, score); + assertExplanation(explanation, score, RADIAL_SEARCH, EXACT_SEARCH, VectorDataType.FLOAT.name(), spaceType.getValue()); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + // verify JNI Service is not called + jniServiceMockedStatic.verifyNoInteractions(); + verify(mockedExactSearcher).searchLeaf(leafReaderContext, exactSearchContext); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTestCase.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTestCase.java new file mode 100644 index 0000000000..6831c4c7fb --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTestCase.java @@ -0,0 +1,179 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import org.apache.lucene.index.SegmentCommitInfo; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentReader; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.util.StringHelper; +import org.apache.lucene.util.Version; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.mockito.MockedStatic; +import org.opensearch.common.io.PathUtils; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.unit.ByteSizeValue; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.codec.KNNCodecVersion; +import org.opensearch.knn.index.memory.NativeMemoryAllocation; +import org.opensearch.knn.index.memory.NativeMemoryCacheManager; +import org.opensearch.knn.jni.JNIService; + +import java.nio.file.Path; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.KNNRestTestCase.INDEX_NAME; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; +import static org.opensearch.knn.index.KNNSettings.QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES; +import static org.opensearch.knn.index.KNNSettings.QUANTIZATION_STATE_CACHE_SIZE_LIMIT; + +public class KNNWeightTestCase extends KNNTestCase { + + protected static final String FIELD_NAME = "target_field"; + protected static final float[] QUERY_VECTOR = new float[] { 1.8f, 2.4f }; + protected static final byte[] BYTE_QUERY_VECTOR = new byte[] { 1, 2 }; + protected static final String SEGMENT_NAME = "0"; + protected static final int K = 5; + protected static final Set SEGMENT_FILES_NMSLIB = Set.of("_0.cfe", "_0_2011_target_field.hnswc"); + protected static final Set SEGMENT_FILES_FAISS = Set.of("_0.cfe", "_0_2011_target_field.faissc"); + protected static final Set SEGMENT_FILES_DEFAULT = SEGMENT_FILES_FAISS; + protected static final Set SEGMENT_MULTI_FIELD_FILES_FAISS = Set.of( + "_0.cfe", + "_0_2011_target_field.faissc", + "_0_2011_long_target_field.faissc" + ); + protected static final String CIRCUIT_BREAKER_LIMIT_100KB = "100Kb"; + protected static final Integer EF_SEARCH = 10; + protected static final Map HNSW_METHOD_PARAMETERS = Map.of(METHOD_PARAMETER_EF_SEARCH, EF_SEARCH); + protected static final Map DOC_ID_TO_SCORES = Map.of(10, 0.4f, 101, 0.05f, 100, 0.8f, 50, 0.52f); + protected static final Map FILTERED_DOC_ID_TO_SCORES = Map.of(101, 0.05f, 100, 0.8f, 50, 0.52f); + protected static final Map EXACT_SEARCH_DOC_ID_TO_SCORES = Map.of(0, 0.12048191f); + protected static final Map BINARY_EXACT_SEARCH_DOC_ID_TO_SCORES = Map.of(0, 0.5f); + protected static final Query FILTER_QUERY = new TermQuery(new Term("foo", "fooValue")); + protected static MockedStatic nativeMemoryCacheManagerMockedStatic; + protected static MockedStatic jniServiceMockedStatic; + + protected static MockedStatic knnSettingsMockedStatic; + + @BeforeClass + public static void setUpClass() throws Exception { + final KNNSettings knnSettings = mock(KNNSettings.class); + knnSettingsMockedStatic = mockStatic(KNNSettings.class); + when(knnSettings.getSettingValue(eq(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED))).thenReturn(true); + when(knnSettings.getSettingValue(eq(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_CLUSTER_LIMIT))).thenReturn(CIRCUIT_BREAKER_LIMIT_100KB); + when(knnSettings.getSettingValue(eq(KNNSettings.KNN_CACHE_ITEM_EXPIRY_ENABLED))).thenReturn(false); + when(knnSettings.getSettingValue(eq(KNNSettings.KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES))).thenReturn(TimeValue.timeValueMinutes(10)); + + final ByteSizeValue v = ByteSizeValue.parseBytesSizeValue( + CIRCUIT_BREAKER_LIMIT_100KB, + KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_CLUSTER_LIMIT + ); + knnSettingsMockedStatic.when(KNNSettings::getClusterCbLimit).thenReturn(v); + knnSettingsMockedStatic.when(KNNSettings::state).thenReturn(knnSettings); + ByteSizeValue cacheSize = ByteSizeValue.parseBytesSizeValue("1024kb", QUANTIZATION_STATE_CACHE_SIZE_LIMIT); // Setting 1MB as an + // example + when(knnSettings.getSettingValue(eq(QUANTIZATION_STATE_CACHE_SIZE_LIMIT))).thenReturn(cacheSize); + // Mock QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES setting + TimeValue mockTimeValue = TimeValue.timeValueMinutes(10); + when(knnSettings.getSettingValue(eq(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES))).thenReturn(mockTimeValue); + + nativeMemoryCacheManagerMockedStatic = mockStatic(NativeMemoryCacheManager.class); + + final NativeMemoryCacheManager nativeMemoryCacheManager = mock(NativeMemoryCacheManager.class); + final NativeMemoryAllocation nativeMemoryAllocation = mock(NativeMemoryAllocation.class); + when(nativeMemoryCacheManager.get(any(), anyBoolean())).thenReturn(nativeMemoryAllocation); + + nativeMemoryCacheManagerMockedStatic.when(NativeMemoryCacheManager::getInstance).thenReturn(nativeMemoryCacheManager); + + final MockedStatic pathUtilsMockedStatic = mockStatic(PathUtils.class); + final Path indexPath = mock(Path.class); + when(indexPath.toString()).thenReturn("/mydrive/myfolder"); + pathUtilsMockedStatic.when(() -> PathUtils.get(anyString(), anyString())).thenReturn(indexPath); + } + + @Before + public void setupBeforeTest() { + knnSettingsMockedStatic.when(() -> KNNSettings.getFilteredExactSearchThreshold(INDEX_NAME)).thenReturn(0); + jniServiceMockedStatic = mockStatic(JNIService.class); + } + + @After + public void tearDownAfterTest() { + jniServiceMockedStatic.close(); + } + + protected Map getTranslatedScores(Function scoreTranslator) { + return DOC_ID_TO_SCORES.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreTranslator.apply(entry.getValue()))); + } + + protected KNNQueryResult[] getKNNQueryResults() { + return DOC_ID_TO_SCORES.entrySet() + .stream() + .map(entry -> new KNNQueryResult(entry.getKey(), entry.getValue())) + .collect(Collectors.toList()) + .toArray(new KNNQueryResult[0]); + } + + protected KNNQueryResult[] getFilteredKNNQueryResults() { + return FILTERED_DOC_ID_TO_SCORES.entrySet() + .stream() + .map(entry -> new KNNQueryResult(entry.getKey(), entry.getValue())) + .collect(Collectors.toList()) + .toArray(new KNNQueryResult[0]); + } + + protected SegmentReader mockSegmentReader() { + return mockSegmentReader(true); + } + + protected SegmentReader mockSegmentReader(boolean isCompoundFile) { + Path path = mock(Path.class); + + FSDirectory directory = mock(FSDirectory.class); + when(directory.getDirectory()).thenReturn(path); + + SegmentInfo segmentInfo = new SegmentInfo( + directory, + Version.LATEST, + Version.LATEST, + SEGMENT_NAME, + 100, + isCompoundFile, + false, + KNNCodecVersion.CURRENT_DEFAULT, + Map.of(), + new byte[StringHelper.ID_LENGTH], + Map.of(), + Sort.RELEVANCE + ); + segmentInfo.setFiles(SEGMENT_FILES_FAISS); + SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); + + SegmentReader reader = mock(SegmentReader.class); + when(reader.directory()).thenReturn(directory); + when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); + return reader; + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index a8355e29b0..49da9b066b 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -15,12 +15,9 @@ import org.apache.lucene.index.SegmentCommitInfo; import org.apache.lucene.index.SegmentInfo; import org.apache.lucene.index.SegmentReader; -import org.apache.lucene.index.Term; import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.Query; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Sort; -import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.Weight; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.store.FSDirectory; @@ -29,16 +26,9 @@ import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.StringHelper; import org.apache.lucene.util.Version; -import org.junit.After; -import org.junit.Before; -import org.junit.BeforeClass; import org.mockito.MockedConstruction; import org.mockito.MockedStatic; import org.mockito.Mockito; -import org.opensearch.common.io.PathUtils; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.core.common.unit.ByteSizeValue; -import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.KNN990Codec.QuantizationConfigKNNCollector; import org.opensearch.knn.index.codec.KNNCodecVersion; @@ -47,8 +37,6 @@ import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.memory.NativeMemoryAllocation; -import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; @@ -80,7 +68,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyFloat; -import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; @@ -95,90 +82,11 @@ import static org.opensearch.knn.KNNRestTestCase.INDEX_NAME; import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; -import static org.opensearch.knn.index.KNNSettings.QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES; -import static org.opensearch.knn.index.KNNSettings.QUANTIZATION_STATE_CACHE_SIZE_LIMIT; - -public class KNNWeightTests extends KNNTestCase { - private static final String FIELD_NAME = "target_field"; - private static final float[] QUERY_VECTOR = new float[] { 1.8f, 2.4f }; - private static final byte[] BYTE_QUERY_VECTOR = new byte[] { 1, 2 }; - private static final String SEGMENT_NAME = "0"; - private static final int K = 5; - private static final Set SEGMENT_FILES_NMSLIB = Set.of("_0.cfe", "_0_2011_target_field.hnswc"); - private static final Set SEGMENT_FILES_FAISS = Set.of("_0.cfe", "_0_2011_target_field.faissc"); - private static final Set SEGMENT_FILES_DEFAULT = SEGMENT_FILES_FAISS; - private static final Set SEGMENT_MULTI_FIELD_FILES_FAISS = Set.of( - "_0.cfe", - "_0_2011_target_field.faissc", - "_0_2011_long_target_field.faissc" - ); - private static final String CIRCUIT_BREAKER_LIMIT_100KB = "100Kb"; - private static final Integer EF_SEARCH = 10; - private static final Map HNSW_METHOD_PARAMETERS = Map.of(METHOD_PARAMETER_EF_SEARCH, EF_SEARCH); - - private static final Map DOC_ID_TO_SCORES = Map.of(10, 0.4f, 101, 0.05f, 100, 0.8f, 50, 0.52f); - private static final Map FILTERED_DOC_ID_TO_SCORES = Map.of(101, 0.05f, 100, 0.8f, 50, 0.52f); - private static final Map EXACT_SEARCH_DOC_ID_TO_SCORES = Map.of(0, 0.12048191f); - private static final Map BINARY_EXACT_SEARCH_DOC_ID_TO_SCORES = Map.of(0, 0.5f); - - private static final Query FILTER_QUERY = new TermQuery(new Term("foo", "fooValue")); - - private static MockedStatic nativeMemoryCacheManagerMockedStatic; - private static MockedStatic jniServiceMockedStatic; - - private static MockedStatic knnSettingsMockedStatic; - - @BeforeClass - public static void setUpClass() throws Exception { - final KNNSettings knnSettings = mock(KNNSettings.class); - knnSettingsMockedStatic = mockStatic(KNNSettings.class); - when(knnSettings.getSettingValue(eq(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED))).thenReturn(true); - when(knnSettings.getSettingValue(eq(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_CLUSTER_LIMIT))).thenReturn(CIRCUIT_BREAKER_LIMIT_100KB); - when(knnSettings.getSettingValue(eq(KNNSettings.KNN_CACHE_ITEM_EXPIRY_ENABLED))).thenReturn(false); - when(knnSettings.getSettingValue(eq(KNNSettings.KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES))).thenReturn(TimeValue.timeValueMinutes(10)); - - final ByteSizeValue v = ByteSizeValue.parseBytesSizeValue( - CIRCUIT_BREAKER_LIMIT_100KB, - KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_CLUSTER_LIMIT - ); - knnSettingsMockedStatic.when(KNNSettings::getClusterCbLimit).thenReturn(v); - knnSettingsMockedStatic.when(KNNSettings::state).thenReturn(knnSettings); - ByteSizeValue cacheSize = ByteSizeValue.parseBytesSizeValue("1024kb", QUANTIZATION_STATE_CACHE_SIZE_LIMIT); // Setting 1MB as an - // example - when(knnSettings.getSettingValue(eq(QUANTIZATION_STATE_CACHE_SIZE_LIMIT))).thenReturn(cacheSize); - // Mock QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES setting - TimeValue mockTimeValue = TimeValue.timeValueMinutes(10); - when(knnSettings.getSettingValue(eq(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES))).thenReturn(mockTimeValue); - - nativeMemoryCacheManagerMockedStatic = mockStatic(NativeMemoryCacheManager.class); - - final NativeMemoryCacheManager nativeMemoryCacheManager = mock(NativeMemoryCacheManager.class); - final NativeMemoryAllocation nativeMemoryAllocation = mock(NativeMemoryAllocation.class); - when(nativeMemoryCacheManager.get(any(), anyBoolean())).thenReturn(nativeMemoryAllocation); - - nativeMemoryCacheManagerMockedStatic.when(NativeMemoryCacheManager::getInstance).thenReturn(nativeMemoryCacheManager); - - final MockedStatic pathUtilsMockedStatic = mockStatic(PathUtils.class); - final Path indexPath = mock(Path.class); - when(indexPath.toString()).thenReturn("/mydrive/myfolder"); - pathUtilsMockedStatic.when(() -> PathUtils.get(anyString(), anyString())).thenReturn(indexPath); - } - - @Before - public void setupBeforeTest() { - knnSettingsMockedStatic.when(() -> KNNSettings.getFilteredExactSearchThreshold(INDEX_NAME)).thenReturn(0); - jniServiceMockedStatic = mockStatic(JNIService.class); - } - - @After - public void tearDownAfterTest() { - jniServiceMockedStatic.close(); - } +public class KNNWeightTests extends KNNWeightTestCase { @SneakyThrows public void testQueryResultScoreNmslib() { for (SpaceType space : List.of(SpaceType.L2, SpaceType.L1, SpaceType.COSINESIMIL, SpaceType.INNER_PRODUCT, SpaceType.LINF)) { @@ -847,7 +755,7 @@ public void testANNWithFilterQuery_whenFiltersMatchAllDocs_thenSuccess() { assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); } - private SegmentReader mockSegmentReader() { + protected SegmentReader mockSegmentReader() { Path path = mock(Path.class); FSDirectory directory = mock(FSDirectory.class); @@ -1619,28 +1527,6 @@ private void testQueryScore( assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); } - private Map getTranslatedScores(Function scoreTranslator) { - return DOC_ID_TO_SCORES.entrySet() - .stream() - .collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreTranslator.apply(entry.getValue()))); - } - - private KNNQueryResult[] getKNNQueryResults() { - return DOC_ID_TO_SCORES.entrySet() - .stream() - .map(entry -> new KNNQueryResult(entry.getKey(), entry.getValue())) - .collect(Collectors.toList()) - .toArray(new KNNQueryResult[0]); - } - - private KNNQueryResult[] getFilteredKNNQueryResults() { - return FILTERED_DOC_ID_TO_SCORES.entrySet() - .stream() - .map(entry -> new KNNQueryResult(entry.getKey(), entry.getValue())) - .collect(Collectors.toList()) - .toArray(new KNNQueryResult[0]); - } - @SneakyThrows public void testANNWithQuantizationParams_whenStateNotFound_thenFail() { try (MockedStatic quantizationServiceMockedStatic = Mockito.mockStatic(QuantizationService.class)) { diff --git a/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java b/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java index 607699b56e..1aeb63fb04 100644 --- a/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java @@ -14,7 +14,6 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; @@ -23,10 +22,13 @@ import org.apache.lucene.tests.analysis.MockAnalyzer; import org.mockito.Mock; +import org.opensearch.knn.index.query.KNNWeight; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.mockito.MockitoAnnotations.openMocks; @@ -56,7 +58,7 @@ public void testScorer() throws Exception { int[] expectedDocs = { 0, 1, 2, 3, 4 }; float[] expectedScores = { 0.1f, 1.2f, 2.3f, 5.1f, 3.4f }; int[] findSegments = { 0, 2, 5 }; - objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, readerContext.id()); + objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, readerContext.id(), null); // When Scorer scorer1 = objectUnderTest.createWeight(indexSearcher, ScoreMode.COMPLETE, 1).scorer(leaf1); @@ -88,19 +90,18 @@ public void testWeight() { int[] expectedDocs = { 0, 1, 2, 3, 4 }; float[] expectedScores = { 0.1f, 1.2f, 2.3f, 5.1f, 3.4f }; int[] findSegments = { 0, 2, 5 }; - Explanation expectedExplanation = Explanation.match(1.2f, "within top 4"); // When - objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, readerContext.id()); + KNNWeight knnWeight = mock(KNNWeight.class); + objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, readerContext.id(), knnWeight); Weight weight = objectUnderTest.createWeight(indexSearcher, ScoreMode.COMPLETE, 1); - Explanation explanation = weight.explain(leaf1, 1); + weight.explain(leaf1, 1); // Then assertEquals(objectUnderTest, weight.getQuery()); assertTrue(weight.isCacheable(leaf1)); assertEquals(2, weight.count(leaf1)); - assertEquals(expectedExplanation, explanation); - assertEquals(Explanation.noMatch("not in top 4"), weight.explain(leaf1, 9)); + verify(knnWeight).explain(leaf1, 1, 1.2f); } private IndexReader createTestIndexReader() throws IOException { diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java index 82dc43bf35..274ca33820 100644 --- a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java @@ -152,7 +152,6 @@ public int length() { // Set the reader and searcher reader = directoryReader; - ; indexReaderContext = reader.getContext(); // Extract LeafReaderContext List leaves = reader.leaves(); @@ -186,6 +185,44 @@ public int length() { assertEquals(expected, actual.getQuery()); } + @SneakyThrows + public void testExplain() { + + List leaves = reader.leaves(); + assertEquals(1, leaves.size()); + leaf1 = leaves.get(0); + leafReader1 = leaf1.reader(); + + PerLeafResult leafResult = new PerLeafResult(null, new HashMap<>(Map.of(4, 3.4f, 3, 5.1f))); + + when(knnWeight.searchLeaf(leaf1, 4)).thenReturn(leafResult); + + Bits liveDocs = mock(Bits.class); + when(liveDocs.get(anyInt())).thenReturn(true); + when(liveDocs.get(2)).thenReturn(false); + when(liveDocs.get(1)).thenReturn(false); + + // k=4 to make sure we get topk results even if docs are deleted/less in one of the leaves + when(knnQuery.getK()).thenReturn(4); + + TopDocs[] topDocs = { ResultUtil.resultMapToTopDocs(leafResult.getResult(), leaf1.docBase) }; + TopDocs expectedTopDocs = TopDocs.merge(4, topDocs); + + // When + Weight actual = objectUnderTest.createWeight(searcher, scoreMode, 1); + + // Then + Query expected = QueryUtils.INSTANCE.createDocAndScoreQuery(reader, expectedTopDocs); + assertEquals(expected, actual.getQuery()); + for (ScoreDoc scoreDoc : expectedTopDocs.scoreDocs) { + int docId = scoreDoc.doc; + if (docId == 0) continue; + float score = scoreDoc.score; + actual.explain(leaf1, docId); + verify(knnWeight).explain(leaf1, docId, score); + } + } + @SneakyThrows public void testRescoreWhenShardLevelRescoringEnabled() { @@ -220,7 +257,6 @@ public void testRescoreWhenShardLevelRescoringEnabled() { // Set the reader and searcher reader = directoryReader; - ; indexReaderContext = reader.getContext(); // Extract LeafReaderContext List leaves = reader.leaves(); @@ -471,7 +507,7 @@ public void testExpandNestedDocs() { QueryUtils queryUtils = mock(QueryUtils.class); when(queryUtils.getAllSiblings(any(), any(), any(), any())).thenReturn(allSiblings); - when(queryUtils.createDocAndScoreQuery(eq(reader), any())).thenReturn(finalQuery); + when(queryUtils.createDocAndScoreQuery(eq(reader), any(), eq(knnWeight))).thenReturn(finalQuery); // Run NativeEngineKnnVectorQuery query = new NativeEngineKnnVectorQuery(knnQuery, queryUtils, true); @@ -482,7 +518,7 @@ public void testExpandNestedDocs() { verify(queryUtils).getAllSiblings(leaf1, perLeafResults.get(0).keySet(), parentFilter, queryFilterBits); verify(queryUtils).getAllSiblings(leaf2, perLeafResults.get(1).keySet(), parentFilter, queryFilterBits); ArgumentCaptor topDocsCaptor = ArgumentCaptor.forClass(TopDocs.class); - verify(queryUtils).createDocAndScoreQuery(eq(reader), topDocsCaptor.capture()); + verify(queryUtils).createDocAndScoreQuery(eq(reader), topDocsCaptor.capture(), eq(knnWeight)); TopDocs capturedTopDocs = topDocsCaptor.getValue(); assertEquals(topK.totalHits, capturedTopDocs.totalHits); for (int i = 0; i < topK.scoreDocs.length; i++) { diff --git a/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java index d6526125bb..01dc992c0d 100644 --- a/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java +++ b/src/test/java/org/opensearch/knn/integ/DerivedSourceIT.java @@ -6,6 +6,10 @@ package org.opensearch.knn.integ; import lombok.SneakyThrows; +import org.junit.Before; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.knn.DerivedSourceTestCase; import org.opensearch.knn.DerivedSourceUtils; import org.opensearch.knn.Pair; @@ -13,6 +17,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Locale; import java.util.Random; import static org.opensearch.knn.DerivedSourceUtils.DERIVED_ENABLED_WITH_SEGREP_SETTINGS; @@ -25,9 +30,21 @@ */ public class DerivedSourceIT extends DerivedSourceTestCase { + private final String snapshot = "snapshot-test"; + private final String repository = "repo"; + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + final String pathRepo = System.getProperty("tests.path.repo"); + Settings repoSettings = Settings.builder().put("compress", randomBoolean()).put("location", pathRepo).build(); + registerRepository(repository, "fs", true, repoSettings); + } + @SneakyThrows public void testFlatFields() { - List indexConfigContexts = getFlatIndexContexts("derivedit", true); + List indexConfigContexts = getFlatIndexContexts("derivedit", true, true); testDerivedSourceE2E(indexConfigContexts); } @@ -52,11 +69,13 @@ public void testDerivedSource_whenSegrepLocal_thenDisabled() { new Pair<>("original-disable-", false) ); List indexConfigContexts = new ArrayList<>(); + long consistentRandomSeed = random().nextLong(); for (Pair index : indexPrefixToEnabled) { + Random random = new Random(consistentRandomSeed); DerivedSourceUtils.IndexConfigContext indexConfigContext = DerivedSourceUtils.IndexConfigContext.builder() .indexName(getIndexName("deriveit", index.getFirst(), false)) .derivedEnabled(index.getSecond()) - .random(new Random(1)) + .random(random) .settings(index.getSecond() ? DERIVED_ENABLED_WITH_SEGREP_SETTINGS : null) .fields( List.of( @@ -67,7 +86,7 @@ public void testDerivedSource_whenSegrepLocal_thenDisabled() { DerivedSourceUtils.KNNVectorFieldTypeContext.builder() .fieldPath("nested_1.test_vector") .dimension(TEST_DIMENSION) - .valueSupplier(randomVectorSupplier(new Random(0), TEST_DIMENSION, VectorDataType.BYTE)) + .valueSupplier(randomVectorSupplier(random, TEST_DIMENSION, VectorDataType.BYTE)) .build() ) ) @@ -80,7 +99,7 @@ public void testDerivedSource_whenSegrepLocal_thenDisabled() { DerivedSourceUtils.KNNVectorFieldTypeContext.builder() .fieldPath("nested_2.test_vector") .dimension(TEST_DIMENSION) - .valueSupplier(randomVectorSupplier(new Random(0), TEST_DIMENSION, VectorDataType.BYTE)) + .valueSupplier(randomVectorSupplier(random, TEST_DIMENSION, VectorDataType.BYTE)) .build(), DerivedSourceUtils.NestedFieldContext.builder() .fieldPath("nested_2.nested_3") @@ -89,7 +108,7 @@ public void testDerivedSource_whenSegrepLocal_thenDisabled() { DerivedSourceUtils.KNNVectorFieldTypeContext.builder() .fieldPath("nested_2.nested_3.test_vector") .dimension(TEST_DIMENSION) - .valueSupplier(randomVectorSupplier(new Random(0), TEST_DIMENSION, VectorDataType.BYTE)) + .valueSupplier(randomVectorSupplier(random, TEST_DIMENSION, VectorDataType.BYTE)) .build(), DerivedSourceUtils.IntFieldType.builder().fieldPath("nested_2.nested_3.test-int").build() ) @@ -100,7 +119,7 @@ public void testDerivedSource_whenSegrepLocal_thenDisabled() { .build(), DerivedSourceUtils.KNNVectorFieldTypeContext.builder() .dimension(TEST_DIMENSION) - .valueSupplier(randomVectorSupplier(new Random(0), TEST_DIMENSION, VectorDataType.BYTE)) + .valueSupplier(randomVectorSupplier(random, TEST_DIMENSION, VectorDataType.BYTE)) .fieldPath("test_vector") .build(), DerivedSourceUtils.TextFieldType.builder().fieldPath("test-text").build(), @@ -142,6 +161,30 @@ private void testDerivedSourceE2E(List in // Reindex testReindex(indexConfigContexts); + + // Snapshot restore + testSnapshotRestore(repository, snapshot + getTestName().toLowerCase(Locale.ROOT), indexConfigContexts); } + @SneakyThrows + public void testDefaultSetting() { + String indexName = getIndexName("defaults", "test", false); + String fieldName = "test"; + String indexNameDisabled = "disabled"; + int dimension = 16; + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .endObject() + .endObject() + .endObject(); + String mapping = builder.toString(); + createKnnIndex(indexName, mapping); + validateDerivedSetting(indexName, true); + createIndex(indexNameDisabled, Settings.builder().build()); + validateDerivedSetting(indexNameDisabled, false); + } } diff --git a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java index 5a630b5d68..1eceba35fa 100644 --- a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java @@ -69,33 +69,54 @@ public void testKNNL2ScriptScore() throws Exception { testKNNScriptScore(SpaceType.L2); } + public void testKNNL2ByteScriptScore() throws Exception { + testKNNByteScriptScore(SpaceType.L2); + } + public void testKNNL1ScriptScore() throws Exception { testKNNScriptScore(SpaceType.L1); } + public void testKNNL1ByteScriptScore() throws Exception { + testKNNByteScriptScore(SpaceType.L1); + } + public void testKNNLInfScriptScore() throws Exception { testKNNScriptScore(SpaceType.LINF); } + public void testKNNLInfByteScriptScore() throws Exception { + testKNNByteScriptScore(SpaceType.LINF); + } + public void testKNNCosineScriptScore() throws Exception { testKNNScriptScore(SpaceType.COSINESIMIL); } + public void testKNNByteCosineScriptScore() throws Exception { + testKNNByteScriptScore(SpaceType.COSINESIMIL); + } + @SneakyThrows public void testKNNHammingScriptScore() { testKNNScriptScoreOnBinaryIndex(SpaceType.HAMMING); } + @SuppressWarnings("unchecked") @SneakyThrows public void testKNNHammingScriptScore_whenNonBinary_thenException() { final int dims = randomIntBetween(2, 10) * 8; final float[] queryVector = randomVector(dims, VectorDataType.BYTE); - final BiFunction scoreFunction = getScoreFunction(SpaceType.HAMMING, queryVector); + final BiFunction scoreFunction = (BiFunction) getScoreFunction( + SpaceType.HAMMING, + queryVector, + VectorDataType.BINARY + ); List nonBinary = List.of(VectorDataType.FLOAT, VectorDataType.BYTE); for (VectorDataType vectorDataType : nonBinary) { Exception e = expectThrows( Exception.class, - () -> createIndexAndAssertScriptScore( + () -> createIndexAndAssertByteScriptScore( createKnnIndexMapping(FIELD_NAME, dims, vectorDataType), SpaceType.HAMMING, scoreFunction, @@ -110,15 +131,20 @@ public void testKNNHammingScriptScore_whenNonBinary_thenException() { } } + @SuppressWarnings("unchecked") public void testKNNNonHammingScriptScore_whenBinary_thenException() { final int dims = randomIntBetween(2, 10) * 8; final float[] queryVector = randomVector(dims, VectorDataType.BINARY); - final BiFunction scoreFunction = getScoreFunction(SpaceType.HAMMING, queryVector); + final BiFunction scoreFunction = (BiFunction) getScoreFunction( + SpaceType.HAMMING, + queryVector, + VectorDataType.BINARY + ); Set spaceTypeToExclude = Set.of(SpaceType.UNDEFINED, SpaceType.HAMMING); Arrays.stream(SpaceType.values()).filter(s -> spaceTypeToExclude.contains(s) == false).forEach(s -> { Exception e = expectThrows( Exception.class, - () -> createIndexAndAssertScriptScore( + () -> createIndexAndAssertByteScriptScore( createKnnIndexMapping(FIELD_NAME, dims, VectorDataType.BINARY), s, scoreFunction, @@ -625,6 +651,7 @@ public void testKNNScriptScoreWithRequestCacheEnabled() throws Exception { assertEquals(1, secondQueryCacheMap.get("hit_count")); } + @SuppressWarnings("unchecked") public void testKNNScriptScoreOnModelBasedIndex() throws Exception { int dimensions = randomIntBetween(2, 10); String trainMapping = createKnnIndexMapping(TRAIN_FIELD_PARAMETER, dimensions); @@ -661,7 +688,11 @@ public void testKNNScriptScoreOnModelBasedIndex() throws Exception { continue; } final float[] queryVector = randomVector(dimensions); - final BiFunction scoreFunction = getScoreFunction(spaceType, queryVector); + final BiFunction scoreFunction = (BiFunction) getScoreFunction( + spaceType, + queryVector, + VectorDataType.FLOAT + ); createIndexAndAssertScriptScore(testMapping, spaceType, scoreFunction, dimensions, queryVector, true); } } @@ -688,6 +719,30 @@ private List createMappers(int dimensions) throws Exception { ); } + private List createByteMappers(int dimensions) throws Exception { + return List.of( + createKnnIndexMapping(FIELD_NAME, dimensions, VectorDataType.BYTE), + createKnnIndexMapping( + FIELD_NAME, + dimensions, + KNNConstants.METHOD_HNSW, + KNNEngine.LUCENE.getName(), + SpaceType.DEFAULT.getValue(), + true, + VectorDataType.BYTE + ), + createKnnIndexMapping( + FIELD_NAME, + dimensions, + KNNConstants.METHOD_HNSW, + KNNEngine.LUCENE.getName(), + SpaceType.DEFAULT.getValue(), + false, + VectorDataType.BYTE + ) + ); + } + private List createBinaryIndexMappers(int dimensions) throws Exception { return List.of( createKnnIndexMapping( @@ -745,7 +800,7 @@ private Map createDataset( return dataset; } - private BiFunction getScoreFunction(SpaceType spaceType, float[] queryVector) { + private BiFunction getScoreFunction(SpaceType spaceType, float[] queryVector, VectorDataType vectorDataType) { List target = new ArrayList<>(queryVector.length); for (float f : queryVector) { target.add(f); @@ -756,8 +811,8 @@ private BiFunction getScoreFunction(SpaceType spaceType new KNNVectorFieldType( FIELD_NAME, Collections.emptyMap(), - SpaceType.HAMMING == spaceType ? VectorDataType.BINARY : VectorDataType.FLOAT, - getMappingConfigForFlatMapping(SpaceType.HAMMING == spaceType ? queryVector.length * 8 : queryVector.length) + vectorDataType, + getMappingConfigForFlatMapping(vectorDataType == VectorDataType.BINARY ? queryVector.length * 8 : queryVector.length) ) ); switch (spaceType) { @@ -767,35 +822,63 @@ private BiFunction getScoreFunction(SpaceType spaceType case COSINESIMIL: case INNER_PRODUCT: case HAMMING: - return ((KNNScoringSpace.KNNFieldSpace) knnScoringSpace).getScoringMethod(); + if (vectorDataType == VectorDataType.FLOAT) { + return ((KNNScoringSpace.KNNFieldSpace) knnScoringSpace).getScoringMethod(queryVector); + } + return ((KNNScoringSpace.KNNFieldSpace) knnScoringSpace).getScoringMethod(toByte(queryVector)); default: throw new IllegalArgumentException(); } } + @SuppressWarnings("unchecked") private void testKNNScriptScore(SpaceType spaceType) throws Exception { final int dims = randomIntBetween(2, 10); final float[] queryVector = randomVector(dims); - final BiFunction scoreFunction = getScoreFunction(spaceType, queryVector); + final BiFunction scoreFunction = (BiFunction) getScoreFunction( + spaceType, + queryVector, + VectorDataType.FLOAT + ); for (String mapper : createMappers(dims)) { createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, true); createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, false); } } + @SuppressWarnings("unchecked") + private void testKNNByteScriptScore(SpaceType spaceType) throws Exception { + final int dims = randomIntBetween(2, 10); + final float[] queryVector = randomVector(dims, VectorDataType.BYTE); + final BiFunction scoreFunction = (BiFunction) getScoreFunction( + spaceType, + queryVector, + VectorDataType.BYTE + ); + for (String mapper : createByteMappers(dims)) { + createIndexAndAssertByteScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, true, true, VectorDataType.BYTE); + createIndexAndAssertByteScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, false, true, VectorDataType.BYTE); + } + } + + @SuppressWarnings("unchecked") private void testKNNScriptScoreOnBinaryIndex(SpaceType spaceType) throws Exception { final int dims = randomIntBetween(2, 10) * 8; final float[] queryVector = randomVector(dims, VectorDataType.BINARY); - final BiFunction scoreFunction = getScoreFunction(spaceType, queryVector); + final BiFunction scoreFunction = (BiFunction) getScoreFunction( + spaceType, + queryVector, + VectorDataType.BINARY + ); // Test when knn is enabled and engine is Faiss for (String mapper : createBinaryIndexMappers(dims)) { - createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, true, true, VectorDataType.BINARY); - createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, false, true, VectorDataType.BINARY); + createIndexAndAssertByteScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, true, true, VectorDataType.BINARY); + createIndexAndAssertByteScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, false, true, VectorDataType.BINARY); } // Test when knn is disabled and engine is default(Nmslib) - createIndexAndAssertScriptScore( + createIndexAndAssertByteScriptScore( createKnnIndexMapping(FIELD_NAME, dims, VectorDataType.BINARY), spaceType, scoreFunction, @@ -805,7 +888,7 @@ private void testKNNScriptScoreOnBinaryIndex(SpaceType spaceType) throws Excepti false, VectorDataType.BINARY ); - createIndexAndAssertScriptScore( + createIndexAndAssertByteScriptScore( createKnnIndexMapping(FIELD_NAME, dims, VectorDataType.BINARY), spaceType, scoreFunction, @@ -825,13 +908,22 @@ private void createIndexAndAssertScriptScore( float[] queryVector, boolean dense ) throws Exception { - createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dimensions, queryVector, dense, true, VectorDataType.FLOAT); + createIndexAndAssertScriptScore( + mapper, + spaceType, + v -> scoreFunction.apply(queryVector, v), + dimensions, + queryVector, + dense, + true, + VectorDataType.FLOAT + ); } private void createIndexAndAssertScriptScore( String mapper, SpaceType spaceType, - BiFunction scoreFunction, + Function scoreFunction, int dimensions, float[] queryVector, boolean dense, @@ -849,13 +941,7 @@ private void createIndexAndAssertScriptScore( createKnnIndex(INDEX_NAME, settings, mapper); try { final int numDocsWithField = randomIntBetween(4, 10); - Map dataset = createDataset( - v -> scoreFunction.apply(queryVector, v), - dimensions, - numDocsWithField, - dense, - vectorDataType - ); + Map dataset = createDataset(scoreFunction, dimensions, numDocsWithField, dense, vectorDataType); float[] dummyVector = new float[1]; dataset.forEach((k, v) -> { final float[] vector = (v != null) ? v.getVector() : dummyVector; @@ -887,6 +973,36 @@ private void createIndexAndAssertScriptScore( } } + private void createIndexAndAssertByteScriptScore( + String mapper, + SpaceType spaceType, + BiFunction scoreFunction, + int dimensions, + float[] queryVector, + boolean dense, + boolean enableKnn, + VectorDataType vectorDataType + ) throws Exception { + createIndexAndAssertScriptScore( + mapper, + spaceType, + v -> scoreFunction.apply(toByte(queryVector), toByte(v)), + dimensions, + queryVector, + dense, + enableKnn, + vectorDataType + ); + } + + private byte[] toByte(final float[] vector) { + byte[] bytes = new byte[vector.length]; + for (int i = 0; i < vector.length; i++) { + bytes[i] = (byte) vector[i]; + } + return bytes; + } + private float[] dummyFloatArrayBasedOnDimension(int dimesion) { return new float[dimesion]; } diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index a82bc15b36..1559a6ae08 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -68,6 +68,7 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.memoryoptsearch.FaissHNSWTests.loadHnswBinary; public class JNIServiceTests extends KNNTestCase { static final int FP16_MAX = 65504; @@ -1343,6 +1344,55 @@ public void testQueryIndex_faiss_parentIds() throws IOException { } } + public void testQueryIndex_faissCagra_parentIds() throws IOException { + doTestQueryIndex_faissCagra_parentIds(SpaceType.L2); + doTestQueryIndex_faissCagra_parentIds(SpaceType.INNER_PRODUCT); + doTestQueryIndex_faissCagra_parentIds(SpaceType.COSINESIMIL); + + } + + private void doTestQueryIndex_faissCagra_parentIds(SpaceType spaceType) throws IOException { + + int k = 100; + int efSearch = 100; + + int[] parentIds = toParentIdArray(testDataNested.indexData.docs); + Map idToParentIdMap = toIdToParentIdMap(testDataNested.indexData.docs); + + final long pointer; + // This faiss graph binary was created with the IndexHNSWCagra index (base_level_only==true) containing the + // test_vectors_nested_1000x128.json vectors + try (IndexInput indexInput = loadHnswBinary("data/remoteindexbuild/faiss_hnsw_cagra_nested_float_1000_vectors_128_dims.bin")) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + pointer = JNIService.loadIndex( + indexInputWithBuffer, + ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), + KNNEngine.FAISS + ); + assertNotEquals(0, pointer); + } catch (Throwable e) { + fail(e.getMessage()); + throw e; + } + + for (float[] query : testDataNested.queries) { + KNNQueryResult[] results = JNIService.queryIndex( + pointer, + query, + k, + Map.of("ef_search", efSearch), + KNNEngine.FAISS, + null, + 0, + parentIds + ); + // Verify there is no more than one result from same parent + Set parentIdSet = toParentIdSet(results, idToParentIdMap); + assertEquals(results.length, parentIdSet.size()); + } + + } + public void testQueryIndex_faiss_streaming_parentIds() throws IOException { int k = 100; diff --git a/src/test/java/org/opensearch/knn/jni/PlatformUtilTests.java b/src/test/java/org/opensearch/knn/jni/PlatformUtilTests.java index c524d211dd..bc66b934a4 100644 --- a/src/test/java/org/opensearch/knn/jni/PlatformUtilTests.java +++ b/src/test/java/org/opensearch/knn/jni/PlatformUtilTests.java @@ -12,8 +12,11 @@ package org.opensearch.knn.jni; import com.sun.jna.Platform; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; import org.mockito.MockedStatic; -import org.opensearch.knn.KNNTestCase; import oshi.util.platform.mac.SysctlUtil; import java.nio.file.Files; @@ -25,10 +28,16 @@ import static org.opensearch.knn.jni.PlatformUtils.isAVX512SupportedBySystem; import static org.opensearch.knn.jni.PlatformUtils.isAVX512SPRSupportedBySystem; -public class PlatformUtilTests extends KNNTestCase { +public class PlatformUtilTests extends Assert { public static final String MAC_CPU_FEATURES = "machdep.cpu.leaf7_features"; public static final String LINUX_PROC_CPU_INFO = "/proc/cpuinfo"; + @Before + public void setUp() { + PlatformUtils.reset(); + } + + @Test public void testIsAVX2SupportedBySystem_platformIsNotIntel_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(false); @@ -36,6 +45,7 @@ public void testIsAVX2SupportedBySystem_platformIsNotIntel_returnsFalse() { } } + @Test public void testIsAVX2SupportedBySystem_platformIsIntelWithOSAsWindows_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -44,6 +54,7 @@ public void testIsAVX2SupportedBySystem_platformIsIntelWithOSAsWindows_returnsFa } } + @Test public void testIsAVX2SupportedBySystem_platformIsMac_returnsTrue() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -59,6 +70,7 @@ public void testIsAVX2SupportedBySystem_platformIsMac_returnsTrue() { } } + @Test public void testIsAVX2SupportedBySystem_platformIsMac_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -72,6 +84,7 @@ public void testIsAVX2SupportedBySystem_platformIsMac_returnsFalse() { } + @Test public void testIsAVX2SupportedBySystem_platformIsMac_throwsExceptionReturnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -98,6 +111,7 @@ public void testIsAVX2SupportedBySystem_platformIsLinux_returnsTrue() { } } + @Test public void testIsAVX2SupportedBySystem_platformIsLinux_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -112,6 +126,7 @@ public void testIsAVX2SupportedBySystem_platformIsLinux_returnsFalse() { } + @Test public void testIsAVX2SupportedBySystem_platformIsLinux_throwsExceptionReturnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -127,7 +142,7 @@ public void testIsAVX2SupportedBySystem_platformIsLinux_throwsExceptionReturnsFa } // AVX512 tests - + @Test public void testIsAVX512SupportedBySystem_platformIsNotIntel_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(false); @@ -135,6 +150,7 @@ public void testIsAVX512SupportedBySystem_platformIsNotIntel_returnsFalse() { } } + @Test public void testIsAVX512SupportedBySystem_platformIsMac_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isMac).thenReturn(false); @@ -142,6 +158,7 @@ public void testIsAVX512SupportedBySystem_platformIsMac_returnsFalse() { } } + @Test public void testIsAVX512SupportedBySystem_platformIsIntelMac_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -150,6 +167,7 @@ public void testIsAVX512SupportedBySystem_platformIsIntelMac_returnsFalse() { } } + @Test public void testIsAVX512SupportedBySystem_platformIsIntelWithOSAsWindows_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -158,6 +176,7 @@ public void testIsAVX512SupportedBySystem_platformIsIntelWithOSAsWindows_returns } } + @Test public void testIsAVX512SupportedBySystem_platformIsLinuxAllAVX512FlagsPresent_returnsTrue() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -171,6 +190,7 @@ public void testIsAVX512SupportedBySystem_platformIsLinuxAllAVX512FlagsPresent_r } } + @Test public void testIsAVX512SupportedBySystem_platformIsLinuxSomeAVX512FlagsPresent_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -185,7 +205,7 @@ public void testIsAVX512SupportedBySystem_platformIsLinuxSomeAVX512FlagsPresent_ } // Tests AVX512 instructions available since Intel(R) Sapphire Rapids. - + @Test public void testIsAVX512SPRSupportedBySystem_platformIsNotIntel_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(false); @@ -193,6 +213,7 @@ public void testIsAVX512SPRSupportedBySystem_platformIsNotIntel_returnsFalse() { } } + @Test public void testIsAVX512SPRSupportedBySystem_platformIsMac_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isMac).thenReturn(false); @@ -200,6 +221,7 @@ public void testIsAVX512SPRSupportedBySystem_platformIsMac_returnsFalse() { } } + @Test public void testIsAVX512SPRSupportedBySystem_platformIsIntelMac_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -208,6 +230,7 @@ public void testIsAVX512SPRSupportedBySystem_platformIsIntelMac_returnsFalse() { } } + @Test public void testIsAVX512SPRSupportedBySystem_platformIsIntelWithOSAsWindows_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -216,6 +239,7 @@ public void testIsAVX512SPRSupportedBySystem_platformIsIntelWithOSAsWindows_retu } } + @Test public void testIsAVX512SPRSupportedBySystem_platformIsLinuxAllAVX512SPRFlagsPresent_returnsTrue() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); @@ -229,6 +253,7 @@ public void testIsAVX512SPRSupportedBySystem_platformIsLinuxAllAVX512SPRFlagsPre } } + @Test public void testIsAVX512SPRSupportedBySystem_platformIsLinuxSomeAVX512SPRFlagsPresent_returnsFalse() { try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { mockedPlatform.when(Platform::isIntel).thenReturn(true); diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java index 0c6a816e5a..01417d6375 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java @@ -17,22 +17,27 @@ import org.opensearch.cluster.health.ClusterHealthStatus; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; -import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.common.featureflags.KNNFeatureFlags; import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.plugin.stats.KNNStats; import org.opensearch.knn.plugin.stats.StatNames; -import org.opensearch.core.rest.RestStatus; import java.io.IOException; -import java.util.*; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import static org.opensearch.knn.TestUtils.KNN_VECTOR; import static org.opensearch.knn.TestUtils.PROPERTIES; @@ -48,8 +53,9 @@ import static org.opensearch.knn.common.KNNConstants.MIN_SCORE; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; -import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.plugin.stats.StatNames.REMOTE_VECTOR_INDEX_BUILD_STATS; /** * Integration tests to check the correctness of RestKNNStatsHandler @@ -86,6 +92,8 @@ public void setup() { * @throws IOException throws IOException */ public void testCorrectStatsReturned() throws Exception { + // Enable flag to get all stats in KNNStats returned + updateClusterSettings(KNNFeatureFlags.KNN_REMOTE_VECTOR_BUILD_SETTING.getKey(), true); Response response = getKnnStats(Collections.emptyList(), Collections.emptyList()); String responseBody = EntityUtils.toString(response.getEntity()); Map clusterStats = parseClusterStatsResponse(responseBody); @@ -94,6 +102,16 @@ public void testCorrectStatsReturned() throws Exception { assertEquals(knnStats.getNodeStats().keySet(), nodeStats.get(0).keySet()); } + /** + * Test checks that handler correctly omits stats based on feature flag + */ + public void testFeatureFlagOmittingStats() throws Exception { + updateClusterSettings(KNNFeatureFlags.KNN_REMOTE_VECTOR_BUILD_SETTING.getKey(), false); + Response response = getKnnStats(Collections.emptyList(), Collections.emptyList()); + String responseBody = EntityUtils.toString(response.getEntity()); + assertFalse(responseBody.contains(REMOTE_VECTOR_INDEX_BUILD_STATS.getName())); + } + /** * Test checks that handler correctly returns value for select metrics * diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java index 99e847eeab..c571053e0a 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java @@ -55,6 +55,7 @@ private void expectThrowsExceptionWithKNNFieldWithBinaryDataType(Class clazz) th } @SneakyThrows + @SuppressWarnings("unchecked") public void testL2_whenValid_thenSucceed() { float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; List arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); @@ -66,7 +67,12 @@ public void testL2_whenValid_thenSucceed() { getMappingConfigForMethodMapping(knnMethodContext, 3) ); KNNScoringSpace.L2 l2 = new KNNScoringSpace.L2(arrayListQueryObject, fieldType); - assertEquals(1F, l2.getScoringMethod().apply(arrayFloat, arrayFloat), 0.1F); + float[] processedFloatQuery = (float[]) l2.getProcessedQuery(arrayListQueryObject, fieldType); + assertEquals( + 1F, + ((BiFunction) l2.getScoringMethod(processedFloatQuery)).apply(arrayFloat, arrayFloat), + 0.1F + ); } @SneakyThrows @@ -75,6 +81,7 @@ public void testL2_whenInvalidType_thenException() { expectThrowsExceptionWithKNNFieldWithBinaryDataType(KNNScoringSpace.L2.class); } + @SuppressWarnings("unchecked") public void testCosineSimilarity_whenValid_thenSucceed() { float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; List arrayListQueryObject = new ArrayList<>(Arrays.asList(2.0, 4.0, 6.0)); @@ -87,9 +94,10 @@ public void testCosineSimilarity_whenValid_thenSucceed() { getMappingConfigForMethodMapping(knnMethodContext, 3) ); KNNScoringSpace.CosineSimilarity cosineSimilarity = new KNNScoringSpace.CosineSimilarity(arrayListQueryObject, fieldType); + float[] processedFloatQuery = (float[]) cosineSimilarity.getProcessedQuery(arrayListQueryObject, fieldType); assertEquals( VectorSimilarityFunction.COSINE.compare(arrayFloat2, arrayFloat), - cosineSimilarity.getScoringMethod().apply(arrayFloat2, arrayFloat), + ((BiFunction) cosineSimilarity.getScoringMethod(processedFloatQuery)).apply(arrayFloat2, arrayFloat), 0.1F ); @@ -131,6 +139,7 @@ public void testCosineSimilarity_whenInvalidType_thenException() { expectThrowsExceptionWithKNNFieldWithBinaryDataType(KNNScoringSpace.CosineSimilarity.class); } + @SuppressWarnings("unchecked") public void testInnerProd_whenValid_thenSucceed() { float[] arrayFloat_case1 = new float[] { 1.0f, 2.0f, 3.0f }; List arrayListQueryObject_case1 = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); @@ -145,23 +154,45 @@ public void testInnerProd_whenValid_thenSucceed() { ); KNNScoringSpace.InnerProd innerProd = new KNNScoringSpace.InnerProd(arrayListQueryObject_case1, fieldType); - assertEquals(7.0F, innerProd.getScoringMethod().apply(arrayFloat_case1, arrayFloat2_case1), 0.001F); + float[] processedFloatQuery_case1 = (float[]) innerProd.getProcessedQuery(arrayListQueryObject_case1, fieldType); + assertEquals( + 7.0F, + ((BiFunction) innerProd.getScoringMethod(processedFloatQuery_case1)).apply( + arrayFloat_case1, + arrayFloat2_case1 + ), + 0.001F + ); float[] arrayFloat_case2 = new float[] { 100_000.0f, 200_000.0f, 300_000.0f }; List arrayListQueryObject_case2 = new ArrayList<>(Arrays.asList(100_000.0, 200_000.0, 300_000.0)); float[] arrayFloat2_case2 = new float[] { -100_000.0f, -200_000.0f, -300_000.0f }; innerProd = new KNNScoringSpace.InnerProd(arrayListQueryObject_case2, fieldType); - - assertEquals(7.142857143E-12F, innerProd.getScoringMethod().apply(arrayFloat_case2, arrayFloat2_case2), 1.0E-11F); + float[] processedFloatQuery_case2 = (float[]) innerProd.getProcessedQuery(arrayListQueryObject_case2, fieldType); + assertEquals( + 7.142857143E-12F, + ((BiFunction) innerProd.getScoringMethod(processedFloatQuery_case2)).apply( + arrayFloat_case2, + arrayFloat2_case2 + ), + 1.0E-11F + ); float[] arrayFloat_case3 = new float[] { 100_000.0f, 200_000.0f, 300_000.0f }; List arrayListQueryObject_case3 = new ArrayList<>(Arrays.asList(100_000.0, 200_000.0, 300_000.0)); float[] arrayFloat2_case3 = new float[] { 100_000.0f, 200_000.0f, 300_000.0f }; innerProd = new KNNScoringSpace.InnerProd(arrayListQueryObject_case3, fieldType); - - assertEquals(140_000_000_001F, innerProd.getScoringMethod().apply(arrayFloat_case3, arrayFloat2_case3), 0.01F); + float[] processedFloatQuery_case3 = (float[]) innerProd.getProcessedQuery(arrayListQueryObject_case3, fieldType); + assertEquals( + 140_000_000_001F, + ((BiFunction) innerProd.getScoringMethod(processedFloatQuery_case3)).apply( + arrayFloat_case3, + arrayFloat2_case3 + ), + 0.01F + ); } @SneakyThrows @@ -205,6 +236,7 @@ public void testHammingBit_Base64() { ); } + @SuppressWarnings("unchecked") public void testHamming_whenKNNFieldType_thenSucceed() { List arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); @@ -216,9 +248,13 @@ public void testHamming_whenKNNFieldType_thenSucceed() { ); KNNScoringSpace.Hamming hamming = new KNNScoringSpace.Hamming(arrayListQueryObject, fieldType); - - float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; - assertEquals(1F, hamming.getScoringMethod().apply(arrayFloat, arrayFloat), 0.1F); + byte[] processedByteQuery = (byte[]) hamming.getProcessedQuery(arrayListQueryObject, fieldType); + byte[] arrayByte = new byte[] { 1, 2, 3 }; + assertEquals( + 1F, + ((BiFunction) hamming.getScoringMethod(processedByteQuery)).apply(arrayByte, arrayByte), + 0.1F + ); } public void testHamming_whenNonBinaryVectorDataType_thenException() { diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java index 2374e4f7bb..575e40145d 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java @@ -76,6 +76,12 @@ public void testParseKNNVectorQuery() { expectThrows(ClassCastException.class, () -> KNNScoringSpaceUtil.parseToFloatArray(invalidObject, 3, VectorDataType.FLOAT)); } + public void testConvertVectorToByteArray() { + byte[] arrayByte = new byte[] { 1, 2, 3 }; + List arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); + assertArrayEquals(arrayByte, KNNScoringSpaceUtil.parseToByteArray(arrayListQueryObject, 3, VectorDataType.BINARY)); + } + public void testIsBinaryVectorDataType_whenBinary_thenReturnTrue() { KNNVectorFieldType fieldType = mock(KNNVectorFieldType.class); when(fieldType.getVectorDataType()).thenReturn(VectorDataType.BINARY); diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java index 67ffd3e857..77397e208d 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java @@ -314,11 +314,10 @@ public void testHamming_whenKNNVectorScriptDocValuesOfBinary_thenSuccess() { byte[] b1 = { 1, 16, -128 }; // 0000 0001, 0001 0000, 1000 0000 byte[] b2 = { 2, 17, -1 }; // 0000 0010, 0001 0001, 1111 1111 float[] f1 = { 1, 16, -128 }; // 0000 0001, 0001 0000, 1000 0000 - float[] f2 = { 2, 17, -1 }; // 0000 0010, 0001 0001, 1111 1111 List queryVector = Arrays.asList(f1[0], f1[1], f1[2]); - KNNVectorScriptDocValues docValues = mock(KNNVectorScriptDocValues.class); + KNNVectorScriptDocValues docValues = mock(KNNVectorScriptDocValues.class); when(docValues.getVectorDataType()).thenReturn(VectorDataType.BINARY); - when(docValues.getValue()).thenReturn(f2); + when(docValues.getValue()).thenReturn(b2); assertEquals(KNNScoringUtil.calculateHammingBit(b1, b2), KNNScoringUtil.hamming(queryVector, docValues), 0.01f); } diff --git a/src/test/java/org/opensearch/knn/profiler/SegmentProfilerStateTests.java b/src/test/java/org/opensearch/knn/profiler/SegmentProfilerStateTests.java new file mode 100644 index 0000000000..1367d8d41a --- /dev/null +++ b/src/test/java/org/opensearch/knn/profiler/SegmentProfilerStateTests.java @@ -0,0 +1,165 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.profiler; + +import org.apache.commons.math3.stat.descriptive.SummaryStatistics; +import org.apache.lucene.search.DocIdSetIterator; +import org.junit.Before; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class SegmentProfilerStateTests extends OpenSearchTestCase { + + private KNNVectorValues mockVectorValues; + private Supplier> mockSupplier; + + @Before + public void setUp() throws Exception { + super.setUp(); + mockVectorValues = (KNNVectorValues) mock(KNNVectorValues.class); + mockSupplier = () -> mockVectorValues; + } + + public void testConstructor() { + List statistics = new ArrayList<>(); + statistics.add(new SummaryStatistics()); + + SegmentProfilerState state = new SegmentProfilerState(statistics, 0); + assertEquals(statistics, state.getStatistics()); + } + + public void testProfileVectorsWithNullVectorValues() throws IOException { + Supplier> nullSupplier = () -> null; + SegmentProfilerState state = SegmentProfilerState.profileVectors(nullSupplier); + + assertTrue(state.getStatistics().isEmpty()); + } + + public void testProfileVectorsWithNoDocuments() throws IOException { + when(mockVectorValues.docId()).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + + SegmentProfilerState state = SegmentProfilerState.profileVectors(mockSupplier); + assertTrue(state.getStatistics().isEmpty()); + } + + public void testProfileVectorsWithSingleFloatVector() throws IOException { + float[] vector = new float[] { 1.0f, 2.0f, 3.0f }; + + when(mockVectorValues.docId()).thenReturn(0); + when(mockVectorValues.dimension()).thenReturn(3); + when(mockVectorValues.getVector()).thenReturn(vector); + when(mockVectorValues.nextDoc()).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + + SegmentProfilerState state = SegmentProfilerState.profileVectors(mockSupplier); + + assertEquals(3, state.getStatistics().size()); + assertEquals(1.0, state.getStatistics().get(0).getMean(), 0.001); + assertEquals(2.0, state.getStatistics().get(1).getMean(), 0.001); + assertEquals(3.0, state.getStatistics().get(2).getMean(), 0.001); + } + + public void testProfileVectorsWithSingleByteVector() throws IOException { + byte[] vector = new byte[] { 1, 2, 3 }; + + when(mockVectorValues.docId()).thenReturn(0); + when(mockVectorValues.dimension()).thenReturn(3); + when(mockVectorValues.getVector()).thenReturn(vector); + when(mockVectorValues.nextDoc()).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + + SegmentProfilerState state = SegmentProfilerState.profileVectors(mockSupplier); + + assertEquals(3, state.getStatistics().size()); + assertEquals(1.0, state.getStatistics().get(0).getMean(), 0.001); + assertEquals(2.0, state.getStatistics().get(1).getMean(), 0.001); + assertEquals(3.0, state.getStatistics().get(2).getMean(), 0.001); + } + + public void testProfileVectorsWithMultipleFloatVectors() throws IOException { + float[] vector1 = new float[] { 1.0f, 2.0f }; + float[] vector2 = new float[] { 3.0f, 4.0f }; + + when(mockVectorValues.docId()).thenReturn(0); + when(mockVectorValues.dimension()).thenReturn(2); + when(mockVectorValues.getVector()).thenReturn(vector1).thenReturn(vector2); + when(mockVectorValues.nextDoc()).thenReturn(1).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + + SegmentProfilerState state = SegmentProfilerState.profileVectors(mockSupplier); + + assertEquals(2, state.getStatistics().size()); + assertEquals(2.0, state.getStatistics().get(0).getMean(), 0.001); + assertEquals(3.0, state.getStatistics().get(1).getMean(), 0.001); + } + + public void testProfileVectorsWithMultipleByteVectors() throws IOException { + byte[] vector1 = new byte[] { 1, 2 }; + byte[] vector2 = new byte[] { 3, 4 }; + + when(mockVectorValues.docId()).thenReturn(0); + when(mockVectorValues.dimension()).thenReturn(2); + when(mockVectorValues.getVector()).thenReturn(vector1).thenReturn(vector2); + when(mockVectorValues.nextDoc()).thenReturn(1).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + + SegmentProfilerState state = SegmentProfilerState.profileVectors(mockSupplier); + + assertEquals(2, state.getStatistics().size()); + assertEquals(2.0, state.getStatistics().get(0).getMean(), 0.001); + assertEquals(3.0, state.getStatistics().get(1).getMean(), 0.001); + } + + public void testProfileVectorsStatisticalValues() throws IOException { + float[] vector1 = new float[] { 1.0f, 2.0f }; + float[] vector2 = new float[] { 3.0f, 4.0f }; + float[] vector3 = new float[] { 5.0f, 6.0f }; + + when(mockVectorValues.docId()).thenReturn(0); + when(mockVectorValues.dimension()).thenReturn(2); + when(mockVectorValues.getVector()).thenReturn(vector1).thenReturn(vector2).thenReturn(vector3); + when(mockVectorValues.nextDoc()).thenReturn(1).thenReturn(2).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + + SegmentProfilerState state = SegmentProfilerState.profileVectors(mockSupplier); + + assertEquals(3.0, state.getStatistics().get(0).getMean(), 0.001); + assertEquals(2.0, state.getStatistics().get(0).getStandardDeviation(), 0.001); + assertEquals(1.0, state.getStatistics().get(0).getMin(), 0.001); + assertEquals(5.0, state.getStatistics().get(0).getMax(), 0.001); + + assertEquals(4.0, state.getStatistics().get(1).getMean(), 0.001); + assertEquals(2.0, state.getStatistics().get(1).getStandardDeviation(), 0.001); + assertEquals(2.0, state.getStatistics().get(1).getMin(), 0.001); + assertEquals(6.0, state.getStatistics().get(1).getMax(), 0.001); + } + + public void testProfileVectorsWithByteStatisticalValues() throws IOException { + byte[] vector1 = new byte[] { 1, 2 }; + byte[] vector2 = new byte[] { 3, 4 }; + byte[] vector3 = new byte[] { 5, 6 }; + + when(mockVectorValues.docId()).thenReturn(0); + when(mockVectorValues.dimension()).thenReturn(2); + when(mockVectorValues.getVector()).thenReturn(vector1).thenReturn(vector2).thenReturn(vector3); + when(mockVectorValues.nextDoc()).thenReturn(1).thenReturn(2).thenReturn(DocIdSetIterator.NO_MORE_DOCS); + + SegmentProfilerState state = SegmentProfilerState.profileVectors(mockSupplier); + + assertEquals(3.0, state.getStatistics().get(0).getMean(), 0.001); + assertEquals(2.0, state.getStatistics().get(0).getStandardDeviation(), 0.001); + assertEquals(1.0, state.getStatistics().get(0).getMin(), 0.001); + assertEquals(5.0, state.getStatistics().get(0).getMax(), 0.001); + + assertEquals(4.0, state.getStatistics().get(1).getMean(), 0.001); + assertEquals(2.0, state.getStatistics().get(1).getStandardDeviation(), 0.001); + assertEquals(2.0, state.getStatistics().get(1).getMin(), 0.001); + assertEquals(6.0, state.getStatistics().get(1).getMax(), 0.001); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManagerTests.java b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManagerTests.java index 14e55e627d..a07f44d152 100644 --- a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManagerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheManagerTests.java @@ -6,22 +6,103 @@ package org.opensearch.knn.quantization.models.quantizationState; import lombok.SneakyThrows; +import org.junit.After; +import org.junit.Before; import org.mockito.MockedStatic; import org.mockito.Mockito; +import org.opensearch.common.settings.Settings; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.codec.KNN990Codec.KNN990QuantizationStateReader; +import org.opensearch.threadpool.ThreadPool; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; +import static org.mockito.Mockito.when; public class QuantizationStateCacheManagerTests extends KNNTestCase { + private ThreadPool threadPool; + + @Before + public void setThreadPool() { + threadPool = new ThreadPool(Settings.builder().put("node.name", "QuantizationStateCacheTests").build()); + QuantizationStateCache.setThreadPool(threadPool); + QuantizationStateCache.getInstance().rebuildCache(); + } + + @After + public void terminateThreadPool() { + terminate(threadPool); + } + + @SneakyThrows + public void testConcurrentLoad() { + // Get manager and clean it. + final QuantizationStateCacheManager manager = QuantizationStateCacheManager.getInstance(); + manager.rebuildCache(); + + // Mock read config + final QuantizationStateReadConfig readConfig = mock(QuantizationStateReadConfig.class); + when(readConfig.getCacheKey()).thenReturn("cache_key"); + + // Add state first. + final QuantizationState quantizationState = mock(QuantizationState.class); + when(quantizationState.toByteArray()).thenReturn(new byte[32]); + try (MockedStatic mockedStaticReader = Mockito.mockStatic(KNN990QuantizationStateReader.class)) { + // Mock static + mockedStaticReader.when(() -> KNN990QuantizationStateReader.read(readConfig)).thenReturn(quantizationState); + + // Add state + manager.getQuantizationState(readConfig); + } + + // Set up thread executors + final int threadCount = 10; + final int tries = 100; + final ExecutorService executorService = Executors.newFixedThreadPool(threadCount); + final CountDownLatch latch = new CountDownLatch(threadCount); + + // Try to get in parallel + for (int i = 0; i < threadCount; i++) { + executorService.submit(() -> { + try { + for (int k = 0; k < tries; k++) { + // Since we already added state at the beginning, even multiple threads try to load, + // the retrieved one should be the one that we added. + final QuantizationState acquired = manager.getQuantizationState(readConfig); + assertEquals(quantizationState, acquired); + } + } catch (Exception e) { + fail(e.getMessage()); + } finally { + latch.countDown(); + } + }); + } + + // Wait for all threads to finish + latch.await(); + executorService.shutdown(); + } + @SneakyThrows public void testRebuildCache() { try (MockedStatic mockedStaticCache = Mockito.mockStatic(QuantizationStateCache.class)) { - QuantizationStateCache quantizationStateCache = Mockito.mock(QuantizationStateCache.class); + // Mocking state cache singleton + QuantizationStateCache quantizationStateCache = mock(QuantizationStateCache.class); mockedStaticCache.when(QuantizationStateCache::getInstance).thenReturn(quantizationStateCache); + + // Mocking it to do nothing when `rebuildCache` Mockito.doNothing().when(quantizationStateCache).rebuildCache(); QuantizationStateCacheManager.getInstance().rebuildCache(); + + // Verify rebuildCache is called exactly once Mockito.verify(quantizationStateCache, times(1)).rebuildCache(); } } @@ -29,22 +110,30 @@ public void testRebuildCache() { @SneakyThrows public void testGetQuantizationState() { try (MockedStatic mockedStaticCache = Mockito.mockStatic(QuantizationStateCache.class)) { - QuantizationStateReadConfig quantizationStateReadConfig = Mockito.mock(QuantizationStateReadConfig.class); + // Mocking read config with cache key + QuantizationStateReadConfig quantizationStateReadConfig = mock(QuantizationStateReadConfig.class); String cacheKey = "test-key"; - Mockito.when(quantizationStateReadConfig.getCacheKey()).thenReturn(cacheKey); - QuantizationState quantizationState = Mockito.mock(QuantizationState.class); - QuantizationStateCache quantizationStateCache = Mockito.mock(QuantizationStateCache.class); + when(quantizationStateReadConfig.getCacheKey()).thenReturn(cacheKey); + + // Mocking quantization state + QuantizationState quantizationState = mock(QuantizationState.class); + QuantizationStateCache quantizationStateCache = mock(QuantizationStateCache.class); mockedStaticCache.when(QuantizationStateCache::getInstance).thenReturn(quantizationStateCache); - Mockito.doNothing().when(quantizationStateCache).addQuantizationState(cacheKey, quantizationState); + when(quantizationStateCache.getQuantizationState(any(), any())).thenReturn(quantizationState); + + // Validate `getQuantizationState` of `quantizationStateCache` was called. try (MockedStatic mockedStaticReader = Mockito.mockStatic(KNN990QuantizationStateReader.class)) { mockedStaticReader.when(() -> KNN990QuantizationStateReader.read(quantizationStateReadConfig)) .thenReturn(quantizationState); QuantizationStateCacheManager.getInstance().getQuantizationState(quantizationStateReadConfig); - Mockito.verify(quantizationStateCache, times(1)).addQuantizationState(cacheKey, quantizationState); + Mockito.verify(quantizationStateCache, times(1)).getQuantizationState(eq(cacheKey), any()); } - Mockito.when(quantizationStateCache.getQuantizationState(cacheKey)).thenReturn(quantizationState); + + // Validate `getQuantizationState` was called AGAIN. + // But this time, we don't need to invoke `read` as we have a value loaded already. + when(quantizationStateCache.getQuantizationState(any(), any())).thenReturn(quantizationState); QuantizationStateCacheManager.getInstance().getQuantizationState(quantizationStateReadConfig); - Mockito.verify(quantizationStateCache, times(1)).addQuantizationState(cacheKey, quantizationState); + Mockito.verify(quantizationStateCache, times(2)).getQuantizationState(eq(cacheKey), any()); } } @@ -52,7 +141,7 @@ public void testGetQuantizationState() { public void testEvict() { try (MockedStatic mockedStaticCache = Mockito.mockStatic(QuantizationStateCache.class)) { String field = "test-field"; - QuantizationStateCache quantizationStateCache = Mockito.mock(QuantizationStateCache.class); + QuantizationStateCache quantizationStateCache = mock(QuantizationStateCache.class); mockedStaticCache.when(QuantizationStateCache::getInstance).thenReturn(quantizationStateCache); Mockito.doNothing().when(quantizationStateCache).evict(field); QuantizationStateCacheManager.getInstance().evict(field); @@ -60,24 +149,11 @@ public void testEvict() { } } - @SneakyThrows - public void testAddQuantizationState() { - try (MockedStatic mockedStaticCache = Mockito.mockStatic(QuantizationStateCache.class)) { - String field = "test-field"; - QuantizationState quantizationState = Mockito.mock(QuantizationState.class); - QuantizationStateCache quantizationStateCache = Mockito.mock(QuantizationStateCache.class); - mockedStaticCache.when(QuantizationStateCache::getInstance).thenReturn(quantizationStateCache); - Mockito.doNothing().when(quantizationStateCache).addQuantizationState(field, quantizationState); - QuantizationStateCacheManager.getInstance().addQuantizationState(field, quantizationState); - Mockito.verify(quantizationStateCache, times(1)).addQuantizationState(field, quantizationState); - } - } - @SneakyThrows public void testSetMaxCacheSizeInKB() { try (MockedStatic mockedStaticCache = Mockito.mockStatic(QuantizationStateCache.class)) { long maxCacheSizeInKB = 1024; - QuantizationStateCache quantizationStateCache = Mockito.mock(QuantizationStateCache.class); + QuantizationStateCache quantizationStateCache = mock(QuantizationStateCache.class); mockedStaticCache.when(QuantizationStateCache::getInstance).thenReturn(quantizationStateCache); Mockito.doNothing().when(quantizationStateCache).setMaxCacheSizeInKB(maxCacheSizeInKB); QuantizationStateCacheManager.getInstance().setMaxCacheSizeInKB(1024); @@ -88,7 +164,7 @@ public void testSetMaxCacheSizeInKB() { @SneakyThrows public void testClear() { try (MockedStatic mockedStaticCache = Mockito.mockStatic(QuantizationStateCache.class)) { - QuantizationStateCache quantizationStateCache = Mockito.mock(QuantizationStateCache.class); + QuantizationStateCache quantizationStateCache = mock(QuantizationStateCache.class); mockedStaticCache.when(QuantizationStateCache::getInstance).thenReturn(quantizationStateCache); Mockito.doNothing().when(quantizationStateCache).clear(); QuantizationStateCacheManager.getInstance().clear(); diff --git a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java index 87cb57cdcb..e76a99806b 100644 --- a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java +++ b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java @@ -9,7 +9,6 @@ import lombok.SneakyThrows; import org.junit.After; import org.junit.Before; -import org.opensearch.transport.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -19,6 +18,7 @@ import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.threadpool.Scheduler; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.Client; import java.io.IOException; import java.util.concurrent.CountDownLatch; @@ -39,6 +39,7 @@ public class QuantizationStateCacheTests extends KNNTestCase { public void setThreadPool() { threadPool = new ThreadPool(Settings.builder().put("node.name", "QuantizationStateCacheTests").build()); QuantizationStateCache.setThreadPool(threadPool); + QuantizationStateCache.getInstance().rebuildCache(); } @After @@ -46,14 +47,77 @@ public void terminateThreadPool() { terminate(threadPool); } + @SneakyThrows + public void testConcurrentLoadWhenValueExists() { + // Set up thread executors + final int threadCount = 10; + ExecutorService executorService = Executors.newFixedThreadPool(threadCount); + final CountDownLatch latch = new CountDownLatch(threadCount); + + // Prepare quantization state + final String fieldName = "multiThreadField"; + final QuantizationState state = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ONE_BIT), + new float[] { 1.2f, 2.3f, 3.4f } + ); + + // Configure settings + final String cacheSize = "10%"; + final TimeValue expiry = TimeValue.timeValueMinutes(30); + final Settings settings = Settings.builder() + .put(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize) + .put(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING.getKey(), expiry) + .build(); + ClusterSettings clusterSettings = new ClusterSettings( + settings, + ImmutableSet.of(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING) + ); + + // Mocking ClusterService + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + + // Apply settings + final QuantizationStateCache cache = QuantizationStateCache.getInstance(); + clusterService.getClusterSettings().applySettings(settings); + + // Add the state first + QuantizationState retrievedState = cache.getQuantizationState(fieldName, () -> state); + assertEquals(state, retrievedState); + + // Add state from multiple threads + for (int i = 0; i < threadCount; i++) { + executorService.submit(() -> { + try { + // Since we already added state at the beginning, even multiple threads try to load, + // the retrieved one should be the one that we added. + final QuantizationState acquired = cache.getQuantizationState( + fieldName, + () -> new OneBitScalarQuantizationState(new ScalarQuantizationParams(ONE_BIT), new float[] { 1.2f, 2.3f, 3.4f }) + ); + assertEquals(state, acquired); + } finally { + latch.countDown(); + } + }); + } + + // Wait for all threads to finish + latch.await(); + executorService.shutdown(); + } + @SneakyThrows public void testSingleThreadedAddAndRetrieve() { + // Prepare state String fieldName = "singleThreadField"; QuantizationState state = new OneBitScalarQuantizationState( new ScalarQuantizationParams(ONE_BIT), new float[] { 1.2f, 2.3f, 3.4f } ); + // Configure settings with 10% String cacheSize = "10%"; TimeValue expiry = TimeValue.timeValueMinutes(30); @@ -65,35 +129,40 @@ public void testSingleThreadedAddAndRetrieve() { settings, ImmutableSet.of(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING) ); + + // Mocking ClusterService ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); - QuantizationStateCache cache = QuantizationStateCache.getInstance(); + // Apply the configured setting + final QuantizationStateCache cache = QuantizationStateCache.getInstance(); clusterService.getClusterSettings().applySettings(settings); - // Add state - cache.addQuantizationState(fieldName, state); - - QuantizationState retrievedState = cache.getQuantizationState(fieldName); + // Try to get a state and validate + final QuantizationState retrievedState = cache.getQuantizationState(fieldName, () -> state); assertNotNull("State should be retrieved successfully", retrievedState); assertSame("Retrieved state should be the same instance as the one added", state, retrievedState); } @SneakyThrows public void testMultiThreadedAddAndRetrieve() { - int threadCount = 10; + // Set up thread executors + final int threadCount = 10; ExecutorService executorService = Executors.newFixedThreadPool(threadCount); - CountDownLatch latch = new CountDownLatch(threadCount); - String fieldName = "multiThreadField"; - QuantizationState state = new OneBitScalarQuantizationState( + final CountDownLatch latch = new CountDownLatch(threadCount); + + // Prepare quantization state + final String fieldName = "multiThreadField"; + final QuantizationState state = new OneBitScalarQuantizationState( new ScalarQuantizationParams(ONE_BIT), new float[] { 1.2f, 2.3f, 3.4f } ); - String cacheSize = "10%"; - TimeValue expiry = TimeValue.timeValueMinutes(30); - Settings settings = Settings.builder() + // Configure settings + final String cacheSize = "10%"; + final TimeValue expiry = TimeValue.timeValueMinutes(30); + final Settings settings = Settings.builder() .put(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize) .put(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING.getKey(), expiry) .build(); @@ -101,18 +170,24 @@ public void testMultiThreadedAddAndRetrieve() { settings, ImmutableSet.of(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING) ); + + // Mocking ClusterService ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); - QuantizationStateCache cache = QuantizationStateCache.getInstance(); + // Apply settings + final QuantizationStateCache cache = QuantizationStateCache.getInstance(); clusterService.getClusterSettings().applySettings(settings); // Add state from multiple threads + final int tries = 100; for (int i = 0; i < threadCount; i++) { executorService.submit(() -> { try { - cache.addQuantizationState(fieldName, state); + for (int k = 0; k < tries; k++) { + cache.getQuantizationState(fieldName, () -> state); + } } finally { latch.countDown(); } @@ -123,16 +198,20 @@ public void testMultiThreadedAddAndRetrieve() { latch.await(); executorService.shutdown(); - QuantizationState retrievedState = cache.getQuantizationState(fieldName); + // Validate retrieved state + QuantizationState retrievedState = cache.getQuantizationState(fieldName, () -> state); assertNotNull("State should be retrieved successfully", retrievedState); assertSame("Retrieved state should be the same instance as the one added", state, retrievedState); } @SneakyThrows public void testMultiThreadedEvict() { - int threadCount = 10; + // Set up threads + final int threadCount = 10; ExecutorService executorService = Executors.newFixedThreadPool(threadCount); CountDownLatch latch = new CountDownLatch(threadCount); + + // Prepare quantization state String fieldName = "multiThreadEvictField"; QuantizationState state = new OneBitScalarQuantizationState( new ScalarQuantizationParams(ONE_BIT), @@ -141,10 +220,13 @@ public void testMultiThreadedEvict() { String cacheSize = "10%"; TimeValue expiry = TimeValue.timeValueMinutes(30); + // Configure settings Settings settings = Settings.builder() .put(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize) .put(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING.getKey(), expiry) .build(); + + // Mocking ClusterService ClusterSettings clusterSettings = new ClusterSettings( settings, ImmutableSet.of(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING) @@ -153,11 +235,11 @@ public void testMultiThreadedEvict() { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); - QuantizationStateCache cache = QuantizationStateCache.getInstance(); - + // Apply settings to ClusterService clusterService.getClusterSettings().applySettings(settings); - cache.addQuantizationState(fieldName, state); + final QuantizationStateCache cache = QuantizationStateCache.getInstance(); + cache.getQuantizationState(fieldName, () -> state); // Evict state from multiple threads for (int i = 0; i < threadCount; i++) { @@ -174,27 +256,34 @@ public void testMultiThreadedEvict() { latch.await(); executorService.shutdown(); - QuantizationState retrievedState = cache.getQuantizationState(fieldName); - assertNull("State should be null", retrievedState); + final QuantizationState mockedState = getMockedState(); + final QuantizationState retrievedState = cache.getQuantizationState(fieldName, () -> mockedState); + assertEquals(mockedState, retrievedState); } @SneakyThrows public void testConcurrentAddAndEvict() { + // Set up thread executors int threadCount = 10; ExecutorService executorService = Executors.newFixedThreadPool(threadCount); CountDownLatch latch = new CountDownLatch(threadCount); - String fieldName = "concurrentAddEvictField"; + + // Prepare quantization state + final String fieldName = "concurrentAddEvictField"; QuantizationState state = new OneBitScalarQuantizationState( new ScalarQuantizationParams(ONE_BIT), new float[] { 1.2f, 2.3f, 3.4f } ); + + // Configure settings String cacheSize = "10%"; TimeValue expiry = TimeValue.timeValueMinutes(30); - Settings settings = Settings.builder() .put(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize) .put(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING.getKey(), expiry) .build(); + + // Mocking ClusterService ClusterSettings clusterSettings = new ClusterSettings( settings, ImmutableSet.of(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING) @@ -203,15 +292,16 @@ public void testConcurrentAddAndEvict() { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); - QuantizationStateCache cache = QuantizationStateCache.getInstance(); + // Apply settings clusterService.getClusterSettings().applySettings(settings); // Concurrently add and evict state from multiple threads + final QuantizationStateCache cache = QuantizationStateCache.getInstance(); for (int i = 0; i < threadCount; i++) { if (i % 2 == 0) { executorService.submit(() -> { try { - cache.addQuantizationState(fieldName, state); + cache.getQuantizationState(fieldName, () -> state); } finally { latch.countDown(); } @@ -225,7 +315,6 @@ public void testConcurrentAddAndEvict() { } }); } - } // Wait for all threads to finish @@ -233,23 +322,28 @@ public void testConcurrentAddAndEvict() { executorService.shutdown(); // Since operations are concurrent, we can't be sure of the final state, but we can assert that the cache handles it gracefully - QuantizationState retrievedState = cache.getQuantizationState(fieldName); - assertTrue("Final state should be either null or the added state", retrievedState == null || retrievedState == state); + QuantizationState mockedState = getMockedState(); + QuantizationState retrievedState = cache.getQuantizationState(fieldName, () -> mockedState); + assertTrue("Final state should be either new one or the added state", retrievedState == mockedState || retrievedState == state); } @SneakyThrows public void testMultipleThreadedCacheClear() { - int threadCount = 10; + // Set up thread executors + final int threadCount = 10; ExecutorService executorService = Executors.newFixedThreadPool(threadCount); CountDownLatch latch = new CountDownLatch(threadCount); - String fieldName = "multiThreadField"; + + // Prepare quantization state + final String fieldName = "multiThreadField"; QuantizationState state = new OneBitScalarQuantizationState( new ScalarQuantizationParams(ONE_BIT), new float[] { 1.2f, 2.3f, 3.4f } ); + + // Configure settings String cacheSize = "10%"; TimeValue expiry = TimeValue.timeValueMinutes(30); - Settings settings = Settings.builder() .put(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize) .put(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING.getKey(), expiry) @@ -258,13 +352,16 @@ public void testMultipleThreadedCacheClear() { settings, ImmutableSet.of(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING) ); + + // Mocking ClusterService ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); - QuantizationStateCache cache = QuantizationStateCache.getInstance(); + // Apply settings + final QuantizationStateCache cache = QuantizationStateCache.getInstance(); clusterService.getClusterSettings().applySettings(settings); - cache.addQuantizationState(fieldName, state); + cache.getQuantizationState(fieldName, () -> state); // Clear cache from multiple threads for (int i = 0; i < threadCount; i++) { @@ -281,23 +378,29 @@ public void testMultipleThreadedCacheClear() { latch.await(); executorService.shutdown(); - QuantizationState retrievedState = cache.getQuantizationState(fieldName); - assertNull("State should be null", retrievedState); + // Validate there's no state, and it should be the one we just added. + QuantizationState mockedState = getMockedState(); + QuantizationState retrievedState = cache.getQuantizationState(fieldName, () -> mockedState); + assertEquals(mockedState, retrievedState); } @SneakyThrows public void testRebuild() { - int threadCount = 10; + // Set up thread executors + final int threadCount = 10; ExecutorService executorService = Executors.newFixedThreadPool(threadCount); CountDownLatch latch = new CountDownLatch(threadCount); + + // Prepare quantization state String fieldName = "rebuildField"; QuantizationState state = new OneBitScalarQuantizationState( new ScalarQuantizationParams(ONE_BIT), new float[] { 1.2f, 2.3f, 3.4f } ); + + // Configure settings String cacheSize = "10%"; TimeValue expiry = TimeValue.timeValueMinutes(30); - Settings settings = Settings.builder() .put(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING.getKey(), cacheSize) .put(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING.getKey(), expiry) @@ -306,12 +409,15 @@ public void testRebuild() { settings, ImmutableSet.of(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING) ); + + // Mocking ClusterService ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); - QuantizationStateCache cache = QuantizationStateCache.getInstance(); - cache.addQuantizationState(fieldName, state); + // Apply settings + final QuantizationStateCache cache = QuantizationStateCache.getInstance(); + cache.getQuantizationState(fieldName, () -> state); // Rebuild cache from multiple threads for (int i = 0; i < threadCount; i++) { @@ -328,41 +434,49 @@ public void testRebuild() { latch.await(); executorService.shutdown(); - QuantizationState retrievedState = cache.getQuantizationState(fieldName); - assertNull("State should be null", retrievedState); + // Validate there's no state, and it should be the one we just added. + QuantizationState mockedState = getMockedState(); + QuantizationState retrievedState = cache.getQuantizationState(fieldName, () -> mockedState); + assertEquals(mockedState, retrievedState); } @SneakyThrows public void testRebuildOnCacheSizeSettingsChange() { + // Set up thread executors int threadCount = 10; ExecutorService executorService = Executors.newFixedThreadPool(threadCount); CountDownLatch latch = new CountDownLatch(threadCount); + + // Prepare quantization state String fieldName = "rebuildField"; QuantizationState state = new OneBitScalarQuantizationState( new ScalarQuantizationParams(ONE_BIT), new float[] { 1.2f, 2.3f, 3.4f } ); + // Configure settings Settings settings = Settings.builder().build(); ClusterSettings clusterSettings = new ClusterSettings( settings, ImmutableSet.of(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING) ); + + // Mocking ClusterService ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); + // Initialize KNNSettings Client client = mock(Client.class); - KNNSettings.state().initialize(client, clusterService); - QuantizationStateCache cache = QuantizationStateCache.getInstance(); - cache.rebuildCache(); + // Rebuild and add the state + final QuantizationStateCache cache = QuantizationStateCache.getInstance(); long maxCacheSizeInKB = cache.getMaxCacheSizeInKB(); - cache.addQuantizationState(fieldName, state); + cache.getQuantizationState(fieldName, () -> state); + // Prepare a new setting String newCacheSize = "10%"; - Settings newSettings = Settings.builder().put(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING.getKey(), newCacheSize).build(); // Rebuild cache from multiple threads @@ -380,40 +494,49 @@ public void testRebuildOnCacheSizeSettingsChange() { latch.await(); executorService.shutdown(); - QuantizationState retrievedState = cache.getQuantizationState(fieldName); - assertNull("State should be null", retrievedState); + // Validate there's no state and KB threshold value. + QuantizationState mockedState = getMockedState(); + QuantizationState retrievedState = cache.getQuantizationState(fieldName, () -> mockedState); + assertEquals(mockedState, retrievedState); assertNotEquals(maxCacheSizeInKB, cache.getMaxCacheSizeInKB()); } @SneakyThrows public void testRebuildOnTimeExpirySettingsChange() { - int threadCount = 10; + // Set up thread executors + final int threadCount = 10; ExecutorService executorService = Executors.newFixedThreadPool(threadCount); CountDownLatch latch = new CountDownLatch(threadCount); + + // Prepare quantization state String fieldName = "rebuildField"; QuantizationState state = new OneBitScalarQuantizationState( new ScalarQuantizationParams(ONE_BIT), new float[] { 1.2f, 2.3f, 3.4f } ); + // Configure settings Settings settings = Settings.builder().build(); ClusterSettings clusterSettings = new ClusterSettings( settings, ImmutableSet.of(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING) ); + + // Mocking ClusterService ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); + // Initialize KNNSettings Client client = mock(Client.class); - KNNSettings.state().initialize(client, clusterService); - QuantizationStateCache cache = QuantizationStateCache.getInstance(); - cache.addQuantizationState(fieldName, state); + // Add a new state + final QuantizationStateCache cache = QuantizationStateCache.getInstance(); + cache.getQuantizationState(fieldName, () -> state); + // Prepare a new settings TimeValue newExpiry = TimeValue.timeValueMinutes(30); - Settings newSettings = Settings.builder().put(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING.getKey(), newExpiry).build(); // Rebuild cache from multiple threads @@ -431,44 +554,118 @@ public void testRebuildOnTimeExpirySettingsChange() { latch.await(); executorService.shutdown(); - QuantizationState retrievedState = cache.getQuantizationState(fieldName); - assertNull("State should be null", retrievedState); + // Validate there was no state in it. + QuantizationState mockedState = getMockedState(); + QuantizationState retrievedState = cache.getQuantizationState(fieldName, () -> mockedState); + assertEquals(mockedState, retrievedState); + } + + public void testCacheEvictionToSize() throws IOException { + // Adding 4K + 100 bytes as meta info (e.g. length vint encoding etc) + final int arrayLength = 1024; + + // Prepare state1 ~ roughly 4,100 bytes + float[] meanThresholds1 = new float[arrayLength]; + for (int i = 0; i < arrayLength; i++) { + meanThresholds1[i] = i; + } + + // Configure settings + Settings settings = Settings.builder().build(); + ClusterSettings clusterSettings = new ClusterSettings( + settings, + ImmutableSet.of(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING) + ); + + // Mocking ClusterService + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterService.getSettings()).thenReturn(settings); + + // Build cache + final String fieldName = "evictionField"; + // Setting 1KB as a threshold. As a result, expected the first one added will be evicted right away. + final long cacheSizeKB = 1; + final QuantizationStateCache cache = QuantizationStateCache.getInstance(); + cache.setMaxCacheSizeInKB(cacheSizeKB); + cache.rebuildCache(); // Need to rebuild to update size threshold. + + // Try to add the first state + final QuantizationState state = new OneBitScalarQuantizationState(new ScalarQuantizationParams(ONE_BIT), meanThresholds1); + QuantizationState retrievedState = cache.getQuantizationState(fieldName, () -> state); + assertEquals(state, retrievedState); + + // Try again + final QuantizationState state2 = new OneBitScalarQuantizationState(new ScalarQuantizationParams(ONE_BIT), meanThresholds1); + retrievedState = cache.getQuantizationState(fieldName, () -> state2); + assertEquals(state2, retrievedState); + + // Close cache + cache.clear(); + cache.close(); + + // Validate whether states were evicted due to size. + assertNotNull(cache.getEvictedDueToSizeAt()); } public void testCacheEvictionDueToSize() throws IOException { - String fieldName = "evictionField"; - // States have size of slightly over 500 bytes so that adding two will reach the max size of 1 kb for the cache - int arrayLength = 112; - float[] arr = new float[arrayLength]; - float[] arr2 = new float[arrayLength]; + // Adding 4K + 100 bytes as meta info (e.g. length vint encoding etc) + final int arrayLength = 1024; + + // Prepare state1 ~ roughly 4,100 bytes + float[] meanThresholds1 = new float[arrayLength]; + for (int i = 0; i < arrayLength; i++) { + meanThresholds1[i] = i; + } + QuantizationState state1 = new OneBitScalarQuantizationState(new ScalarQuantizationParams(ONE_BIT), meanThresholds1); + + // Prepare state2 ~ roughly 4,100 bytes + float[] meanThresholds2 = new float[arrayLength]; for (int i = 0; i < arrayLength; i++) { - arr[i] = i; - arr[i] = i + 1; + meanThresholds2[i] = i + 1; } - QuantizationState state = new OneBitScalarQuantizationState(new ScalarQuantizationParams(ONE_BIT), arr); - QuantizationState state2 = new OneBitScalarQuantizationState(new ScalarQuantizationParams(ONE_BIT), arr2); - long cacheSize = 1; + QuantizationState state2 = new OneBitScalarQuantizationState(new ScalarQuantizationParams(ONE_BIT), meanThresholds2); + + // Configure settings Settings settings = Settings.builder().build(); ClusterSettings clusterSettings = new ClusterSettings( settings, ImmutableSet.of(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING) ); + + // Mocking ClusterService ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(clusterService.getSettings()).thenReturn(settings); - QuantizationStateCache cache = new QuantizationStateCache(); - cache.setMaxCacheSizeInKB(cacheSize); - cache.rebuildCache(); - cache.addQuantizationState(fieldName, state); - cache.addQuantizationState(fieldName, state2); + // Build cache + final String fieldName = "evictionField"; + final String fieldName2 = "evictionField2"; + // Setting 7KB as a threshold. As the weight of each one si roughly 4,100 bytes + // Thus, setting 7KB so that it can evict the first one added when the second state is added. + final long cacheSizeKB = 7; + final QuantizationStateCache cache = QuantizationStateCache.getInstance(); + cache.setMaxCacheSizeInKB(cacheSizeKB); + cache.rebuildCache(); // Need to rebuild to update size threshold. + + // Try to add the first state + QuantizationState retrievedState = cache.getQuantizationState(fieldName, () -> state1); + assertEquals(state1, retrievedState); + + // Try to add the second state + retrievedState = cache.getQuantizationState(fieldName2, () -> state2); + assertEquals(state2, retrievedState); + + // Close cache cache.clear(); cache.close(); + + // Validate whether states were evicted due to size. assertNotNull(cache.getEvictedDueToSizeAt()); } public void testMaintenanceScheduled() throws Exception { - QuantizationStateCache quantizationStateCache = new QuantizationStateCache(); + final QuantizationStateCache quantizationStateCache = QuantizationStateCache.getInstance(); Scheduler.Cancellable maintenanceTask = quantizationStateCache.getMaintenanceTask(); assertNotNull(maintenanceTask); @@ -478,7 +675,7 @@ public void testMaintenanceScheduled() throws Exception { } public void testMaintenanceWithRebuild() throws Exception { - QuantizationStateCache quantizationStateCache = new QuantizationStateCache(); + final QuantizationStateCache quantizationStateCache = QuantizationStateCache.getInstance(); Scheduler.Cancellable task1 = quantizationStateCache.getMaintenanceTask(); assertNotNull(task1); @@ -489,4 +686,11 @@ public void testMaintenanceWithRebuild() throws Exception { assertNotNull(task2); quantizationStateCache.close(); } + + @SneakyThrows + private static QuantizationState getMockedState() { + QuantizationState mockedState = mock(QuantizationState.class); + when(mockedState.toByteArray()).thenReturn(new byte[32]); + return mockedState; + } } diff --git a/src/test/resources/data/remoteindexbuild/faiss_hnsw_cagra_nested_float_1000_vectors_128_dims.bin b/src/test/resources/data/remoteindexbuild/faiss_hnsw_cagra_nested_float_1000_vectors_128_dims.bin new file mode 100644 index 0000000000..6f3f734f7e Binary files /dev/null and b/src/test/resources/data/remoteindexbuild/faiss_hnsw_cagra_nested_float_1000_vectors_128_dims.bin differ diff --git a/src/testFixtures/java/org/opensearch/knn/DerivedSourceTestCase.java b/src/testFixtures/java/org/opensearch/knn/DerivedSourceTestCase.java index 67a9a79ec7..dee67ee53b 100644 --- a/src/testFixtures/java/org/opensearch/knn/DerivedSourceTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/DerivedSourceTestCase.java @@ -16,11 +16,13 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.index.VectorDataType; +import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Random; +import java.util.function.Supplier; public class DerivedSourceTestCase extends KNNRestTestCase { @@ -33,6 +35,11 @@ public class DerivedSourceTestCase extends KNNRestTestCase { new Pair<>("d2d-", false) ); + private static final int MIN_DIMENSION = 4; + private static final int MAX_DIMENSION = 32; + private static final int MIN_DOCS = 50; + private static final int MAX_DOCS = 200; + /** * Testing flat, single field base case with index configuration. The test will automatically skip adding fields for * random documents to ensure it works robustly. To ensure correctness, we repeat same operations against an @@ -84,33 +91,54 @@ public class DerivedSourceTestCase extends KNNRestTestCase { * } * } */ - protected List getFlatIndexContexts(String testSuitePrefix, boolean addRandom) { + protected List getFlatIndexContexts(String testSuitePrefix, boolean addRandom, boolean addNull) { List indexConfigContexts = new ArrayList<>(); + long consistentRandomSeed = random().nextLong(); for (Pair index : INDEX_PREFIX_TO_ENABLED) { + Supplier dimensionSupplier = randomIntegerSupplier(consistentRandomSeed, MIN_DIMENSION, MAX_DIMENSION); + Supplier binaryDimensionSupplier = randomIntegerSupplier(consistentRandomSeed, MIN_DIMENSION, MAX_DIMENSION, 8); + Supplier randomDocCountSupplier = randomIntegerSupplier(consistentRandomSeed, MIN_DOCS, MAX_DOCS); DerivedSourceUtils.IndexConfigContext indexConfigContext = DerivedSourceUtils.IndexConfigContext.builder() .indexName(getIndexName(testSuitePrefix, index.getFirst(), addRandom)) + .docCount(randomDocCountSupplier.get()) .derivedEnabled(index.getSecond()) - .random(new Random(1)) + .random(new Random(consistentRandomSeed)) .fields( List.of( - DerivedSourceUtils.KNNVectorFieldTypeContext.builder().fieldPath("test_float_vector").build(), - DerivedSourceUtils.KNNVectorFieldTypeContext.builder().fieldPath("update_float_vector").isUpdate(true).build(), + DerivedSourceUtils.KNNVectorFieldTypeContext.builder() + .dimension(dimensionSupplier.get()) + .nullProb(addNull ? DerivedSourceUtils.DEFAULT_NULL_PROB : 0) + .fieldPath("test_float_vector") + .build(), + DerivedSourceUtils.KNNVectorFieldTypeContext.builder() + .dimension(dimensionSupplier.get()) + .nullProb(addNull ? DerivedSourceUtils.DEFAULT_NULL_PROB : 0) + .fieldPath("update_float_vector") + .isUpdate(true) + .build(), DerivedSourceUtils.KNNVectorFieldTypeContext.builder() .fieldPath("test_byte_vector") .vectorDataType(VectorDataType.BYTE) + .nullProb(addNull ? DerivedSourceUtils.DEFAULT_NULL_PROB : 0) .build(), DerivedSourceUtils.KNNVectorFieldTypeContext.builder() .fieldPath("update_byte_vector") .vectorDataType(VectorDataType.BYTE) + .dimension(dimensionSupplier.get()) + .nullProb(addNull ? DerivedSourceUtils.DEFAULT_NULL_PROB : 0) .isUpdate(true) .build(), DerivedSourceUtils.KNNVectorFieldTypeContext.builder() .fieldPath("test_binary_vector") .vectorDataType(VectorDataType.BINARY) + .dimension(binaryDimensionSupplier.get()) + .nullProb(addNull ? DerivedSourceUtils.DEFAULT_NULL_PROB : 0) .build(), DerivedSourceUtils.KNNVectorFieldTypeContext.builder() .fieldPath("update_binary_vector") .vectorDataType(VectorDataType.BINARY) + .dimension(binaryDimensionSupplier.get()) + .nullProb(addNull ? DerivedSourceUtils.DEFAULT_NULL_PROB : 0) .isUpdate(true) .build(), DerivedSourceUtils.TextFieldType.builder().fieldPath("test-text").build(), @@ -137,11 +165,11 @@ protected List getFlatIndexContexts(Strin * "properties" : { * "test_vector" : { * "type" : "knn_vector", - * "dimension" : 16 + * "dimension" : 63 * }, * "update_vector" : { * "type" : "knn_vector", - * "dimension" : 16 + * "dimension" : 34 * } * } * }, @@ -154,11 +182,11 @@ protected List getFlatIndexContexts(Strin * }, * "test_vector" : { * "type" : "knn_vector", - * "dimension" : 16 + * "dimension" : 41 * }, * "update_vector" : { * "type" : "knn_vector", - * "dimension" : 16 + * "dimension" : 8 * } * } * }, @@ -167,11 +195,11 @@ protected List getFlatIndexContexts(Strin * }, * "test_vector" : { * "type" : "knn_vector", - * "dimension" : 16 + * "dimension" : 45 * }, * "update_vector" : { * "type" : "knn_vector", - * "dimension" : 16 + * "dimension" : 7 * } * } * }, @@ -183,11 +211,11 @@ protected List getFlatIndexContexts(Strin * }, * "test_vector" : { * "type" : "knn_vector", - * "dimension" : 16 + * "dimension" : 10 * }, * "update_vector" : { * "type" : "knn_vector", - * "dimension" : 16 + * "dimension" : 51 * } * } * } @@ -196,19 +224,27 @@ protected List getFlatIndexContexts(Strin */ protected List getObjectIndexContexts(String testSuitePrefix, boolean addRandom) { List indexConfigContexts = new ArrayList<>(); + long consistentRandomSeed = random().nextLong(); for (Pair index : INDEX_PREFIX_TO_ENABLED) { + Supplier dimensionSupplier = randomIntegerSupplier(consistentRandomSeed, MIN_DIMENSION, MAX_DIMENSION); + Supplier randomDocCountSupplier = randomIntegerSupplier(consistentRandomSeed, MIN_DOCS, MAX_DOCS); DerivedSourceUtils.IndexConfigContext indexConfigContext = DerivedSourceUtils.IndexConfigContext.builder() .indexName(getIndexName(testSuitePrefix, index.getFirst(), addRandom)) + .docCount(randomDocCountSupplier.get()) .derivedEnabled(index.getSecond()) - .random(new Random(1)) + .random(new Random(consistentRandomSeed)) .fields( List.of( DerivedSourceUtils.ObjectFieldContext.builder() .fieldPath("path_1") .children( List.of( - DerivedSourceUtils.KNNVectorFieldTypeContext.builder().fieldPath("path_1.test_vector").build(), DerivedSourceUtils.KNNVectorFieldTypeContext.builder() + .dimension(dimensionSupplier.get()) + .fieldPath("path_1.test_vector") + .build(), + DerivedSourceUtils.KNNVectorFieldTypeContext.builder() + .dimension(dimensionSupplier.get()) .fieldPath("path_1.update_vector") .isUpdate(true) .build() @@ -220,8 +256,12 @@ protected List getObjectIndexContexts(Str .children( List.of( DerivedSourceUtils.TextFieldType.builder().fieldPath("path_2.test-text").build(), - DerivedSourceUtils.KNNVectorFieldTypeContext.builder().fieldPath("path_2.test_vector").build(), DerivedSourceUtils.KNNVectorFieldTypeContext.builder() + .dimension(dimensionSupplier.get()) + .fieldPath("path_2.test_vector") + .build(), + DerivedSourceUtils.KNNVectorFieldTypeContext.builder() + .dimension(dimensionSupplier.get()) .fieldPath("path_2.update_vector") .isUpdate(true) .build(), @@ -230,9 +270,11 @@ protected List getObjectIndexContexts(Str .children( List.of( DerivedSourceUtils.KNNVectorFieldTypeContext.builder() + .dimension(dimensionSupplier.get()) .fieldPath("path_2.path_3.test_vector") .build(), DerivedSourceUtils.KNNVectorFieldTypeContext.builder() + .dimension(dimensionSupplier.get()) .fieldPath("path_2.path_3.update_vector") .isUpdate(true) .build(), @@ -243,8 +285,15 @@ protected List getObjectIndexContexts(Str ) ) .build(), - DerivedSourceUtils.KNNVectorFieldTypeContext.builder().fieldPath("test_vector").build(), - DerivedSourceUtils.KNNVectorFieldTypeContext.builder().fieldPath("update_vector").isUpdate(true).build(), + DerivedSourceUtils.KNNVectorFieldTypeContext.builder() + .dimension(dimensionSupplier.get()) + .fieldPath("test_vector") + .build(), + DerivedSourceUtils.KNNVectorFieldTypeContext.builder() + .dimension(dimensionSupplier.get()) + .fieldPath("update_vector") + .isUpdate(true) + .build(), DerivedSourceUtils.TextFieldType.builder().fieldPath("test-text").build(), DerivedSourceUtils.IntFieldType.builder().fieldPath("test-int").build() @@ -270,13 +319,24 @@ protected List getObjectIndexContexts(Str * "nested_1" : { * "type" : "nested", * "properties" : { + * "object_1" : { + * "properties" : { + * "test-int" : { + * "type" : "integer" + * }, + * "test_vector" : { + * "type" : "knn_vector", + * "dimension" : 64 + * } + * } + * }, * "test_vector" : { * "type" : "knn_vector", - * "dimension" : 16 + * "dimension" : 9 * }, * "update_vector" : { * "type" : "knn_vector", - * "dimension" : 16 + * "dimension" : 4 * } * } * }, @@ -291,11 +351,11 @@ protected List getObjectIndexContexts(Str * }, * "test_vector" : { * "type" : "knn_vector", - * "dimension" : 16 + * "dimension" : 27 * }, * "update_vector" : { * "type" : "knn_vector", - * "dimension" : 16 + * "dimension" : 14 * } * } * }, @@ -304,11 +364,28 @@ protected List getObjectIndexContexts(Str * }, * "test_vector" : { * "type" : "knn_vector", - * "dimension" : 16 + * "dimension" : 57 * }, * "update_vector" : { * "type" : "knn_vector", - * "dimension" : 16 + * "dimension" : 10 + * } + * } + * }, + * "object_1" : { + * "properties" : { + * "nested_1" : { + * "type" : "nested", + * "properties" : { + * "test_vector" : { + * "type" : "knn_vector", + * "dimension" : 30 + * } + * } + * }, + * "test_vector" : { + * "type" : "knn_vector", + * "dimension" : 51 * } * } * }, @@ -320,11 +397,11 @@ protected List getObjectIndexContexts(Str * }, * "test_vector" : { * "type" : "knn_vector", - * "dimension" : 16 + * "dimension" : 63 * }, * "update_vector" : { * "type" : "knn_vector", - * "dimension" : 16 + * "dimension" : 4 * } * } * } @@ -332,21 +409,63 @@ protected List getObjectIndexContexts(Str */ protected List getNestedIndexContexts(String testSuitePrefix, boolean addRandom) { List indexConfigContexts = new ArrayList<>(); + long consistentRandomSeed = random().nextLong(); for (Pair index : INDEX_PREFIX_TO_ENABLED) { + Supplier dimensionSupplier = randomIntegerSupplier(consistentRandomSeed, MIN_DIMENSION, MAX_DIMENSION); + Supplier randomDocCountSupplier = randomIntegerSupplier(consistentRandomSeed, MIN_DOCS, MAX_DOCS); DerivedSourceUtils.IndexConfigContext indexConfigContext = DerivedSourceUtils.IndexConfigContext.builder() .indexName(getIndexName(testSuitePrefix, index.getFirst(), addRandom)) + .docCount(randomDocCountSupplier.get()) .derivedEnabled(index.getSecond()) - .random(new Random(1)) + .random(new Random(consistentRandomSeed)) .fields( List.of( + DerivedSourceUtils.ObjectFieldContext.builder() + .fieldPath("object_1") + .children( + List.of( + DerivedSourceUtils.KNNVectorFieldTypeContext.builder() + .dimension(dimensionSupplier.get()) + .fieldPath("object_1.test_vector") + .build(), + DerivedSourceUtils.NestedFieldContext.builder() + .fieldPath("object_1.nested_1") + .children( + List.of( + DerivedSourceUtils.KNNVectorFieldTypeContext.builder() + .dimension(dimensionSupplier.get()) + .fieldPath("object_1.nested_1.test_vector") + .build() + ) + ) + .build() + ) + ) + .build(), DerivedSourceUtils.NestedFieldContext.builder() .fieldPath("nested_1") .children( List.of( - DerivedSourceUtils.KNNVectorFieldTypeContext.builder().fieldPath("nested_1.test_vector").build(), DerivedSourceUtils.KNNVectorFieldTypeContext.builder() + .dimension(dimensionSupplier.get()) + .fieldPath("nested_1.test_vector") + .build(), + DerivedSourceUtils.KNNVectorFieldTypeContext.builder() + .dimension(dimensionSupplier.get()) .fieldPath("nested_1.update_vector") .isUpdate(true) + .build(), + DerivedSourceUtils.ObjectFieldContext.builder() + .fieldPath("nested_1.object_1") + .children( + List.of( + DerivedSourceUtils.KNNVectorFieldTypeContext.builder() + .dimension(dimensionSupplier.get()) + .fieldPath("nested_1.object_1.test_vector") + .build(), + DerivedSourceUtils.IntFieldType.builder().fieldPath("nested_1.object_1.test-int").build() + ) + ) .build() ) ) @@ -356,19 +475,25 @@ protected List getNestedIndexContexts(Str .children( List.of( DerivedSourceUtils.TextFieldType.builder().fieldPath("nested_2.test-text").build(), - DerivedSourceUtils.KNNVectorFieldTypeContext.builder().fieldPath("nested_2.test_vector").build(), + DerivedSourceUtils.KNNVectorFieldTypeContext.builder() + .dimension(dimensionSupplier.get()) + .fieldPath("nested_2.test_vector") + .build(), DerivedSourceUtils.KNNVectorFieldTypeContext.builder() .fieldPath("nested_2.update_vector") .isUpdate(true) + .dimension(dimensionSupplier.get()) .build(), DerivedSourceUtils.NestedFieldContext.builder() .fieldPath("nested_2.nested_3") .children( List.of( DerivedSourceUtils.KNNVectorFieldTypeContext.builder() + .dimension(dimensionSupplier.get()) .fieldPath("nested_2.nested_3.test_vector") .build(), DerivedSourceUtils.KNNVectorFieldTypeContext.builder() + .dimension(dimensionSupplier.get()) .fieldPath("nested_2.nested_3.update_vector") .isUpdate(true) .build(), @@ -379,8 +504,15 @@ protected List getNestedIndexContexts(Str ) ) .build(), - DerivedSourceUtils.KNNVectorFieldTypeContext.builder().fieldPath("test_vector").build(), - DerivedSourceUtils.KNNVectorFieldTypeContext.builder().fieldPath("update_vector").isUpdate(true).build(), + DerivedSourceUtils.KNNVectorFieldTypeContext.builder() + .dimension(dimensionSupplier.get()) + .fieldPath("test_vector") + .build(), + DerivedSourceUtils.KNNVectorFieldTypeContext.builder() + .dimension(dimensionSupplier.get()) + .fieldPath("update_vector") + .isUpdate(true) + .build(), DerivedSourceUtils.TextFieldType.builder().fieldPath("test-text").build(), DerivedSourceUtils.IntFieldType.builder().fieldPath("test-int").build() ) @@ -409,8 +541,11 @@ protected void prepareOriginalIndices(List indexConf ); } + @SneakyThrows + protected void testSnapshotRestore( + String repository, + String snapshot, + List indexConfigContexts + ) { + DerivedSourceUtils.IndexConfigContext derivedSourceEnabledContext = indexConfigContexts.get(0); + DerivedSourceUtils.IndexConfigContext derivedSourceDisabledContext = indexConfigContexts.get(1); + DerivedSourceUtils.IndexConfigContext reindexFromEnabledToEnabledContext = indexConfigContexts.get(2); + DerivedSourceUtils.IndexConfigContext reindexFromEnabledToDisabledContext = indexConfigContexts.get(3); + DerivedSourceUtils.IndexConfigContext reindexFromDisabledToEnabledContext = indexConfigContexts.get(4); + DerivedSourceUtils.IndexConfigContext reindexFromDisabledToDisabledContext = indexConfigContexts.get(5); + + String originalIndexNameDerivedSourceEnabled = derivedSourceEnabledContext.indexName; + String originalIndexNameDerivedSourceDisabled = derivedSourceDisabledContext.indexName; + String reindexFromEnabledToEnabledIndexName = reindexFromEnabledToEnabledContext.indexName; + String reindexFromEnabledToDisabledIndexName = reindexFromEnabledToDisabledContext.indexName; + String reindexFromDisabledToEnabledIndexName = reindexFromDisabledToEnabledContext.indexName; + String reindexFromDisabledToDisabledIndexName = reindexFromDisabledToDisabledContext.indexName; + + createSnapshot(repository, snapshot, true); + + deleteIndex(originalIndexNameDerivedSourceEnabled); + deleteIndex(originalIndexNameDerivedSourceDisabled); + deleteIndex(reindexFromEnabledToEnabledIndexName); + deleteIndex(reindexFromEnabledToDisabledIndexName); + deleteIndex(reindexFromDisabledToEnabledIndexName); + deleteIndex(reindexFromDisabledToDisabledIndexName); + + String restoreSuffix = "-restored"; + restoreSnapshot( + restoreSuffix, + List.of( + originalIndexNameDerivedSourceEnabled, + originalIndexNameDerivedSourceDisabled, + reindexFromEnabledToEnabledIndexName, + reindexFromEnabledToDisabledIndexName, + reindexFromDisabledToEnabledIndexName, + reindexFromDisabledToDisabledIndexName + ), + repository, + snapshot, + true + ); + + originalIndexNameDerivedSourceEnabled += restoreSuffix; + originalIndexNameDerivedSourceDisabled += restoreSuffix; + reindexFromEnabledToEnabledIndexName += restoreSuffix; + reindexFromEnabledToDisabledIndexName += restoreSuffix; + reindexFromDisabledToEnabledIndexName += restoreSuffix; + reindexFromDisabledToDisabledIndexName += restoreSuffix; + + assertIndexBigger(originalIndexNameDerivedSourceDisabled, originalIndexNameDerivedSourceEnabled); + assertIndexBigger(originalIndexNameDerivedSourceDisabled, reindexFromEnabledToEnabledIndexName); + assertIndexBigger(originalIndexNameDerivedSourceDisabled, reindexFromDisabledToEnabledIndexName); + assertIndexBigger(reindexFromEnabledToDisabledIndexName, originalIndexNameDerivedSourceEnabled); + assertIndexBigger(reindexFromDisabledToDisabledIndexName, originalIndexNameDerivedSourceEnabled); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + reindexFromEnabledToEnabledIndexName + ); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + reindexFromDisabledToEnabledIndexName + ); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + reindexFromEnabledToDisabledIndexName + ); + assertDocsMatch( + derivedSourceDisabledContext.docCount, + originalIndexNameDerivedSourceDisabled, + reindexFromDisabledToDisabledIndexName + ); + } + @SneakyThrows protected void assertIndexBigger(String expectedBiggerIndex, String expectedSmallerIndex) { if (isExhaustive()) { @@ -718,4 +932,26 @@ protected String getIndexName(String testPrefix, String indexPrefix, boolean add } return indexName.toLowerCase(Locale.ROOT); } + + protected Supplier randomIntegerSupplier(long randomSeed, int min, int max) { + return randomIntegerSupplier(randomSeed, min, max, 1); + } + + protected Supplier randomIntegerSupplier(long randomSeed, int min, int max, int multipleOf) { + Random random = new Random(randomSeed); + return () -> { + // Calculate how many multiples fit within the range + int adjustedMin = (min + multipleOf - 1) / multipleOf * multipleOf; + int adjustedMax = max / multipleOf * multipleOf; + + // Generate a random number within the adjusted range + int randomMultiple = random.nextInt(adjustedMin / multipleOf, (adjustedMax / multipleOf) + 1) * multipleOf; + + return randomMultiple; + }; + } + + protected void validateDerivedSetting(String indexName, boolean expectedValue) throws IOException { + assertEquals(expectedValue, Boolean.parseBoolean(getIndexSettingByName(indexName, "index.knn.derived_source.enabled", true))); + } } diff --git a/src/testFixtures/java/org/opensearch/knn/DerivedSourceUtils.java b/src/testFixtures/java/org/opensearch/knn/DerivedSourceUtils.java index d33fdb1039..173dbd5628 100644 --- a/src/testFixtures/java/org/opensearch/knn/DerivedSourceUtils.java +++ b/src/testFixtures/java/org/opensearch/knn/DerivedSourceUtils.java @@ -33,6 +33,8 @@ public class DerivedSourceUtils { public static final int TEST_DIMENSION = 16; protected static final int DOCS = 500; + public static final float DEFAULT_NULL_PROB = 0.03f; + protected static final Settings DERIVED_ENABLED_SETTINGS = Settings.builder() .put( "number_of_shards", @@ -95,9 +97,7 @@ public static class IndexConfigContext { public Settings settings = null; public void init() { - if (random == null) { - random = new Random(1); - } + assert random != null; for (FieldContext context : fields) { context.init(random); } @@ -173,15 +173,17 @@ public static abstract class FieldContext { @Builder.Default public float skipProb = 0.1f; @Builder.Default + public float nullProb = DEFAULT_NULL_PROB; + @Builder.Default public boolean isUpdate = false; abstract XContentBuilder buildMapping(XContentBuilder builder) throws IOException; XContentBuilder buildDoc(XContentBuilder builder) throws IOException { - return buildDoc(builder, skipProb); + return buildDoc(builder, skipProb, nullProb); } - abstract XContentBuilder buildDoc(XContentBuilder builder, float skipProb) throws IOException; + abstract XContentBuilder buildDoc(XContentBuilder builder, float skipProb, float nullProb) throws IOException; abstract XContentBuilder partialUpdate(XContentBuilder builder) throws IOException; @@ -200,6 +202,10 @@ protected boolean shouldSkip(float skipProb) { return isUpdate == false && random.nextFloat() < skipProb; } + protected boolean shouldNull(float nullProb) { + return random.nextFloat() < nullProb; + } + String updateSourceString() throws IOException { return ""; } @@ -232,10 +238,10 @@ XContentBuilder buildMapping(XContentBuilder builder) throws IOException { } @Override - XContentBuilder buildDoc(XContentBuilder builder, float skipProb) throws IOException { + XContentBuilder buildDoc(XContentBuilder builder, float skipProb, float nullProb) throws IOException { builder.startObject(getFieldName()); for (FieldContext child : children) { - child.buildDoc(builder, skipProb); + child.buildDoc(builder, skipProb, nullProb); } builder.endObject(); return builder; @@ -281,7 +287,7 @@ XContentBuilder buildMapping(XContentBuilder builder) throws IOException { } @Override - XContentBuilder buildDoc(XContentBuilder builder, float skipProb) throws IOException { + XContentBuilder buildDoc(XContentBuilder builder, float skipProb, float nullProb) throws IOException { if (shouldSkip(skipProb)) { return builder; } @@ -290,7 +296,7 @@ XContentBuilder buildDoc(XContentBuilder builder, float skipProb) throws IOExcep if (docCount == 1) { builder.startObject(getFieldName()); for (FieldContext child : children) { - child.buildDoc(builder, skipProb); + child.buildDoc(builder, skipProb, nullProb); } builder.endObject(); return builder; @@ -300,7 +306,7 @@ XContentBuilder buildDoc(XContentBuilder builder, float skipProb) throws IOExcep for (int i = 0; i < docCount; i++) { builder.startObject(); for (FieldContext child : children) { - child.buildDoc(builder, skipProb); + child.buildDoc(builder, skipProb, nullProb); } builder.endObject(); } @@ -320,11 +326,12 @@ public abstract static class LeafFieldContext extends FieldContext { public Supplier valueSupplier; @Override - XContentBuilder buildDoc(XContentBuilder builder, float skipProb) throws IOException { + XContentBuilder buildDoc(XContentBuilder builder, float skipProb, float nullProb) throws IOException { if (shouldSkip(skipProb)) { return builder; } - return builder.field(getFieldName(), valueSupplier.get()); + Object value = shouldNull(nullProb) ? null : valueSupplier.get(); + return builder.field(getFieldName(), value); } public String updateSourceString() { diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 6a0c11fc4e..d312755f9e 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -16,6 +16,7 @@ import org.apache.commons.lang.StringUtils; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.net.URIBuilder; +import org.hamcrest.Matchers; import org.junit.AfterClass; import org.junit.Before; import org.opensearch.Version; @@ -24,6 +25,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.DeprecationHandler; @@ -2360,4 +2362,30 @@ protected void setupSnapshotRestore(String index, String snapshot, String reposi createSnapshot(repository, snapshot, true); } + protected static void restoreSnapshot( + String restoreIndexSuffix, + List indices, + String repository, + String snapshot, + boolean waitForCompletion + ) throws IOException { + // valid restore + XContentBuilder restoreCommand = JsonXContent.contentBuilder().startObject(); + restoreCommand.field("indices", String.join(",", indices)); + restoreCommand.field("rename_pattern", "(.+)"); + restoreCommand.field("rename_replacement", "$1" + restoreIndexSuffix); + restoreCommand.endObject(); + + Request restoreRequest = new Request("POST", "/_snapshot/" + repository + "/" + snapshot + "/_restore"); + restoreRequest.addParameter("wait_for_completion", "true"); + restoreRequest.setJsonEntity(restoreCommand.toString()); + + final Response restoreResponse = client().performRequest(restoreRequest); + assertThat( + "Failed to restore snapshot [" + snapshot + "] from repository [" + repository + "]: " + String.valueOf(restoreResponse), + restoreResponse.getStatusLine().getStatusCode(), + Matchers.equalTo(RestStatus.OK.getStatus()) + ); + } + }