|  | 
| 30 | 30 | import org.opensearch.index.codec.CodecServiceFactory; | 
| 31 | 31 | import org.opensearch.index.engine.EngineFactory; | 
| 32 | 32 | import org.opensearch.index.mapper.Mapper; | 
|  | 33 | +import org.opensearch.index.query.QueryBuilder; | 
| 33 | 34 | import org.opensearch.index.shard.IndexSettingProvider; | 
| 34 | 35 | import org.opensearch.indices.SystemIndexDescriptor; | 
| 35 | 36 | import org.opensearch.knn.index.KNNCircuitBreaker; | 
|  | 
| 88 | 89 | import org.opensearch.knn.plugin.transport.UpdateModelMetadataTransportAction; | 
| 89 | 90 | import org.opensearch.knn.profile.query.KNNMetrics; | 
| 90 | 91 | import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCache; | 
|  | 92 | +import org.opensearch.knn.search.extension.MMRSearchExtBuilder; | 
|  | 93 | + | 
|  | 94 | +import org.opensearch.knn.search.processor.mmr.MMRKnnQueryTransformer; | 
|  | 95 | +import org.opensearch.knn.search.processor.mmr.MMROverSampleProcessor; | 
|  | 96 | +import org.opensearch.knn.search.processor.mmr.MMRQueryTransformer; | 
|  | 97 | +import org.opensearch.knn.search.processor.mmr.MMRRerankProcessor; | 
| 91 | 98 | import org.opensearch.knn.training.TrainingJobClusterStateListener; | 
| 92 | 99 | import org.opensearch.knn.training.TrainingJobRunner; | 
| 93 | 100 | import org.opensearch.knn.training.VectorReader; | 
|  | 
| 100 | 107 | import org.opensearch.plugins.Plugin; | 
| 101 | 108 | import org.opensearch.plugins.ReloadablePlugin; | 
| 102 | 109 | import org.opensearch.plugins.ScriptPlugin; | 
|  | 110 | +import org.opensearch.plugins.SearchPipelinePlugin; | 
| 103 | 111 | import org.opensearch.plugins.SearchPlugin; | 
| 104 | 112 | import org.opensearch.plugins.SystemIndexPlugin; | 
| 105 | 113 | import org.opensearch.remoteindexbuild.client.RemoteIndexHTTPClient; | 
|  | 
| 109 | 117 | import org.opensearch.script.ScriptContext; | 
| 110 | 118 | import org.opensearch.script.ScriptEngine; | 
| 111 | 119 | import org.opensearch.script.ScriptService; | 
|  | 120 | +import org.opensearch.search.SearchExtBuilder; | 
| 112 | 121 | import org.opensearch.search.deciders.ConcurrentSearchRequestDecider; | 
|  | 122 | +import org.opensearch.search.pipeline.SearchRequestProcessor; | 
|  | 123 | +import org.opensearch.search.pipeline.SearchResponseProcessor; | 
|  | 124 | +import org.opensearch.search.pipeline.SystemGeneratedProcessor; | 
| 113 | 125 | import org.opensearch.threadpool.ExecutorBuilder; | 
| 114 | 126 | import org.opensearch.threadpool.FixedExecutorBuilder; | 
| 115 | 127 | import org.opensearch.threadpool.ThreadPool; | 
|  | 
| 120 | 132 | import java.util.Arrays; | 
| 121 | 133 | import java.util.Collection; | 
| 122 | 134 | import java.util.Collections; | 
|  | 135 | +import java.util.HashMap; | 
| 123 | 136 | import java.util.List; | 
|  | 137 | +import java.util.Locale; | 
| 124 | 138 | import java.util.Map; | 
| 125 | 139 | import java.util.Optional; | 
| 126 | 140 | import java.util.concurrent.ForkJoinPool; | 
| @@ -173,14 +187,16 @@ public class KNNPlugin extends Plugin | 
| 173 | 187 |         ScriptPlugin, | 
| 174 | 188 |         ExtensiblePlugin, | 
| 175 | 189 |         SystemIndexPlugin, | 
| 176 |  | -        ReloadablePlugin { | 
|  | 190 | +        ReloadablePlugin, | 
|  | 191 | +        SearchPipelinePlugin { | 
| 177 | 192 | 
 | 
| 178 | 193 |     public static final String LEGACY_KNN_BASE_URI = "/_opendistro/_knn"; | 
| 179 | 194 |     public static final String KNN_BASE_URI = "/_plugins/_knn"; | 
| 180 | 195 | 
 | 
| 181 | 196 |     private KNNStats knnStats; | 
| 182 | 197 |     private ClusterService clusterService; | 
| 183 | 198 |     private Supplier<RepositoriesService> repositoriesServiceSupplier; | 
|  | 199 | +    private final Map<String, MMRQueryTransformer<? extends QueryBuilder>> mmrQueryTransformers = new HashMap<>(); | 
| 184 | 200 | 
 | 
| 185 | 201 |     static { | 
| 186 | 202 |         ForkJoinPool.commonPool().execute(() -> { | 
| @@ -237,7 +253,7 @@ public Collection<Object> createComponents( | 
| 237 | 253 |         NativeMemoryLoadStrategy.TrainingLoadStrategy.initialize(vectorReader); | 
| 238 | 254 | 
 | 
| 239 | 255 |         KNNSettings.state().initialize(client, clusterService); | 
| 240 |  | -        KNNClusterUtil.instance().initialize(clusterService); | 
|  | 256 | +        KNNClusterUtil.instance().initialize(clusterService, indexNameExpressionResolver); | 
| 241 | 257 |         ModelDao.OpenSearchKNNModelDao.initialize(client, clusterService, environment.settings()); | 
| 242 | 258 |         ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); | 
| 243 | 259 |         TrainingJobRunner.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance()); | 
| @@ -356,6 +372,27 @@ public void onIndexModule(IndexModule indexModule) { | 
| 356 | 372 |         } | 
| 357 | 373 |     } | 
| 358 | 374 | 
 | 
|  | 375 | +    @Override | 
|  | 376 | +    public void loadExtensions(ExtensionLoader loader) { | 
|  | 377 | +        // knn plugin cannot extend itself so we have to manually load the transformer implemented in knn plugin | 
|  | 378 | +        mmrQueryTransformers.put(KNNQueryBuilder.NAME, new MMRKnnQueryTransformer()); | 
|  | 379 | +        for (MMRQueryTransformer<?> transformer : loader.loadExtensions(MMRQueryTransformer.class)) { | 
|  | 380 | +            String queryName = transformer.getQueryName(); | 
|  | 381 | +            if (mmrQueryTransformers.containsKey(queryName)) { | 
|  | 382 | +                throw new IllegalStateException( | 
|  | 383 | +                    String.format( | 
|  | 384 | +                        Locale.ROOT, | 
|  | 385 | +                        "Already load the MMR query transformer %s for %s query. Cannot load another transformer %s for it.", | 
|  | 386 | +                        mmrQueryTransformers.get(queryName).getClass().getName(), | 
|  | 387 | +                        queryName, | 
|  | 388 | +                        transformer.getClass().getName() | 
|  | 389 | +                    ) | 
|  | 390 | +                ); | 
|  | 391 | +            } | 
|  | 392 | +            mmrQueryTransformers.put(queryName, transformer); | 
|  | 393 | +        } | 
|  | 394 | +    } | 
|  | 395 | + | 
| 359 | 396 |     /** | 
| 360 | 397 |      * Sample knn custom script | 
| 361 | 398 |      * | 
| @@ -447,4 +484,26 @@ public void reload(Settings settings) { | 
| 447 | 484 |         SecureString password = KNNSettings.KNN_REMOTE_BUILD_SERVER_PASSWORD_SETTING.get(settings); | 
| 448 | 485 |         RemoteIndexHTTPClient.reloadAuthHeader(username, password); | 
| 449 | 486 |     } | 
|  | 487 | + | 
|  | 488 | +    @Override | 
|  | 489 | +    public List<SearchExtSpec<?>> getSearchExts() { | 
|  | 490 | +        return List.of(new SearchExtSpec<SearchExtBuilder>(MMRSearchExtBuilder.NAME, MMRSearchExtBuilder::new, MMRSearchExtBuilder::parse)); | 
|  | 491 | +    } | 
|  | 492 | + | 
|  | 493 | +    @Override | 
|  | 494 | +    public Map<String, SystemGeneratedProcessor.SystemGeneratedFactory<SearchRequestProcessor>> getSystemGeneratedRequestProcessors( | 
|  | 495 | +        Parameters parameters | 
|  | 496 | +    ) { | 
|  | 497 | +        return Map.of( | 
|  | 498 | +            MMROverSampleProcessor.MMROverSampleProcessorFactory.TYPE, | 
|  | 499 | +            new MMROverSampleProcessor.MMROverSampleProcessorFactory(parameters.client, mmrQueryTransformers) | 
|  | 500 | +        ); | 
|  | 501 | +    } | 
|  | 502 | + | 
|  | 503 | +    @Override | 
|  | 504 | +    public Map<String, SystemGeneratedProcessor.SystemGeneratedFactory<SearchResponseProcessor>> getSystemGeneratedResponseProcessors( | 
|  | 505 | +        Parameters parameters | 
|  | 506 | +    ) { | 
|  | 507 | +        return Map.of(MMRRerankProcessor.MMRRerankProcessorFactory.TYPE, new MMRRerankProcessor.MMRRerankProcessorFactory()); | 
|  | 508 | +    } | 
| 450 | 509 | } | 
0 commit comments