Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit bca3abe

Browse files
committedJun 6, 2024·
adding pretrained model templates
Signed-off-by: Amit Galitzky <amgalitz@amazon.com>
1 parent 13b32f1 commit bca3abe

20 files changed

+414
-41
lines changed
 

‎CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
1717
### Enhancements
1818
- Add Workflow Step for Reindex from source index to destination ([#718](https://github.com/opensearch-project/flow-framework/pull/718))
1919
- Add param to delete workflow API to clear status even if resources exist ([#719](https://github.com/opensearch-project/flow-framework/pull/719))
20+
- Add additional default use cases ([#731](https://github.com/opensearch-project/flow-framework/pull/731))
2021
### Bug Fixes
2122
- Add user mapping to Workflow State index ([#705](https://github.com/opensearch-project/flow-framework/pull/705))
2223

‎build.gradle

+2
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ dependencies {
181181

182182
// ZipArchive dependencies used for integration tests
183183
zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}"
184+
zipArchive group: 'org.opensearch.plugin', name:'opensearch-knn', version: "${opensearch_build}"
185+
zipArchive group: 'org.opensearch.plugin', name:'neural-search', version: "${opensearch_build}"
184186
secureIntegTestPluginArchive group: 'org.opensearch.plugin', name:'opensearch-security', version: "${opensearch_build}"
185187

186188
configurations.all {

‎src/main/java/org/opensearch/flowframework/common/DefaultUseCases.java

+15
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,21 @@ public enum DefaultUseCases {
132132
"defaults/conversational-search-defaults.json",
133133
"substitutionTemplates/conversational-search-with-cohere-model-template.json",
134134
List.of(CREATE_CONNECTOR_CREDENTIAL_KEY)
135+
),
136+
/** defaults file and substitution ready template for semantic search with a local pretrained model*/
137+
SEMANTIC_SEARCH_WITH_LOCAL_MODEL(
138+
"semantic_search_with_local_model",
139+
"defaults/semantic-search-with-local-model-defaults.json",
140+
"substitutionTemplates/semantic-search-with-local-model-template.json",
141+
Collections.emptyList()
142+
143+
),
144+
/** defaults file and substitution ready template for hybrid search with a local pretrained model*/
145+
HYBRID_SEARCH_WITH_LOCAL_MODEL(
146+
"hybrid_search_with_local_model",
147+
"defaults/hybrid-search-with-local-model-defaults.json",
148+
"substitutionTemplates/hybrid-search-with-local-model-template.json",
149+
Collections.emptyList()
135150
);
136151

137152
private final String useCaseName;

‎src/main/resources/defaults/hybrid-search-defaults.json

+1-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,5 @@
1414
"text_embedding.field_map.output.dimension": "1024",
1515
"create_search_pipeline.pipeline_id": "nlp-search-pipeline",
1616
"normalization-processor.normalization.technique": "min_max",
17-
"normalization-processor.combination.technique": "arithmetic_mean",
18-
"normalization-processor.combination.parameters.weights": "[0.3, 0.7]"
17+
"normalization-processor.combination.technique": "arithmetic_mean"
1918
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
{
2+
"template.name": "hybrid-search",
3+
"template.description": "Setting up hybrid search, ingest pipeline and index",
4+
"register_local_pretrained_model.name": "huggingface/sentence-transformers/msmarco-distilbert-base-tas-b",
5+
"register_local_pretrained_model.description": "This is a sentence transformer model",
6+
"register_local_pretrained_model.model_format": "TORCH_SCRIPT",
7+
"register_local_pretrained_model.deploy": "true",
8+
"register_local_pretrained_model.version": "1.0.2",
9+
"create_ingest_pipeline.pipeline_id": "nlp-ingest-pipeline",
10+
"create_ingest_pipeline.description": "A text embedding pipeline",
11+
"create_ingest_pipeline.model_id": "123",
12+
"text_embedding.field_map.input": "passage_text",
13+
"text_embedding.field_map.output": "passage_embedding",
14+
"create_index.name": "my-nlp-index",
15+
"create_index.settings.number_of_shards": "2",
16+
"create_index.mappings.method.engine": "lucene",
17+
"create_index.mappings.method.space_type": "l2",
18+
"create_index.mappings.method.name": "hnsw",
19+
"text_embedding.field_map.output.dimension": "768",
20+
"create_search_pipeline.pipeline_id": "nlp-search-pipeline",
21+
"normalization-processor.normalization.technique": "min_max",
22+
"normalization-processor.combination.technique": "arithmetic_mean"
23+
}

‎src/main/resources/defaults/multi-modal-search-defaults.json

+3-1
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,7 @@
1111
"create_index.settings.number_of_shards": "2",
1212
"text_image_embedding.field_map.output.dimension": "1024",
1313
"create_index.mappings.method.engine": "lucene",
14-
"create_index.mappings.method.name": "hnsw"
14+
"create_index.mappings.method.name": "hnsw",
15+
"text_image_embedding.field_map.image.type": "text",
16+
"text_image_embedding.field_map.text.type": "text"
1517
}

‎src/main/resources/defaults/multimodal-search-bedrock-titan-defaults.json

+3-1
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,7 @@
2424
"create_index.settings.number_of_shards": "2",
2525
"text_image_embedding.field_map.output.dimension": "1024",
2626
"create_index.mappings.method.engine": "lucene",
27-
"create_index.mappings.method.name": "hnsw"
27+
"create_index.mappings.method.name": "hnsw",
28+
"text_image_embedding.field_map.image.type": "text",
29+
"text_image_embedding.field_map.text.type": "text"
2830
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"template.name": "semantic search with local pretrained model",
3+
"template.description": "Setting up semantic search, with a local pretrained embedding model",
4+
"register_local_pretrained_model.name": "huggingface/sentence-transformers/msmarco-distilbert-base-tas-b",
5+
"register_local_pretrained_model.description": "This is a sentence transformer model",
6+
"register_local_pretrained_model.model_format": "TORCH_SCRIPT",
7+
"register_local_pretrained_model.deploy": "true",
8+
"register_local_pretrained_model.version": "1.0.2",
9+
"create_ingest_pipeline.pipeline_id": "nlp-ingest-pipeline",
10+
"create_ingest_pipeline.description": "A text embedding pipeline",
11+
"text_embedding.field_map.input": "passage_text",
12+
"text_embedding.field_map.output": "passage_embedding",
13+
"create_index.name": "my-nlp-index",
14+
"create_index.settings.number_of_shards": "2",
15+
"create_index.mappings.method.engine": "lucene",
16+
"create_index.mappings.method.space_type": "l2",
17+
"create_index.mappings.method.name": "hnsw",
18+
"text_embedding.field_map.output.dimension": "768",
19+
"create_search_pipeline.pipeline_id": "default_model_pipeline"
20+
}

‎src/main/resources/substitutionTemplates/hybrid-search-template.json

+1-7
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,6 @@
5050
"mappings": {
5151
"_doc": {
5252
"properties": {
53-
"id": {
54-
"type": "text"
55-
},
5653
"${{text_embedding.field_map.output}}": {
5754
"type": "knn_vector",
5855
"dimension": "${{text_embedding.field_map.output.dimension}}",
@@ -86,10 +83,7 @@
8683
"technique": "${{normalization-processor.normalization.technique}}"
8784
},
8885
"combination": {
89-
"technique": "${{normalization-processor.combination.technique}}",
90-
"parameters": {
91-
"weights": "${{normalization-processor.combination.parameters.weights}}"
92-
}
86+
"technique": "${{normalization-processor.combination.technique}}"
9387
}
9488
}
9589
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
{
2+
"name": "${{template.name}}",
3+
"description": "${{template.description}}",
4+
"use_case": "HYBRID_SEARCH",
5+
"version": {
6+
"template": "1.0.0",
7+
"compatibility": [
8+
"2.12.0",
9+
"3.0.0"
10+
]
11+
},
12+
"workflows": {
13+
"provision": {
14+
"nodes": [
15+
{
16+
"id": "register_local_pretrained_model",
17+
"type": "register_local_pretrained_model",
18+
"user_inputs": {
19+
"name": "${{register_local_pretrained_model.name}}",
20+
"version": "${{register_local_pretrained_model.version}}",
21+
"description": "${{register_local_pretrained_model.description}}",
22+
"model_format": "${{register_local_pretrained_model.model_format}}",
23+
"deploy": true
24+
}
25+
},
26+
{
27+
"id": "create_ingest_pipeline",
28+
"type": "create_ingest_pipeline",
29+
"previous_node_inputs": {
30+
"register_local_pretrained_model": "model_id"
31+
},
32+
"user_inputs": {
33+
"pipeline_id": "${{create_ingest_pipeline.pipeline_id}}",
34+
"configurations": {
35+
"description": "${{create_ingest_pipeline.description}}",
36+
"processors": [
37+
{
38+
"text_embedding": {
39+
"model_id": "${{register_local_pretrained_model.model_id}}",
40+
"field_map": {
41+
"${{text_embedding.field_map.input}}": "${{text_embedding.field_map.output}}"
42+
}
43+
}
44+
}
45+
]
46+
}
47+
}
48+
},
49+
{
50+
"id": "create_index",
51+
"type": "create_index",
52+
"previous_node_inputs": {
53+
"create_ingest_pipeline": "pipeline_id"
54+
},
55+
"user_inputs": {
56+
"index_name": "${{create_index.name}}",
57+
"configurations": {
58+
"settings": {
59+
"index.knn": true,
60+
"default_pipeline": "${{create_ingest_pipeline.pipeline_id}}",
61+
"number_of_shards": "${{create_index.settings.number_of_shards}}",
62+
"index.search.default_pipeline": "${{create_search_pipeline.pipeline_id}}"
63+
},
64+
"mappings": {
65+
"properties": {
66+
"${{text_embedding.field_map.output}}": {
67+
"type": "knn_vector",
68+
"dimension": "${{text_embedding.field_map.output.dimension}}",
69+
"method": {
70+
"engine": "${{create_index.mappings.method.engine}}",
71+
"space_type": "${{create_index.mappings.method.space_type}}",
72+
"name": "${{create_index.mappings.method.name}}",
73+
"parameters": {}
74+
}
75+
},
76+
"${{text_embedding.field_map.input}}": {
77+
"type": "text"
78+
}
79+
}
80+
}
81+
}
82+
}
83+
},
84+
{
85+
"id": "create_search_pipeline",
86+
"type": "create_search_pipeline",
87+
"user_inputs": {
88+
"pipeline_id": "${{create_search_pipeline.pipeline_id}}",
89+
"configurations": {
90+
"description": "Post processor for hybrid search",
91+
"phase_results_processors": [
92+
{
93+
"normalization-processor": {
94+
"normalization": {
95+
"technique": "${{normalization-processor.normalization.technique}}"
96+
},
97+
"combination": {
98+
"technique": "${{normalization-processor.combination.technique}}"
99+
}
100+
}
101+
}
102+
]
103+
}
104+
}
105+
}
106+
]
107+
}
108+
}
109+
}

‎src/main/resources/substitutionTemplates/multi-modal-search-template.json

+2-5
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@
5151
"mappings": {
5252
"_doc": {
5353
"properties": {
54-
"id": {
55-
"type": "text"
56-
},
5754
"${{text_image_embedding.embedding}}": {
5855
"type": "knn_vector",
5956
"dimension": "${{text_image_embedding.field_map.output.dimension}}",
@@ -64,10 +61,10 @@
6461
}
6562
},
6663
"${{text_image_embedding.field_map.text}}": {
67-
"type": "text"
64+
"type": "${{text_image_embedding.field_map.text.type}}"
6865
},
6966
"${{text_image_embedding.field_map.image}}": {
70-
"type": "binary"
67+
"type": "${{text_image_embedding.field_map.image.type}}"
7168
}
7269
}
7370
}

‎src/main/resources/substitutionTemplates/multi-modal-search-with-bedrock-titan-template.json

+2-5
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,6 @@
101101
"mappings": {
102102
"_doc": {
103103
"properties": {
104-
"id": {
105-
"type": "text"
106-
},
107104
"${{text_image_embedding.embedding}}": {
108105
"type": "knn_vector",
109106
"dimension": "${{text_image_embedding.field_map.output.dimension}}",
@@ -114,10 +111,10 @@
114111
}
115112
},
116113
"${{text_image_embedding.field_map.text}}": {
117-
"type": "text"
114+
"type": "${{text_image_embedding.field_map.text.type}}"
118115
},
119116
"${{text_image_embedding.field_map.image}}": {
120-
"type": "binary"
117+
"type": "${{text_image_embedding.field_map.image.type}}"
121118
}
122119
}
123120
}

‎src/main/resources/substitutionTemplates/neural-sparse-local-biencoder-template.json

-3
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,6 @@
6161
"mappings": {
6262
"_doc": {
6363
"properties": {
64-
"id": {
65-
"type": "text"
66-
},
6764
"${{create_ingest_pipeline.text_embedding.field_map.output}}": {
6865
"type": "rank_features"
6966
},

‎src/main/resources/substitutionTemplates/semantic-search-template.json

-3
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,6 @@
4949
"mappings": {
5050
"_doc": {
5151
"properties": {
52-
"id": {
53-
"type": "text"
54-
},
5552
"${{text_embedding.field_map.output}}": {
5653
"type": "knn_vector",
5754
"dimension": "${{text_embedding.field_map.output.dimension}}",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
{
2+
"name": "${{template.name}}",
3+
"description": "${{template.description}}",
4+
"use_case": "SEMANTIC_SEARCH",
5+
"version": {
6+
"template": "1.0.0",
7+
"compatibility": [
8+
"2.12.0",
9+
"3.0.0"
10+
]
11+
},
12+
"workflows": {
13+
"provision": {
14+
"nodes": [
15+
{
16+
"id": "register_local_pretrained_model",
17+
"type": "register_local_pretrained_model",
18+
"user_inputs": {
19+
"name": "${{register_local_pretrained_model.name}}",
20+
"version": "${{register_local_pretrained_model.version}}",
21+
"description": "${{register_local_pretrained_model.description}}",
22+
"model_format": "${{register_local_pretrained_model.model_format}}",
23+
"deploy": true
24+
}
25+
},
26+
{
27+
"id": "create_ingest_pipeline",
28+
"type": "create_ingest_pipeline",
29+
"previous_node_inputs": {
30+
"register_local_pretrained_model": "model_id"
31+
},
32+
"user_inputs": {
33+
"pipeline_id": "${{create_ingest_pipeline.pipeline_id}}",
34+
"configurations": {
35+
"description": "${{create_ingest_pipeline.description}}",
36+
"processors": [
37+
{
38+
"text_embedding": {
39+
"model_id": "${{register_local_pretrained_model.model_id}}",
40+
"field_map": {
41+
"${{text_embedding.field_map.input}}": "${{text_embedding.field_map.output}}"
42+
}
43+
}
44+
}
45+
]
46+
}
47+
}
48+
},
49+
{
50+
"id": "create_index",
51+
"type": "create_index",
52+
"previous_node_inputs": {
53+
"create_ingest_pipeline": "pipeline_id"
54+
},
55+
"user_inputs": {
56+
"index_name": "${{create_index.name}}",
57+
"configurations": {
58+
"settings": {
59+
"index.knn": true,
60+
"default_pipeline": "${{create_ingest_pipeline.pipeline_id}}",
61+
"number_of_shards": "${{create_index.settings.number_of_shards}}"
62+
},
63+
"mappings": {
64+
"properties": {
65+
"${{text_embedding.field_map.output}}": {
66+
"type": "knn_vector",
67+
"dimension": "${{text_embedding.field_map.output.dimension}}",
68+
"method": {
69+
"engine": "${{create_index.mappings.method.engine}}",
70+
"space_type": "${{create_index.mappings.method.space_type}}",
71+
"name": "${{create_index.mappings.method.name}}",
72+
"parameters": {}
73+
}
74+
},
75+
"${{text_embedding.field_map.input}}": {
76+
"type": "text"
77+
}
78+
}
79+
}
80+
}
81+
}
82+
}
83+
]
84+
}
85+
}
86+
}

‎src/main/resources/substitutionTemplates/semantic-search-with-model-and-query-enricher-template.json

-3
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,6 @@
9999
"mappings": {
100100
"_doc": {
101101
"properties": {
102-
"id": {
103-
"type": "text"
104-
},
105102
"${{text_embedding.field_map.output}}": {
106103
"type": "knn_vector",
107104
"dimension": "${{text_embedding.field_map.output.dimension}}",

‎src/main/resources/substitutionTemplates/semantic-search-with-model-template.json

-3
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,6 @@
9898
"mappings": {
9999
"_doc": {
100100
"properties": {
101-
"id": {
102-
"type": "text"
103-
},
104101
"${{text_embedding.field_map.output}}": {
105102
"type": "knn_vector",
106103
"dimension": "${{text_embedding.field_map.output.dimension}}",

‎src/main/resources/substitutionTemplates/semantic-search-with-query-enricher-template.json

-3
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,6 @@
6767
"mappings": {
6868
"_doc": {
6969
"properties": {
70-
"id": {
71-
"type": "text"
72-
},
7370
"${{text_embedding.field_map.output}}": {
7471
"type": "knn_vector",
7572
"dimension": "${{text_embedding.field_map.output.dimension}}",

‎src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java

+70-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import org.opensearch.flowframework.model.State;
4949
import org.opensearch.flowframework.model.Template;
5050
import org.opensearch.flowframework.model.WorkflowState;
51+
import org.opensearch.flowframework.util.ParseUtils;
5152
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList;
5253
import org.opensearch.test.rest.OpenSearchRestTestCase;
5354
import org.junit.After;
@@ -350,7 +351,7 @@ protected Response createWorkflow(RestClient client, Template template) throws E
350351
* @throws Exception if the request fails
351352
* @return a rest response
352353
*/
353-
protected Response createWorkflowWithUseCase(RestClient client, String useCase, List<String> params) throws Exception {
354+
protected Response createWorkflowWithUseCaseWithNoValidation(RestClient client, String useCase, List<String> params) throws Exception {
354355

355356
StringBuilder sb = new StringBuilder();
356357
for (String param : params) {
@@ -370,6 +371,28 @@ protected Response createWorkflowWithUseCase(RestClient client, String useCase,
370371
);
371372
}
372373

374+
/**
375+
* Helper method to invoke the create workflow API with a use case and also the provision param as true
376+
* @param client the rest client
377+
* @param useCase the usecase to create
378+
* @param defaults the defaults to override given through the request payload
379+
* @throws Exception if the request fails
380+
* @return a rest response
381+
*/
382+
protected Response createAndProvisionWorkflowWithUseCaseWithContent(RestClient client, String useCase, Map<String, Object> defaults)
383+
throws Exception {
384+
String payload = ParseUtils.parseArbitraryStringToObjectMapToString(defaults);
385+
386+
return TestHelpers.makeRequest(
387+
client,
388+
"POST",
389+
WORKFLOW_URI + "?provision=true&use_case=" + useCase,
390+
Collections.emptyMap(),
391+
payload,
392+
null
393+
);
394+
}
395+
373396
/**
374397
* Helper method to invoke the Create Workflow Rest Action with provision
375398
* @param client the rest client
@@ -742,6 +765,52 @@ protected GetPipelineResponse getPipelines(String pipelineId) throws IOException
742765
}
743766
}
744767

768+
protected void ingestSingleDoc(String payload, String indexName) throws IOException {
769+
try {
770+
TestHelpers.makeRequest(
771+
client(),
772+
"PUT",
773+
indexName + "/_doc/1",
774+
null,
775+
payload,
776+
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, ""))
777+
);
778+
} catch (Exception e) {
779+
throw new RuntimeException(e);
780+
}
781+
}
782+
783+
protected SearchResponse neuralSearchRequest(String indexName, String modelId) throws IOException {
784+
String searchRequest =
785+
"{\"_source\":{\"excludes\":[\"passage_embedding\"]},\"query\":{\"neural\":{\"passage_embedding\":{\"query_text\":\"world\",\"k\":5,\"model_id\":\""
786+
+ modelId
787+
+ "\"}}}}";
788+
try {
789+
Response restSearchResponse = TestHelpers.makeRequest(
790+
client(),
791+
"POST",
792+
indexName + "/_search",
793+
null,
794+
searchRequest,
795+
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, ""))
796+
);
797+
// Parse entity content into SearchResponse
798+
MediaType mediaType = MediaType.fromMediaType(restSearchResponse.getEntity().getContentType());
799+
try (
800+
XContentParser parser = mediaType.xContent()
801+
.createParser(
802+
NamedXContentRegistry.EMPTY,
803+
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
804+
restSearchResponse.getEntity().getContent()
805+
)
806+
) {
807+
return SearchResponse.fromXContent(parser);
808+
}
809+
} catch (Exception e) {
810+
throw new RuntimeException(e);
811+
}
812+
}
813+
745814
@SuppressWarnings("unchecked")
746815
protected List<String> catPlugins() throws IOException {
747816
Response response = TestHelpers.makeRequest(

‎src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java

+76-4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import java.time.Instant;
3535
import java.util.Collections;
3636
import java.util.EnumSet;
37+
import java.util.HashMap;
3738
import java.util.HashSet;
3839
import java.util.List;
3940
import java.util.Map;
@@ -429,7 +430,11 @@ public void testCreateAndProvisionIngestAndSearchPipeline() throws Exception {
429430
public void testDefaultCohereUseCase() throws Exception {
430431

431432
// Hit Create Workflow API with original template
432-
Response response = createWorkflowWithUseCase(client(), "cohere_embedding_model_deploy", List.of(CREATE_CONNECTOR_CREDENTIAL_KEY));
433+
Response response = createWorkflowWithUseCaseWithNoValidation(
434+
client(),
435+
"cohere_embedding_model_deploy",
436+
List.of(CREATE_CONNECTOR_CREDENTIAL_KEY)
437+
);
433438
assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response));
434439

435440
Map<String, Object> responseMap = entityAsMap(response);
@@ -468,15 +473,19 @@ public void testDefaultSemanticSearchUseCaseWithFailureExpected() throws Excepti
468473
// Hit Create Workflow API with original template without required params
469474
ResponseException exception = expectThrows(
470475
ResponseException.class,
471-
() -> createWorkflowWithUseCase(client(), "semantic_search", Collections.emptyList())
476+
() -> createWorkflowWithUseCaseWithNoValidation(client(), "semantic_search", Collections.emptyList())
472477
);
473478
assertTrue(
474479
exception.getMessage()
475480
.contains("Missing the following required parameters for use case [semantic_search] : [create_ingest_pipeline.model_id]")
476481
);
477482

478483
// Pass in required params
479-
Response response = createWorkflowWithUseCase(client(), "semantic_search", List.of(CREATE_INGEST_PIPELINE_MODEL_ID));
484+
Response response = createWorkflowWithUseCaseWithNoValidation(
485+
client(),
486+
"semantic_search",
487+
List.of(CREATE_INGEST_PIPELINE_MODEL_ID)
488+
);
480489
assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response));
481490

482491
Map<String, Object> responseMap = entityAsMap(response);
@@ -502,7 +511,7 @@ public void testAllDefaultUseCasesCreation() throws Exception {
502511
.collect(Collectors.toSet());
503512

504513
for (String useCaseName : allUseCaseNames) {
505-
Response response = createWorkflowWithUseCase(
514+
Response response = createWorkflowWithUseCaseWithNoValidation(
506515
client(),
507516
useCaseName,
508517
DefaultUseCases.getRequiredParamsByUseCaseName(useCaseName)
@@ -514,4 +523,67 @@ public void testAllDefaultUseCasesCreation() throws Exception {
514523
getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED);
515524
}
516525
}
526+
527+
public void testSemanticSearchWithLocalModelEndToEnd() throws Exception {
528+
529+
Map<String, Object> defaults = new HashMap<>();
530+
defaults.put("register_local_pretrained_model.name", "huggingface/sentence-transformers/all-MiniLM-L6-v2");
531+
defaults.put("register_local_pretrained_model.version", "1.0.1");
532+
defaults.put("text_embedding.field_map.output.dimension", 384);
533+
534+
Response response = createAndProvisionWorkflowWithUseCaseWithContent(client(), "semantic_search_with_local_model", defaults);
535+
assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response));
536+
537+
Map<String, Object> responseMap = entityAsMap(response);
538+
String workflowId = (String) responseMap.get(WORKFLOW_ID);
539+
getAndAssertWorkflowStatus(client(), workflowId, State.PROVISIONING, ProvisioningProgress.IN_PROGRESS);
540+
541+
// Wait until provisioning has completed successfully before attempting to retrieve created resources
542+
List<ResourceCreated> resourcesCreated = getResourcesCreated(client(), workflowId, 45);
543+
544+
// This template should create 4 resources, registered model_id, deployed model_id, ingest pipeline, and index name
545+
assertEquals(4, resourcesCreated.size());
546+
String modelId = resourcesCreated.get(1).resourceId();
547+
String indexName = resourcesCreated.get(3).resourceId();
548+
549+
// Short wait before ingesting data
550+
Thread.sleep(30000);
551+
552+
String docContent = "{\"passage_text\": \"Hello planet\"\n}";
553+
ingestSingleDoc(docContent, indexName);
554+
// Short wait before neural search
555+
Thread.sleep(500);
556+
SearchResponse neuralSearchResponse = neuralSearchRequest(indexName, modelId);
557+
assertEquals(neuralSearchResponse.getHits().getHits().length, 1);
558+
Thread.sleep(500);
559+
deleteIndex(indexName);
560+
561+
// Hit Deprovision API
562+
// By design, this may not completely deprovision the first time if it takes >2s to process removals
563+
Response deprovisionResponse = deprovisionWorkflow(client(), workflowId);
564+
try {
565+
assertBusy(
566+
() -> { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); },
567+
30,
568+
TimeUnit.SECONDS
569+
);
570+
} catch (ComparisonFailure e) {
571+
// 202 return if still processing
572+
assertEquals(RestStatus.ACCEPTED, TestHelpers.restStatus(deprovisionResponse));
573+
}
574+
if (TestHelpers.restStatus(deprovisionResponse) == RestStatus.ACCEPTED) {
575+
// Short wait before we try again
576+
Thread.sleep(10000);
577+
deprovisionResponse = deprovisionWorkflow(client(), workflowId);
578+
assertBusy(
579+
() -> { getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); },
580+
30,
581+
TimeUnit.SECONDS
582+
);
583+
}
584+
assertEquals(RestStatus.OK, TestHelpers.restStatus(deprovisionResponse));
585+
// Hit Delete API
586+
Response deleteResponse = deleteWorkflow(client(), workflowId);
587+
assertEquals(RestStatus.OK, TestHelpers.restStatus(deleteResponse));
588+
}
517589
}

0 commit comments

Comments
 (0)
Please sign in to comment.