Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ All notable changes to this project are documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). See the [CONTRIBUTING guide](./CONTRIBUTING.md#Changelog) for instructions on how to add changelog entries.

## [Unreleased 3.3](https://github.com/opensearch-project/k-NN/compare/main...HEAD)
### Features
* Support native Maximal Marginal Relevance [#2868](https://github.com/opensearch-project/k-NN/pull/2868)
### Maintenance
* Replace commons-lang with org.apache.commons:commons-lang3 [#2863](https://github.com/opensearch-project/k-NN/pull/2863)
* Bump OpenSearch-Protobufs to 0.13.0 [#2833](https://github.com/opensearch-project/k-NN/pull/2833)
Expand Down
8 changes: 8 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,12 @@ public class KNNConstants {
public static final Integer INDEX_THREAD_QUANTITY_DEFAULT_LARGE = 4;
public static final Integer INDEX_THREAD_QUANTITY_DEFAULT_SMALL = 1;

// mmr
public static final String MMR = "mmr";
public static final String DIVERSITY = "diversity";
public static final String CANDIDATES = "candidates";
public static final String VECTOR_FIELD_PATH = "vector_field_path";
public static final String VECTOR_FIELD_DATA_TYPE = "vector_field_data_type";
public static final String VECTOR_FIELD_SPACE_TYPE = "vector_field_space_type";
public static final String MMR_RERANK_CONTEXT = "mmr.rerank_context";
}
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/knn/index/SpaceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
public static SpaceType DEFAULT = L2;
public static SpaceType DEFAULT_BINARY = HAMMING;

private static final String[] VALID_VALUES = Arrays.stream(SpaceType.values())
public static final String[] VALID_VALUES = Arrays.stream(SpaceType.values())
.filter(space -> space != SpaceType.UNDEFINED)
.map(SpaceType::getValue)
.collect(Collectors.toList())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.opensearch.knn.index.engine;

import org.apache.logging.log4j.util.Strings;
import org.opensearch.common.settings.Settings;
import org.opensearch.index.mapper.MapperParsingException;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
Expand Down Expand Up @@ -38,15 +37,11 @@ private SpaceTypeResolver() {}
public SpaceType resolveSpaceType(
final KNNMethodContext knnMethodContext,
final String topLevelSpaceTypeString,
final Settings indexSettings,
final VectorDataType vectorDataType
) {
SpaceType methodSpaceType = getSpaceTypeFromMethodContext(knnMethodContext);
SpaceType topLevelSpaceType = getSpaceTypeFromString(topLevelSpaceTypeString);

// If we failed to find space type from both method context and top level
// 1. We try to get it from index setting, which is a relic of legacy.
// 2. Otherwise, we return a default one.
if (isSpaceTypeConfigured(methodSpaceType) == false && isSpaceTypeConfigured(topLevelSpaceType) == false) {
return getDefaultSpaceType(vectorDataType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,6 @@ public Mapper.Builder<?> parse(String name, Map<String, Object> node, ParserCont
final SpaceType resolvedSpaceType = SpaceTypeResolver.INSTANCE.resolveSpaceType(
builder.originalParameters.getKnnMethodContext(),
builder.topLevelSpaceType.get(),
parserContext.getSettings(),
builder.vectorDataType.get()
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
Expand Down Expand Up @@ -96,6 +97,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> imple
private final String fieldName;
private final float[] vector;
@Getter
@Setter
private Integer k;
@Getter
private Float maxDistance;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,23 @@
package org.opensearch.knn.index.util;

import lombok.AccessLevel;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.log4j.Log4j2;
import org.opensearch.Version;
import org.opensearch.action.IndicesRequest;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.index.Index;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

import static org.opensearch.search.pipeline.SearchPipelineService.ENABLED_SYSTEM_GENERATED_FACTORIES_SETTING;

/**
* Class abstracts information related to underlying OpenSearch cluster
Expand All @@ -20,6 +33,9 @@ public class KNNClusterUtil {

private ClusterService clusterService;
private static KNNClusterUtil instance;
private IndexNameExpressionResolver indexNameExpressionResolver;
@Getter
private List<String> enabledSystemGeneratedFactories = Collections.emptyList();

/**
* Return instance of the cluster context, must be initialized first for proper usage
Expand All @@ -35,9 +51,17 @@ public static synchronized KNNClusterUtil instance() {
/**
* Initializes instance of cluster context by injecting dependencies
* @param clusterService
* @param indexNameExpressionResolver
*/
public void initialize(final ClusterService clusterService) {
public void initialize(final ClusterService clusterService, final IndexNameExpressionResolver indexNameExpressionResolver) {
this.clusterService = clusterService;
this.indexNameExpressionResolver = indexNameExpressionResolver;
this.enabledSystemGeneratedFactories = clusterService.getClusterSettings().get(ENABLED_SYSTEM_GENERATED_FACTORIES_SETTING);
clusterService.getClusterSettings()
.addSettingsUpdateConsumer(
ENABLED_SYSTEM_GENERATED_FACTORIES_SETTING,
factories -> enabledSystemGeneratedFactories = factories
);
}

/**
Expand All @@ -55,4 +79,16 @@ public Version getClusterMinVersion() {
return Version.CURRENT;
}
}

/**
*
* @param searchRequest
* @return IndexMetadata of the indices of the search request
*/
public List<IndexMetadata> getIndexMetadataList(@NonNull final IndicesRequest searchRequest) {
final Index[] concreteIndices = this.indexNameExpressionResolver.concreteIndices(clusterService.state(), searchRequest);
return Arrays.stream(concreteIndices)
.map(concreteIndex -> clusterService.state().metadata().index(concreteIndex))
.collect(Collectors.toList());
}
}
63 changes: 61 additions & 2 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.opensearch.index.codec.CodecServiceFactory;
import org.opensearch.index.engine.EngineFactory;
import org.opensearch.index.mapper.Mapper;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.shard.IndexSettingProvider;
import org.opensearch.indices.SystemIndexDescriptor;
import org.opensearch.knn.index.KNNCircuitBreaker;
Expand Down Expand Up @@ -88,6 +89,12 @@
import org.opensearch.knn.plugin.transport.UpdateModelMetadataTransportAction;
import org.opensearch.knn.profile.query.KNNMetrics;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCache;
import org.opensearch.knn.search.extension.MMRSearchExtBuilder;

import org.opensearch.knn.search.processor.mmr.MMRKnnQueryTransformer;
import org.opensearch.knn.search.processor.mmr.MMROverSampleProcessor;
import org.opensearch.knn.search.processor.mmr.MMRQueryTransformer;
import org.opensearch.knn.search.processor.mmr.MMRRerankProcessor;
import org.opensearch.knn.training.TrainingJobClusterStateListener;
import org.opensearch.knn.training.TrainingJobRunner;
import org.opensearch.knn.training.VectorReader;
Expand All @@ -100,6 +107,7 @@
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.ReloadablePlugin;
import org.opensearch.plugins.ScriptPlugin;
import org.opensearch.plugins.SearchPipelinePlugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.plugins.SystemIndexPlugin;
import org.opensearch.remoteindexbuild.client.RemoteIndexHTTPClient;
Expand All @@ -109,7 +117,11 @@
import org.opensearch.script.ScriptContext;
import org.opensearch.script.ScriptEngine;
import org.opensearch.script.ScriptService;
import org.opensearch.search.SearchExtBuilder;
import org.opensearch.search.deciders.ConcurrentSearchRequestDecider;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.search.pipeline.SystemGeneratedProcessor;
import org.opensearch.threadpool.ExecutorBuilder;
import org.opensearch.threadpool.FixedExecutorBuilder;
import org.opensearch.threadpool.ThreadPool;
Expand All @@ -120,7 +132,9 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ForkJoinPool;
Expand Down Expand Up @@ -173,14 +187,16 @@ public class KNNPlugin extends Plugin
ScriptPlugin,
ExtensiblePlugin,
SystemIndexPlugin,
ReloadablePlugin {
ReloadablePlugin,
SearchPipelinePlugin {

public static final String LEGACY_KNN_BASE_URI = "/_opendistro/_knn";
public static final String KNN_BASE_URI = "/_plugins/_knn";

private KNNStats knnStats;
private ClusterService clusterService;
private Supplier<RepositoriesService> repositoriesServiceSupplier;
private final Map<String, MMRQueryTransformer<? extends QueryBuilder>> mmrQueryTransformers = new HashMap<>();

static {
ForkJoinPool.commonPool().execute(() -> {
Expand Down Expand Up @@ -237,7 +253,7 @@ public Collection<Object> createComponents(
NativeMemoryLoadStrategy.TrainingLoadStrategy.initialize(vectorReader);

KNNSettings.state().initialize(client, clusterService);
KNNClusterUtil.instance().initialize(clusterService);
KNNClusterUtil.instance().initialize(clusterService, indexNameExpressionResolver);
ModelDao.OpenSearchKNNModelDao.initialize(client, clusterService, environment.settings());
ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
TrainingJobRunner.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance());
Expand Down Expand Up @@ -356,6 +372,27 @@ public void onIndexModule(IndexModule indexModule) {
}
}

@Override
public void loadExtensions(ExtensionLoader loader) {
// knn plugin cannot extend itself so we have to manually load the transformer implemented in knn plugin
mmrQueryTransformers.put(KNNQueryBuilder.NAME, new MMRKnnQueryTransformer());
for (MMRQueryTransformer<?> transformer : loader.loadExtensions(MMRQueryTransformer.class)) {
String queryName = transformer.getQueryName();
if (mmrQueryTransformers.containsKey(queryName)) {
throw new IllegalStateException(
String.format(
Locale.ROOT,
"Already load the MMR query transformer %s for %s query. Cannot load another transformer %s for it.",
mmrQueryTransformers.get(queryName).getClass().getName(),
queryName,
transformer.getClass().getName()
)
);
}
mmrQueryTransformers.put(queryName, transformer);
}
}

/**
* Sample knn custom script
*
Expand Down Expand Up @@ -447,4 +484,26 @@ public void reload(Settings settings) {
SecureString password = KNNSettings.KNN_REMOTE_BUILD_SERVER_PASSWORD_SETTING.get(settings);
RemoteIndexHTTPClient.reloadAuthHeader(username, password);
}

@Override
public List<SearchExtSpec<?>> getSearchExts() {
return List.of(new SearchExtSpec<SearchExtBuilder>(MMRSearchExtBuilder.NAME, MMRSearchExtBuilder::new, MMRSearchExtBuilder::parse));
}

@Override
public Map<String, SystemGeneratedProcessor.SystemGeneratedFactory<SearchRequestProcessor>> getSystemGeneratedRequestProcessors(
Parameters parameters
) {
return Map.of(
MMROverSampleProcessor.MMROverSampleProcessorFactory.TYPE,
new MMROverSampleProcessor.MMROverSampleProcessorFactory(parameters.client, mmrQueryTransformers)
);
}

@Override
public Map<String, SystemGeneratedProcessor.SystemGeneratedFactory<SearchResponseProcessor>> getSystemGeneratedResponseProcessors(
Parameters parameters
) {
return Map.of(MMRRerankProcessor.MMRRerankProcessorFactory.TYPE, new MMRRerankProcessor.MMRRerankProcessorFactory());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ && ensureSpaceTypeNotSet(topLevelSpaceType)) {
SpaceType resolvedSpaceType = SpaceTypeResolver.INSTANCE.resolveSpaceType(
knnMethodContext,
topLevelSpaceType.getValue(),
null,
vectorDataType
);
setSpaceType(knnMethodContext, resolvedSpaceType);
Expand Down
Loading
Loading