Skip to content

Commit 4b21b63

Browse files
committed
Support native Maximal Marginal Relevance
Signed-off-by: Bo Zhang <[email protected]>
1 parent ba97971 commit 4b21b63

32 files changed

+3535
-55
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ All notable changes to this project are documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). See the [CONTRIBUTING guide](./CONTRIBUTING.md#Changelog) for instructions on how to add changelog entries.
66

77
## [Unreleased 3.3](https://github.com/opensearch-project/k-NN/compare/main...HEAD)
8+
### Features
9+
* Support native Maximal Marginal Relevance [#2868](https://github.com/opensearch-project/k-NN/pull/2868)
810
### Maintenance
911
* Replace commons-lang with org.apache.commons:commons-lang3 [#2863](https://github.com/opensearch-project/k-NN/pull/2863)
1012
* Bump OpenSearch-Protobufs to 0.13.0 [#2833](https://github.com/opensearch-project/k-NN/pull/2833)

src/main/java/org/opensearch/knn/common/KNNConstants.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,4 +190,12 @@ public class KNNConstants {
190190
public static final Integer INDEX_THREAD_QUANTITY_DEFAULT_LARGE = 4;
191191
public static final Integer INDEX_THREAD_QUANTITY_DEFAULT_SMALL = 1;
192192

193+
// mmr
194+
public static final String MMR = "mmr";
195+
public static final String DIVERSITY = "diversity";
196+
public static final String CANDIDATES = "candidates";
197+
public static final String VECTOR_FIELD_PATH = "vector_field_path";
198+
public static final String VECTOR_FIELD_DATA_TYPE = "vector_field_data_type";
199+
public static final String VECTOR_FIELD_SPACE_TYPE = "vector_field_space_type";
200+
public static final String MMR_RERANK_CONTEXT = "mmr.rerank_context";
193201
}

src/main/java/org/opensearch/knn/index/SpaceType.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
188188
public static SpaceType DEFAULT = L2;
189189
public static SpaceType DEFAULT_BINARY = HAMMING;
190190

191-
private static final String[] VALID_VALUES = Arrays.stream(SpaceType.values())
191+
public static final String[] VALID_VALUES = Arrays.stream(SpaceType.values())
192192
.filter(space -> space != SpaceType.UNDEFINED)
193193
.map(SpaceType::getValue)
194194
.collect(Collectors.toList())

src/main/java/org/opensearch/knn/index/engine/SpaceTypeResolver.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
package org.opensearch.knn.index.engine;
77

88
import org.apache.logging.log4j.util.Strings;
9-
import org.opensearch.common.settings.Settings;
109
import org.opensearch.index.mapper.MapperParsingException;
1110
import org.opensearch.knn.index.SpaceType;
1211
import org.opensearch.knn.index.VectorDataType;
@@ -38,15 +37,11 @@ private SpaceTypeResolver() {}
3837
public SpaceType resolveSpaceType(
3938
final KNNMethodContext knnMethodContext,
4039
final String topLevelSpaceTypeString,
41-
final Settings indexSettings,
4240
final VectorDataType vectorDataType
4341
) {
4442
SpaceType methodSpaceType = getSpaceTypeFromMethodContext(knnMethodContext);
4543
SpaceType topLevelSpaceType = getSpaceTypeFromString(topLevelSpaceTypeString);
4644

47-
// If we failed to find space type from both method context and top level
48-
// 1. We try to get it from index setting, which is a relic of legacy.
49-
// 2. Otherwise, we return a default one.
5045
if (isSpaceTypeConfigured(methodSpaceType) == false && isSpaceTypeConfigured(topLevelSpaceType) == false) {
5146
return getDefaultSpaceType(vectorDataType);
5247
}

src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,6 @@ public Mapper.Builder<?> parse(String name, Map<String, Object> node, ParserCont
424424
final SpaceType resolvedSpaceType = SpaceTypeResolver.INSTANCE.resolveSpaceType(
425425
builder.originalParameters.getKnnMethodContext(),
426426
builder.topLevelSpaceType.get(),
427-
parserContext.getSettings(),
428427
builder.vectorDataType.get()
429428
);
430429

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import lombok.AccessLevel;
99
import lombok.AllArgsConstructor;
1010
import lombok.Getter;
11+
import lombok.Setter;
1112
import lombok.extern.log4j.Log4j2;
1213
import org.apache.lucene.search.MatchNoDocsQuery;
1314
import org.apache.lucene.search.Query;
@@ -96,6 +97,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> imple
9697
private final String fieldName;
9798
private final float[] vector;
9899
@Getter
100+
@Setter
99101
private int k;
100102
@Getter
101103
private Float maxDistance;

src/main/java/org/opensearch/knn/index/util/KNNClusterUtil.java

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,23 @@
66
package org.opensearch.knn.index.util;
77

88
import lombok.AccessLevel;
9+
import lombok.Getter;
910
import lombok.NoArgsConstructor;
11+
import lombok.NonNull;
1012
import lombok.extern.log4j.Log4j2;
1113
import org.opensearch.Version;
14+
import org.opensearch.action.IndicesRequest;
15+
import org.opensearch.cluster.metadata.IndexMetadata;
16+
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
1217
import org.opensearch.cluster.service.ClusterService;
18+
import org.opensearch.core.index.Index;
19+
20+
import java.util.Arrays;
21+
import java.util.Collections;
22+
import java.util.List;
23+
import java.util.stream.Collectors;
24+
25+
import static org.opensearch.search.pipeline.SearchPipelineService.ENABLED_SYSTEM_GENERATED_FACTORIES_SETTING;
1326

1427
/**
1528
* Class abstracts information related to underlying OpenSearch cluster
@@ -20,6 +33,9 @@ public class KNNClusterUtil {
2033

2134
private ClusterService clusterService;
2235
private static KNNClusterUtil instance;
36+
private IndexNameExpressionResolver indexNameExpressionResolver;
37+
@Getter
38+
private List<String> enabledSystemGeneratedFactories = Collections.emptyList();
2339

2440
/**
2541
* Return instance of the cluster context, must be initialized first for proper usage
@@ -35,9 +51,17 @@ public static synchronized KNNClusterUtil instance() {
3551
/**
3652
* Initializes instance of cluster context by injecting dependencies
3753
* @param clusterService
54+
* @param indexNameExpressionResolver
3855
*/
39-
public void initialize(final ClusterService clusterService) {
56+
public void initialize(final ClusterService clusterService, final IndexNameExpressionResolver indexNameExpressionResolver) {
4057
this.clusterService = clusterService;
58+
this.indexNameExpressionResolver = indexNameExpressionResolver;
59+
this.enabledSystemGeneratedFactories = clusterService.getClusterSettings().get(ENABLED_SYSTEM_GENERATED_FACTORIES_SETTING);
60+
clusterService.getClusterSettings()
61+
.addSettingsUpdateConsumer(
62+
ENABLED_SYSTEM_GENERATED_FACTORIES_SETTING,
63+
factories -> enabledSystemGeneratedFactories = factories
64+
);
4165
}
4266

4367
/**
@@ -55,4 +79,16 @@ public Version getClusterMinVersion() {
5579
return Version.CURRENT;
5680
}
5781
}
82+
83+
/**
84+
*
85+
* @param searchRequest
86+
* @return IndexMetadata of the indices of the search request
87+
*/
88+
public List<IndexMetadata> getIndexMetadataList(@NonNull final IndicesRequest searchRequest) {
89+
final Index[] concreteIndices = this.indexNameExpressionResolver.concreteIndices(clusterService.state(), searchRequest);
90+
return Arrays.stream(concreteIndices)
91+
.map(concreteIndex -> clusterService.state().metadata().index(concreteIndex))
92+
.collect(Collectors.toList());
93+
}
5894
}

src/main/java/org/opensearch/knn/plugin/KNNPlugin.java

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.opensearch.index.codec.CodecServiceFactory;
3131
import org.opensearch.index.engine.EngineFactory;
3232
import org.opensearch.index.mapper.Mapper;
33+
import org.opensearch.index.query.QueryBuilder;
3334
import org.opensearch.index.shard.IndexSettingProvider;
3435
import org.opensearch.indices.SystemIndexDescriptor;
3536
import org.opensearch.knn.index.KNNCircuitBreaker;
@@ -88,6 +89,12 @@
8889
import org.opensearch.knn.plugin.transport.UpdateModelMetadataTransportAction;
8990
import org.opensearch.knn.profile.query.KNNMetrics;
9091
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCache;
92+
import org.opensearch.knn.search.extension.MMRSearchExtBuilder;
93+
94+
import org.opensearch.knn.search.processor.mmr.MMRKnnQueryTransformer;
95+
import org.opensearch.knn.search.processor.mmr.MMROverSampleProcessor;
96+
import org.opensearch.knn.search.processor.mmr.MMRQueryTransformer;
97+
import org.opensearch.knn.search.processor.mmr.MMRRerankProcessor;
9198
import org.opensearch.knn.training.TrainingJobClusterStateListener;
9299
import org.opensearch.knn.training.TrainingJobRunner;
93100
import org.opensearch.knn.training.VectorReader;
@@ -100,6 +107,7 @@
100107
import org.opensearch.plugins.Plugin;
101108
import org.opensearch.plugins.ReloadablePlugin;
102109
import org.opensearch.plugins.ScriptPlugin;
110+
import org.opensearch.plugins.SearchPipelinePlugin;
103111
import org.opensearch.plugins.SearchPlugin;
104112
import org.opensearch.plugins.SystemIndexPlugin;
105113
import org.opensearch.remoteindexbuild.client.RemoteIndexHTTPClient;
@@ -109,7 +117,11 @@
109117
import org.opensearch.script.ScriptContext;
110118
import org.opensearch.script.ScriptEngine;
111119
import org.opensearch.script.ScriptService;
120+
import org.opensearch.search.SearchExtBuilder;
112121
import org.opensearch.search.deciders.ConcurrentSearchRequestDecider;
122+
import org.opensearch.search.pipeline.SearchRequestProcessor;
123+
import org.opensearch.search.pipeline.SearchResponseProcessor;
124+
import org.opensearch.search.pipeline.SystemGeneratedProcessor;
113125
import org.opensearch.threadpool.ExecutorBuilder;
114126
import org.opensearch.threadpool.FixedExecutorBuilder;
115127
import org.opensearch.threadpool.ThreadPool;
@@ -120,7 +132,9 @@
120132
import java.util.Arrays;
121133
import java.util.Collection;
122134
import java.util.Collections;
135+
import java.util.HashMap;
123136
import java.util.List;
137+
import java.util.Locale;
124138
import java.util.Map;
125139
import java.util.Optional;
126140
import java.util.concurrent.ForkJoinPool;
@@ -173,14 +187,16 @@ public class KNNPlugin extends Plugin
173187
ScriptPlugin,
174188
ExtensiblePlugin,
175189
SystemIndexPlugin,
176-
ReloadablePlugin {
190+
ReloadablePlugin,
191+
SearchPipelinePlugin {
177192

178193
public static final String LEGACY_KNN_BASE_URI = "/_opendistro/_knn";
179194
public static final String KNN_BASE_URI = "/_plugins/_knn";
180195

181196
private KNNStats knnStats;
182197
private ClusterService clusterService;
183198
private Supplier<RepositoriesService> repositoriesServiceSupplier;
199+
private final Map<String, MMRQueryTransformer<? extends QueryBuilder>> mmrQueryTransformers = new HashMap<>();
184200

185201
static {
186202
ForkJoinPool.commonPool().execute(() -> {
@@ -237,7 +253,7 @@ public Collection<Object> createComponents(
237253
NativeMemoryLoadStrategy.TrainingLoadStrategy.initialize(vectorReader);
238254

239255
KNNSettings.state().initialize(client, clusterService);
240-
KNNClusterUtil.instance().initialize(clusterService);
256+
KNNClusterUtil.instance().initialize(clusterService, indexNameExpressionResolver);
241257
ModelDao.OpenSearchKNNModelDao.initialize(client, clusterService, environment.settings());
242258
ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
243259
TrainingJobRunner.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance());
@@ -356,6 +372,27 @@ public void onIndexModule(IndexModule indexModule) {
356372
}
357373
}
358374

375+
@Override
376+
public void loadExtensions(ExtensionLoader loader) {
377+
// knn plugin cannot extend itself so we have to manually load the transformer implemented in knn plugin
378+
mmrQueryTransformers.put(KNNQueryBuilder.NAME, new MMRKnnQueryTransformer());
379+
for (MMRQueryTransformer<?> transformer : loader.loadExtensions(MMRQueryTransformer.class)) {
380+
String queryName = transformer.getQueryName();
381+
if (mmrQueryTransformers.containsKey(queryName)) {
382+
throw new IllegalStateException(
383+
String.format(
384+
Locale.ROOT,
385+
"Already load the MMR query transformer %s for %s query. Cannot load another transformer %s for it.",
386+
mmrQueryTransformers.get(queryName).getClass().getName(),
387+
queryName,
388+
transformer.getClass().getName()
389+
)
390+
);
391+
}
392+
mmrQueryTransformers.put(queryName, transformer);
393+
}
394+
}
395+
359396
/**
360397
* Sample knn custom script
361398
*
@@ -447,4 +484,26 @@ public void reload(Settings settings) {
447484
SecureString password = KNNSettings.KNN_REMOTE_BUILD_SERVER_PASSWORD_SETTING.get(settings);
448485
RemoteIndexHTTPClient.reloadAuthHeader(username, password);
449486
}
487+
488+
@Override
489+
public List<SearchExtSpec<?>> getSearchExts() {
490+
return List.of(new SearchExtSpec<SearchExtBuilder>(MMRSearchExtBuilder.NAME, MMRSearchExtBuilder::new, MMRSearchExtBuilder::parse));
491+
}
492+
493+
@Override
494+
public Map<String, SystemGeneratedProcessor.SystemGeneratedFactory<SearchRequestProcessor>> getSystemGeneratedRequestProcessors(
495+
Parameters parameters
496+
) {
497+
return Map.of(
498+
MMROverSampleProcessor.MMROverSampleProcessorFactory.TYPE,
499+
new MMROverSampleProcessor.MMROverSampleProcessorFactory(parameters.client, mmrQueryTransformers)
500+
);
501+
}
502+
503+
@Override
504+
public Map<String, SystemGeneratedProcessor.SystemGeneratedFactory<SearchResponseProcessor>> getSystemGeneratedResponseProcessors(
505+
Parameters parameters
506+
) {
507+
return Map.of(MMRRerankProcessor.MMRRerankProcessorFactory.TYPE, new MMRRerankProcessor.MMRRerankProcessorFactory());
508+
}
450509
}

src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ && ensureSpaceTypeNotSet(topLevelSpaceType)) {
177177
SpaceType resolvedSpaceType = SpaceTypeResolver.INSTANCE.resolveSpaceType(
178178
knnMethodContext,
179179
topLevelSpaceType.getValue(),
180-
null,
181180
vectorDataType
182181
);
183182
setSpaceType(knnMethodContext, resolvedSpaceType);

0 commit comments

Comments
 (0)