Skip to content

Commit a5fcbde

Browse files
opensearch-trigger-bot[bot]github-actions[bot]
andauthoredMar 17, 2024··
[Backport 2.x] Adding default use cases (#587)
Adding default use cases (#583) * initial default use case addition * adding IT and UT * addresing comments and adding more tests * addressing more comments and adding more UT * addressed more comments and more UT --------- (cherry picked from commit b148eb5) Signed-off-by: Amit Galitzky <amgalitz@amazon.com> Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

23 files changed

+1058
-21
lines changed
 

‎CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
1717
- Added create ingest pipeline step ([#558](https://github.com/opensearch-project/flow-framework/pull/558))
1818
- Added create search pipeline step ([#569](https://github.com/opensearch-project/flow-framework/pull/569))
1919
- Added create index step ([#574](https://github.com/opensearch-project/flow-framework/pull/574))
20+
- Added default use cases ([#583](https://github.com/opensearch-project/flow-framework/pull/583))
2021

2122
### Enhancements
2223
- Substitute REST path or body parameters in Workflow Steps ([#525](https://github.com/opensearch-project/flow-framework/pull/525))

‎build.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ dependencies {
179179

180180
// ZipArchive dependencies used for integration tests
181181
zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}"
182+
182183
secureIntegTestPluginArchive group: 'org.opensearch.plugin', name:'opensearch-security', version: "${opensearch_build}"
183184

184185
configurations.all {

‎src/main/java/org/opensearch/flowframework/common/CommonValue.java

+2
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ private CommonValue() {}
7272
public static final String PROVISION_WORKFLOW = "provision";
7373
/** The field name for workflow steps. This field represents the name of the workflow steps to be fetched. */
7474
public static final String WORKFLOW_STEP = "workflow_step";
75+
/** The param name for default use case, used by the create workflow API */
76+
public static final String USE_CASE = "use_case";
7577

7678
/*
7779
* Constants associated with plugin configuration
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* The OpenSearch Contributors require contributions made to
6+
* this file be licensed under the Apache-2.0 license or a
7+
* compatible open source license.
8+
*/
9+
package org.opensearch.flowframework.common;
10+
11+
import org.apache.logging.log4j.LogManager;
12+
import org.apache.logging.log4j.Logger;
13+
import org.opensearch.core.rest.RestStatus;
14+
import org.opensearch.flowframework.exception.FlowFrameworkException;
15+
16+
/**
17+
* Enum encapsulating the different default use cases and templates we have stored
18+
*/
19+
public enum DefaultUseCases {
20+
21+
/** defaults file and substitution ready template for OpenAI embedding model */
22+
OPEN_AI_EMBEDDING_MODEL_DEPLOY(
23+
"open_ai_embedding_model_deploy",
24+
"defaults/open-ai-embedding-defaults.json",
25+
"substitutionTemplates/deploy-remote-model-template.json"
26+
),
27+
/** defaults file and substitution ready template for cohere embedding model */
28+
COHERE_EMBEDDING_MODEL_DEPLOY(
29+
"cohere-embedding_model_deploy",
30+
"defaults/cohere-embedding-defaults.json",
31+
"substitutionTemplates/deploy-remote-model-template-extra-params.json"
32+
),
33+
/** defaults file and substitution ready template for local neural sparse model and ingest pipeline*/
34+
LOCAL_NEURAL_SPARSE_SEARCH(
35+
"local_neural_sparse_search",
36+
"defaults/local-sparse-search-defaults.json",
37+
"substitutionTemplates/neural-sparse-local-template.json"
38+
);
39+
40+
private final String useCaseName;
41+
private final String defaultsFile;
42+
private final String substitutionReadyFile;
43+
private static final Logger logger = LogManager.getLogger(DefaultUseCases.class);
44+
45+
DefaultUseCases(String useCaseName, String defaultsFile, String substitutionReadyFile) {
46+
this.useCaseName = useCaseName;
47+
this.defaultsFile = defaultsFile;
48+
this.substitutionReadyFile = substitutionReadyFile;
49+
}
50+
51+
/**
52+
* Returns the useCaseName for the given enum Constant
53+
* @return the useCaseName of this use case.
54+
*/
55+
public String getUseCaseName() {
56+
return useCaseName;
57+
}
58+
59+
/**
60+
* Returns the defaultsFile for the given enum Constant
61+
* @return the defaultsFile of this for the given useCase.
62+
*/
63+
public String getDefaultsFile() {
64+
return defaultsFile;
65+
}
66+
67+
/**
68+
* Returns the substitutionReadyFile for the given enum Constant
69+
* @return the substitutionReadyFile of the given useCase
70+
*/
71+
public String getSubstitutionReadyFile() {
72+
return substitutionReadyFile;
73+
}
74+
75+
/**
76+
* Gets the defaultsFile based on the given use case.
77+
* @param useCaseName name of the given use case
78+
* @return the defaultsFile for that usecase
79+
* @throws FlowFrameworkException if the use case doesn't exist in enum
80+
*/
81+
public static String getDefaultsFileByUseCaseName(String useCaseName) throws FlowFrameworkException {
82+
if (useCaseName != null && !useCaseName.isEmpty()) {
83+
for (DefaultUseCases usecase : values()) {
84+
if (useCaseName.equals(usecase.getUseCaseName())) {
85+
return usecase.getDefaultsFile();
86+
}
87+
}
88+
}
89+
logger.error("Unable to find defaults file for use case: {}", useCaseName);
90+
throw new FlowFrameworkException("Unable to find defaults file for use case: " + useCaseName, RestStatus.BAD_REQUEST);
91+
}
92+
93+
/**
94+
* Gets the substitutionReadyFile based on the given use case
95+
* @param useCaseName name of the given use case
96+
* @return the substitutionReadyFile which has the template
97+
* @throws FlowFrameworkException if the use case doesn't exist in enum
98+
*/
99+
public static String getSubstitutionReadyFileByUseCaseName(String useCaseName) throws FlowFrameworkException {
100+
if (useCaseName != null && !useCaseName.isEmpty()) {
101+
for (DefaultUseCases useCase : values()) {
102+
if (useCase.getUseCaseName().equals(useCaseName)) {
103+
return useCase.getSubstitutionReadyFile();
104+
}
105+
}
106+
}
107+
logger.error("Unable to find substitution ready file for use case: {}", useCaseName);
108+
throw new FlowFrameworkException("Unable to find substitution ready file for use case: " + useCaseName, RestStatus.BAD_REQUEST);
109+
}
110+
}

‎src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java

+65-5
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,19 @@
1717
import org.opensearch.core.xcontent.ToXContent;
1818
import org.opensearch.core.xcontent.XContentBuilder;
1919
import org.opensearch.core.xcontent.XContentParser;
20+
import org.opensearch.flowframework.common.DefaultUseCases;
2021
import org.opensearch.flowframework.common.FlowFrameworkSettings;
2122
import org.opensearch.flowframework.exception.FlowFrameworkException;
2223
import org.opensearch.flowframework.model.Template;
2324
import org.opensearch.flowframework.transport.CreateWorkflowAction;
2425
import org.opensearch.flowframework.transport.WorkflowRequest;
26+
import org.opensearch.flowframework.util.ParseUtils;
2527
import org.opensearch.rest.BaseRestHandler;
2628
import org.opensearch.rest.BytesRestResponse;
2729
import org.opensearch.rest.RestRequest;
2830

2931
import java.io.IOException;
32+
import java.util.Collections;
3033
import java.util.List;
3134
import java.util.Locale;
3235
import java.util.Map;
@@ -35,6 +38,7 @@
3538

3639
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
3740
import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW;
41+
import static org.opensearch.flowframework.common.CommonValue.USE_CASE;
3842
import static org.opensearch.flowframework.common.CommonValue.VALIDATION;
3943
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID;
4044
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;
@@ -78,6 +82,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
7882
String workflowId = request.param(WORKFLOW_ID);
7983
String[] validation = request.paramAsStringArray(VALIDATION, new String[] { "all" });
8084
boolean provision = request.paramAsBoolean(PROVISION_WORKFLOW, false);
85+
String useCase = request.param(USE_CASE);
8186
// If provisioning, consume all other params and pass to provision transport action
8287
Map<String, String> params = provision
8388
? request.params()
@@ -112,11 +117,63 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
112117
);
113118
}
114119
try {
115-
XContentParser parser = request.contentParser();
116-
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
117-
Template template = Template.parse(parser);
118120

119-
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, validation, provision, params);
121+
Template template;
122+
Map<String, String> useCaseDefaultsMap = Collections.emptyMap();
123+
if (useCase != null) {
124+
String useCaseTemplateFileInStringFormat = ParseUtils.resourceToString(
125+
"/" + DefaultUseCases.getSubstitutionReadyFileByUseCaseName(useCase)
126+
);
127+
String defaultsFilePath = DefaultUseCases.getDefaultsFileByUseCaseName(useCase);
128+
useCaseDefaultsMap = ParseUtils.parseJsonFileToStringToStringMap("/" + defaultsFilePath);
129+
130+
if (request.hasContent()) {
131+
try {
132+
XContentParser parser = request.contentParser();
133+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
134+
Map<String, String> userDefaults = ParseUtils.parseStringToStringMap(parser);
135+
// updates the default params with anything user has given that matches
136+
for (Map.Entry<String, String> userDefaultsEntry : userDefaults.entrySet()) {
137+
String key = userDefaultsEntry.getKey();
138+
String value = userDefaultsEntry.getValue();
139+
if (useCaseDefaultsMap.containsKey(key)) {
140+
useCaseDefaultsMap.put(key, value);
141+
}
142+
}
143+
} catch (Exception ex) {
144+
RestStatus status = ex instanceof IOException ? RestStatus.BAD_REQUEST : ExceptionsHelper.status(ex);
145+
String errorMessage = "failure parsing request body when a use case is given";
146+
logger.error(errorMessage, ex);
147+
throw new FlowFrameworkException(errorMessage, status);
148+
}
149+
150+
}
151+
152+
useCaseTemplateFileInStringFormat = (String) ParseUtils.conditionallySubstitute(
153+
useCaseTemplateFileInStringFormat,
154+
null,
155+
useCaseDefaultsMap
156+
);
157+
158+
XContentParser parserTestJson = ParseUtils.jsonToParser(useCaseTemplateFileInStringFormat);
159+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parserTestJson.currentToken(), parserTestJson);
160+
template = Template.parse(parserTestJson);
161+
162+
} else {
163+
XContentParser parser = request.contentParser();
164+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
165+
template = Template.parse(parser);
166+
}
167+
168+
WorkflowRequest workflowRequest = new WorkflowRequest(
169+
workflowId,
170+
template,
171+
validation,
172+
provision,
173+
params,
174+
useCase,
175+
useCaseDefaultsMap
176+
);
120177

121178
return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> {
122179
XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS);
@@ -134,11 +191,14 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
134191
channel.sendResponse(new BytesRestResponse(ExceptionsHelper.status(e), errorMessage));
135192
}
136193
}));
194+
137195
} catch (FlowFrameworkException e) {
196+
logger.error("failed to prepare rest request", e);
138197
return channel -> channel.sendResponse(
139198
new BytesRestResponse(e.getRestStatus(), e.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS))
140199
);
141-
} catch (IOException e) {
200+
} catch (Exception e) {
201+
logger.error("failed to prepare rest request", e);
142202
FlowFrameworkException ex = new FlowFrameworkException(
143203
"IOException: template content invalid for specified Content-Type.",
144204
RestStatus.BAD_REQUEST

‎src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java

+46-3
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,23 @@ public class WorkflowRequest extends ActionRequest {
4949
*/
5050
private Map<String, String> params;
5151

52+
/**
53+
* use case flag
54+
*/
55+
private String useCase;
56+
57+
/**
58+
* Deafult params map from use case
59+
*/
60+
private Map<String, String> defaultParams;
61+
5262
/**
5363
* Instantiates a new WorkflowRequest, set validation to all, no provisioning
5464
* @param workflowId the documentId of the workflow
5565
* @param template the use case template which describes the workflow
5666
*/
5767
public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) {
58-
this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap());
68+
this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap(), null, Collections.emptyMap());
5969
}
6070

6171
/**
@@ -65,7 +75,18 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template)
6575
* @param params The parameters from the REST path
6676
*/
6777
public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, Map<String, String> params) {
68-
this(workflowId, template, new String[] { "all" }, true, params);
78+
this(workflowId, template, new String[] { "all" }, true, params, null, Collections.emptyMap());
79+
}
80+
81+
/**
82+
* Instantiates a new WorkflowRequest with params map, set validation to all, provisioning to true
83+
* @param workflowId the documentId of the workflow
84+
* @param template the use case template which describes the workflow
85+
* @param useCase the default use case give by user
86+
* @param defaultParams The parameters from the REST body when a use case is given
87+
*/
88+
public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, String useCase, Map<String, String> defaultParams) {
89+
this(workflowId, template, new String[] { "all" }, false, Collections.emptyMap(), useCase, defaultParams);
6990
}
7091

7192
/**
@@ -75,13 +96,17 @@ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template,
7596
* @param validation flag to indicate if validation is necessary
7697
* @param provision flag to indicate if provision is necessary
7798
* @param params map of REST path params. If provision is false, must be an empty map.
99+
* @param useCase default use case given
100+
* @param defaultParams the params to be used in the substitution based on the default use case.
78101
*/
79102
public WorkflowRequest(
80103
@Nullable String workflowId,
81104
@Nullable Template template,
82105
String[] validation,
83106
boolean provision,
84-
Map<String, String> params
107+
Map<String, String> params,
108+
String useCase,
109+
Map<String, String> defaultParams
85110
) {
86111
this.workflowId = workflowId;
87112
this.template = template;
@@ -91,6 +116,8 @@ public WorkflowRequest(
91116
throw new IllegalArgumentException("Params may only be included when provisioning.");
92117
}
93118
this.params = params;
119+
this.useCase = useCase;
120+
this.defaultParams = defaultParams;
94121
}
95122

96123
/**
@@ -150,6 +177,22 @@ public Map<String, String> getParams() {
150177
return Map.copyOf(this.params);
151178
}
152179

180+
/**
181+
* Gets the use case
182+
* @return the use case
183+
*/
184+
public String getUseCase() {
185+
return this.useCase;
186+
}
187+
188+
/**
189+
* Gets the params map
190+
* @return the params map
191+
*/
192+
public Map<String, String> getDefaultParams() {
193+
return Map.copyOf(this.defaultParams);
194+
}
195+
153196
@Override
154197
public void writeTo(StreamOutput out) throws IOException {
155198
super.writeTo(out);

‎src/main/java/org/opensearch/flowframework/util/ParseUtils.java

+32-7
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ public class ParseUtils {
6161

6262
private ParseUtils() {}
6363

64+
private static final ObjectMapper mapper = new ObjectMapper();
65+
6466
/**
6567
* Converts a JSON string into an XContentParser
6668
*
@@ -342,11 +344,18 @@ public static Map<String, Object> getInputsFromPreviousSteps(
342344
return inputs;
343345
}
344346

345-
private static Object conditionallySubstitute(Object value, Map<String, WorkflowData> outputs, Map<String, String> params) {
347+
/**
348+
* Executes substitution on the given value by looking at any matching values in either the ouputs or params map
349+
* @param value the Object that will have the substitution done on
350+
* @param outputs potential location of values to be substituted in
351+
* @param params potential location of values to be subsituted in
352+
* @return the substituted object back
353+
*/
354+
public static Object conditionallySubstitute(Object value, Map<String, WorkflowData> outputs, Map<String, String> params) {
346355
if (value instanceof String) {
347356
Matcher m = SUBSTITUTION_PATTERN.matcher((String) value);
348357
StringBuilder result = new StringBuilder();
349-
while (m.find()) {
358+
while (m.find() && outputs != null) {
350359
// outputs content map contains values for previous node input (e.g: deploy_openai_model.model_id)
351360
// Check first if the substitution is looking for the same key, value pair and if yes
352361
// then replace it with the key value pair in the inputs map
@@ -364,10 +373,15 @@ private static Object conditionallySubstitute(Object value, Map<String, Workflow
364373
m.appendTail(result);
365374
value = result.toString();
366375

367-
// Replace all params if present
368-
for (Entry<String, String> e : params.entrySet()) {
369-
String regex = "\\$\\{\\{\\s*" + Pattern.quote(e.getKey()) + "\\s*\\}\\}";
370-
value = ((String) value).replaceAll(regex, e.getValue());
376+
if (params != null) {
377+
for (Map.Entry<String, String> e : params.entrySet()) {
378+
String regex = "\\$\\{\\{\\s*" + Pattern.quote(e.getKey()) + "\\s*\\}\\}";
379+
String replacement = e.getValue();
380+
381+
// Special handling for JSON strings that contain placeholders (connectors action)
382+
replacement = Matcher.quoteReplacement(replacement.replace("\"", "\\\""));
383+
value = ((String) value).replaceAll(regex, replacement);
384+
}
371385
}
372386
}
373387
return value;
@@ -380,9 +394,20 @@ private static Object conditionallySubstitute(Object value, Map<String, Workflow
380394
* @throws JsonProcessingException JsonProcessingException from Jackson for issues processing map
381395
*/
382396
public static String parseArbitraryStringToObjectMapToString(Map<String, Object> map) throws JsonProcessingException {
383-
ObjectMapper mapper = new ObjectMapper();
384397
// Convert the map to a JSON string
385398
String mappedString = mapper.writeValueAsString(map);
386399
return mappedString;
387400
}
401+
402+
/**
403+
* Generates a String to String map based on a Json File
404+
* @param path file path
405+
* @return instance of the string
406+
* @throws JsonProcessingException JsonProcessingException from Jackson for issues processing map
407+
*/
408+
public static Map<String, String> parseJsonFileToStringToStringMap(String path) throws IOException {
409+
String jsonContent = resourceToString(path);
410+
Map<String, String> mappedJsonFile = mapper.readValue(jsonContent, Map.class);
411+
return mappedJsonFile;
412+
}
388413
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"template.name": "deploy-cohere-model",
3+
"template.description": "deploying cohere embedding model",
4+
"create_connector.name": "cohere-embedding-connector",
5+
"create_connector.description": "The connector to Cohere's public embed API",
6+
"create_connector.protocol": "http",
7+
"create_connector.model": "embed-english-v3.0",
8+
"create_connector.input_type": "search_document",
9+
"create_connector.truncate": "end",
10+
"create_connector.endpoint": "api.openai.com",
11+
"create_connector.credential.key": "123",
12+
"create_connector.actions.url": "https://api.cohere.ai/v1/embed",
13+
"create_connector.actions.request_body": "{ \"texts\": ${parameters.texts}, \"truncate\": \"${parameters.truncate}\", \"model\": \"${parameters.model}\", \"input_type\": \"${parameters.input_type}\" }",
14+
"create_connector.actions.pre_process_function": "connector.pre_process.cohere.embedding",
15+
"create_connector.actions.post_process_function": "connector.post_process.cohere.embedding",
16+
"register_remote_model.name": "Cohere english embed model",
17+
"register_remote_model.description": "cohere-embedding-model"
18+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
{
2+
"template.name": "local-model-neural-sparse-search",
3+
"template.description": "setting up neural sparse search with local model",
4+
"register_local_sparse_encoding_model.name": "neural-sparse/opensearch-neural-sparse-tokenizer-v1-v2",
5+
"register_local_sparse_encoding_model.description": "This is a neural sparse tokenizer model: It tokenize input sentence into tokens and assign pre-defined weight from IDF to each. It serves only in query.",
6+
"register_local_sparse_encoding_model.node_timeout": "60s",
7+
"register_local_sparse_encoding_model.model_format": "TORCH_SCRIPT",
8+
"register_local_sparse_encoding_model.function_name": "SPARSE_TOKENIZE",
9+
"register_local_sparse_encoding_model.model_content_hash_value": "b3487da9c58ac90541b720f3b367084f271d280c7f3bdc3e6d9c9a269fb31950",
10+
"register_local_sparse_encoding_model.url": "https://artifacts.opensearch.org/models/ml-models/amazon/neural-sparse/opensearch-neural-sparse-tokenizer-v1/1.0.0/torch_script/opensearch-neural-sparse-tokenizer-v1-1.0.0.zip",
11+
"register_local_sparse_encoding_model.deploy": "true",
12+
"create_ingest_pipeline.pipeline_id": "nlp-ingest-pipeline-sparse",
13+
"create_ingest_pipeline.description": "A sparse encoding ingest pipeline",
14+
"create_ingest_pipeline.text_embedding.field_map.input": "passage_text",
15+
"create_ingest_pipeline.text_embedding.field_map.output": "passage_embedding",
16+
"create_index.name": "my-nlp-index"
17+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"open_ai_embedding_deploy": {
3+
"template.name": "deploy-openai-model",
4+
"template.description": "deploying openAI embedding model",
5+
"create_connector.name": "OpenAI-embedding-connector",
6+
"create_connector.description": "Connector to public OpenAI model",
7+
"create_connector.protocol": "http",
8+
"create_connector.model": "text-embedding-ada-002",
9+
"create_connector.endpoint": "api.openai.com",
10+
"create_connector.credential.key": "123",
11+
"create_connector.actions.url": "https://api.openai.com/v1/embeddings",
12+
"create_connector.actions.request_body": "{ \"input\": ${parameters.input}, \"model\": \"${parameters.model}\" }",
13+
"create_connector.actions.pre_process_function": "connector.pre_process.openai.embedding",
14+
"create_connector.actions.post_process_function": "connector.post_process.openai.embedding",
15+
"register_remote_model_1.name": "OpenAI embedding model",
16+
"register_remote_model_1.description": "openai-embedding-model"
17+
}
18+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
{
2+
"name": "{template.name}",
3+
"description": "{template.description}",
4+
"use_case": "DEPLOY_MODEL",
5+
"version": {
6+
"template": "1.0.0",
7+
"compatibility": [
8+
"2.12.0",
9+
"3.0.0"
10+
]
11+
},
12+
"workflows": {
13+
"provision": {
14+
"nodes": [
15+
{
16+
"id": "create_connector",
17+
"type": "create_connector",
18+
"user_inputs": {
19+
"name": "${{create_connector_1}}",
20+
"description": "${{create_connector_1.description}}",
21+
"version": "1",
22+
"protocol": "${{create_connector_1.protocol}}",
23+
"parameters": {
24+
"endpoint": "${{create_connector_1.endpoint}}",
25+
"model": "${{create_connector_1.model}}"
26+
},
27+
"credential": {
28+
"key": "${{create_connector_1.credential.key}}",
29+
},
30+
"actions": [
31+
{
32+
"action_type": "predict",
33+
"method": "POST",
34+
"url": "https://api.openai.com/v1/embeddings",
35+
"headers": {
36+
"Authorization": "Bearer ${credential.openAI_key}"
37+
},
38+
"request_body": "{ \"input\": ${parameters.input}, \"model\": \"${parameters.model}\" }",
39+
"pre_process_function": "connector.pre_process.openai.embedding",
40+
"post_process_function": "connector.post_process.openai.embedding"
41+
}
42+
]
43+
}
44+
},
45+
{
46+
"id": "register_model",
47+
"type": "register_remote_model",
48+
"previous_node_inputs": {
49+
"create_connector_step_1": "parameters"
50+
},
51+
"user_inputs": {
52+
"name": "${register_remote_model.name}",
53+
"function_name": "remote",
54+
"description": "${register_remote_model.description}"
55+
}
56+
},
57+
{
58+
"id": "deploy_model",
59+
"type": "deploy_model",
60+
"previous_node_inputs": {
61+
"register_model_1": "model_id"
62+
}
63+
}
64+
],
65+
"edges": [
66+
{
67+
"source": "create_connector",
68+
"dest": "register_model"
69+
},
70+
{
71+
"source": "register_model",
72+
"dest": "deploy_model"
73+
}
74+
]
75+
}
76+
}
77+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
{
2+
"deploy-remote-model-defaults": [
3+
{
4+
"openai_embedding_deploy": {
5+
"template.name": "deploy-openai-model",
6+
"template.description": "deploying openAI embedding model",
7+
"create_connector_1.name": "OpenAI-embedding-connector",
8+
"create_connector_1.description": "Connector to public AI model service for GPT 3.5",
9+
"create_connector_1.protocol": "http",
10+
"create_connector_1.model": "gpt-3.5-turbo",
11+
"create_connector_1.endpoint": "api.openai.com",
12+
"create_connector_1.credential.key": "123",
13+
"create_connector_1.request_body": "{ \"input\": ${parameters.input}, \"model\": \"${parameters.model}\" }",
14+
"create_connector_1.pre_process_function": "connector.pre_process.openai.embedding",
15+
"create_connector_1.post_process_function": "connector.post_process.openai.embedding",
16+
"register_remote_model_1.name": "test-description"
17+
}
18+
},
19+
{
20+
"cohere_embedding_deploy": {
21+
"template.name": "deploy-cohere-embedding-model",
22+
"template.description": "deploying cohere embedding model",
23+
"create_connector_1.name": "cohere-embedding-connector",
24+
"create_connector_1.description": "Connector to public AI model service for GPT 3.5",
25+
"create_connector_1.protocol": "http",
26+
"create_connector_1.model": "gpt-3.5-turbo",
27+
"create_connector_1.endpoint": "api.openai.com",
28+
"create_connector_1.credential.key": "123",
29+
"create_connector_1.request_body": "{ \"input\": ${parameters.input}, \"model\": \"${parameters.model}\" }",
30+
"create_connector_1.pre_process_function": "connector.pre_process.openai.embedding",
31+
"create_connector_1.post_process_function": "connector.post_process.openai.embedding",
32+
"register_remote_model_1.name": "test-description"
33+
}
34+
}
35+
]
36+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
{
2+
"name": "${{template.name}}",
3+
"description": "${{template.description}}",
4+
"use_case": "DEPLOY_MODEL",
5+
"version": {
6+
"template": "1.0.0",
7+
"compatibility": [
8+
"2.12.0",
9+
"3.0.0"
10+
]
11+
},
12+
"workflows": {
13+
"provision": {
14+
"nodes": [
15+
{
16+
"id": "create_connector",
17+
"type": "create_connector",
18+
"user_inputs": {
19+
"name": "${{create_connector_1}}",
20+
"description": "${{create_connector_1.description}}",
21+
"version": "1",
22+
"protocol": "${{create_connector_1.protocol}}",
23+
"parameters": {
24+
"endpoint": "${{create_connector_1.endpoint}}",
25+
"model": "${{create_connector_1.model}}",
26+
"input_type": "search_document",
27+
"truncate": "END"
28+
},
29+
"credential": {
30+
"key": "${{create_connector_1.credential.key}}"
31+
},
32+
"actions": [
33+
{
34+
"action_type": "predict",
35+
"method": "POST",
36+
"url": "${{create_connector.actions.url}}",
37+
"headers": {
38+
"Authorization": "Bearer ${credential.key}",
39+
"Request-Source": "unspecified:opensearch"
40+
},
41+
"request_body": "${{create_connector.actions.request_body}}",
42+
"pre_process_function": "${{create_connector.actions.pre_process_function}}",
43+
"post_process_function": "${{create_connector.actions.post_process_function}}"
44+
}
45+
]
46+
}
47+
},
48+
{
49+
"id": "register_model",
50+
"type": "register_remote_model",
51+
"previous_node_inputs": {
52+
"create_connector_step_1": "parameters"
53+
},
54+
"user_inputs": {
55+
"name": "${register_remote_model.name}",
56+
"function_name": "remote",
57+
"description": "${register_remote_model.description}"
58+
}
59+
},
60+
{
61+
"id": "deploy_model",
62+
"type": "deploy_model",
63+
"previous_node_inputs": {
64+
"register_model_1": "model_id"
65+
}
66+
},
67+
{
68+
"id": "create_ingest_pipeline",
69+
"type": "create_ingest_pipeline",
70+
"previous_node_inputs": {
71+
"deploy_openai_model": "model_id"
72+
},
73+
"user_inputs": {
74+
"pipeline_id": "${{create_ingest_pipeline.pipeline_id}}",
75+
"configurations": {
76+
"description": "${{create_ingest_pipeline.description}}",
77+
"processors": [
78+
{
79+
"text_embedding": {
80+
"model_id": "${{deploy_openai_model.model_id}}",
81+
"field_map": {
82+
"${{text_embedding.field_map.input}}": "${{text_embedding.field_map.input}}"
83+
}
84+
}
85+
}
86+
]
87+
}
88+
}
89+
},
90+
{
91+
"id": "create_index",
92+
"type": "create_index",
93+
"previous_node_inputs": {
94+
"create_ingest_pipeline": "pipeline_id"
95+
},
96+
"user_inputs": {
97+
"index_name": "${{create_index.name}}",
98+
"configurations": {
99+
"settings": {
100+
"index": {
101+
"number_of_shards": 2,
102+
"number_of_replicas": 1,
103+
"search.default_pipeline" : "${{create_ingest_pipeline.pipeline_id}}"
104+
}
105+
},
106+
"mappings": {
107+
"_doc": {
108+
"properties": {
109+
"age": {
110+
"type": "integer"
111+
}
112+
}
113+
}
114+
},
115+
"aliases": {
116+
"sample-alias1": {}
117+
}
118+
}
119+
}
120+
}
121+
]
122+
}
123+
}
124+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
{
2+
"name": "${{template.name}}",
3+
"description": "${{template.description}}",
4+
"use_case": "",
5+
"version": {
6+
"template": "1.0.0",
7+
"compatibility": [
8+
"2.12.0",
9+
"3.0.0"
10+
]
11+
},
12+
"workflows": {
13+
"provision": {
14+
"nodes": [
15+
{
16+
"id": "create_connector",
17+
"type": "create_connector",
18+
"user_inputs": {
19+
"name": "${{create_connector.name}}",
20+
"description": "${{create_connector.description}}",
21+
"version": "1",
22+
"protocol": "${{create_connector.protocol}}",
23+
"parameters": {
24+
"endpoint": "${{create_connector.endpoint}}",
25+
"model": "${{create_connector.model}}",
26+
"input_type": "search_document",
27+
"truncate": "END"
28+
},
29+
"credential": {
30+
"key": "${{create_connector.credential.key}}"
31+
},
32+
"actions": [
33+
{
34+
"action_type": "predict",
35+
"method": "POST",
36+
"url": "${{create_connector.actions.url}}",
37+
"headers": {
38+
"Authorization": "Bearer ${credential.key}",
39+
"Request-Source": "unspecified:opensearch"
40+
},
41+
"request_body": "${{create_connector.actions.request_body}}",
42+
"pre_process_function": "${{create_connector.actions.pre_process_function}}",
43+
"post_process_function": "${{create_connector.actions.post_process_function}}"
44+
}
45+
]
46+
}
47+
},
48+
{
49+
"id": "register_model",
50+
"type": "register_remote_model",
51+
"previous_node_inputs": {
52+
"create_connector": "parameters"
53+
},
54+
"user_inputs": {
55+
"name": "${{register_remote_model.name}}",
56+
"function_name": "remote",
57+
"description": "${{register_remote_model.description}}"
58+
}
59+
},
60+
{
61+
"id": "deploy_model",
62+
"type": "deploy_model",
63+
"previous_node_inputs": {
64+
"register_model": "model_id"
65+
}
66+
}
67+
],
68+
"edges": [
69+
{
70+
"source": "create_connector",
71+
"dest": "register_model"
72+
},
73+
{
74+
"source": "register_model",
75+
"dest": "deploy_model"
76+
}
77+
]
78+
}
79+
}
80+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
{
2+
"name": "${{template.name}}",
3+
"description": "${{template.description}}",
4+
"use_case": "",
5+
"version": {
6+
"template": "1.0.0",
7+
"compatibility": [
8+
"2.12.0",
9+
"3.0.0"
10+
]
11+
},
12+
"workflows": {
13+
"provision": {
14+
"nodes": [
15+
{
16+
"id": "create_connector",
17+
"type": "create_connector",
18+
"user_inputs": {
19+
"name": "${{create_connector}}",
20+
"description": "${{create_connector.description}}",
21+
"version": "1",
22+
"protocol": "${{create_connector.protocol}}",
23+
"parameters": {
24+
"endpoint": "${{create_connector.endpoint}}",
25+
"model": "${{create_connector.model}}"
26+
},
27+
"credential": {
28+
"key": "${{create_connector.credential.key}}"
29+
},
30+
"actions": [
31+
{
32+
"action_type": "predict",
33+
"method": "POST",
34+
"url": "${{create_connector.actions.url}}",
35+
"headers": {
36+
"Authorization": "Bearer ${credential.key}"
37+
},
38+
"request_body": "${{create_connector.actions.request_body}}",
39+
"pre_process_function": "${{create_connector.actions.pre_process_function}}",
40+
"post_process_function": "${{create_connector.actions.post_process_function}}"
41+
}
42+
]
43+
}
44+
},
45+
{
46+
"id": "register_model",
47+
"type": "register_remote_model",
48+
"previous_node_inputs": {
49+
"create_connector_step_1": "parameters"
50+
},
51+
"user_inputs": {
52+
"name": "${{register_remote_model.name}}",
53+
"function_name": "remote",
54+
"description": "${{register_remote_model.description}}"
55+
}
56+
},
57+
{
58+
"id": "deploy_model",
59+
"type": "deploy_model",
60+
"previous_node_inputs": {
61+
"register_model_1": "model_id"
62+
}
63+
}
64+
],
65+
"edges": [
66+
{
67+
"source": "create_connector",
68+
"dest": "register_model"
69+
},
70+
{
71+
"source": "register_model",
72+
"dest": "deploy_model"
73+
}
74+
]
75+
}
76+
}
77+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
{
2+
"name": "${{template.name}}",
3+
"description": "${{template.description}}",
4+
"use_case": "",
5+
"version": {
6+
"template": "1.0.0",
7+
"compatibility": [
8+
"2.12.0",
9+
"3.0.0"
10+
]
11+
},
12+
"workflows": {
13+
"provision": {
14+
"nodes": [
15+
{
16+
"id": "register_local_sparse_encoding_model",
17+
"type": "register_local_sparse_encoding_model",
18+
"user_inputs": {
19+
"node_timeout": "60s",
20+
"name": "neural-sparse/opensearch-neural-sparse-tokenizer-v1-v2",
21+
"version": "1.0.0",
22+
"description": "This is a neural sparse tokenizer model: It tokenize input sentence into tokens and assign pre-defined weight from IDF to each. It serves only in query.",
23+
"model_format": "TORCH_SCRIPT",
24+
"function_name": "SPARSE_TOKENIZE",
25+
"model_content_hash_value": "b3487da9c58ac90541b720f3b367084f271d280c7f3bdc3e6d9c9a269fb31950",
26+
"url": "https://artifacts.opensearch.org/models/ml-models/amazon/neural-sparse/opensearch-neural-sparse-tokenizer-v1/1.0.0/torch_script/opensearch-neural-sparse-tokenizer-v1-1.0.0.zip",
27+
"deploy": true
28+
}
29+
},
30+
{
31+
"id": "create_ingest_pipeline",
32+
"type": "create_ingest_pipeline",
33+
"previous_node_inputs": {
34+
"register_local_sparse_encoding_model": "model_id"
35+
},
36+
"user_inputs": {
37+
"pipeline_id": "${{create_ingest_pipeline.pipeline_id}}",
38+
"configurations": {
39+
"description": "${{create_ingest_pipeline.description}}",
40+
"processors": [
41+
{
42+
"sparse_encoding": {
43+
"model_id": "${{register_local_sparse_encoding_model.model_id}}",
44+
"field_map": {
45+
"${{create_ingest_pipeline.text_embedding.field_map.input}}": "${{create_ingest_pipeline.text_embedding.field_map.output}}"
46+
}
47+
}
48+
}
49+
]
50+
}
51+
}
52+
},
53+
{
54+
"id": "create_index",
55+
"type": "create_index",
56+
"previous_node_inputs": {
57+
"create_ingest_pipeline": "pipeline_id"
58+
},
59+
"user_inputs": {
60+
"index_name": "${{create_index.name}}",
61+
"configurations": {
62+
"settings": {
63+
"default_pipeline": "${{create_ingest_pipeline.pipeline_id}}"
64+
},
65+
"mappings": {
66+
"_doc": {
67+
"properties": {
68+
"id": {
69+
"type": "text"
70+
},
71+
"${{create_ingest_pipeline.text_embedding.field_map.output}}": {
72+
"type": "rank_features"
73+
},
74+
"${{create_ingest_pipeline.text_embedding.field_map.input}}": {
75+
"type": "text"
76+
}
77+
}
78+
}
79+
}
80+
}
81+
}
82+
}
83+
]
84+
}
85+
}
86+
}

‎src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java

+18
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,24 @@ protected Response createWorkflow(RestClient client, Template template) throws E
329329
return TestHelpers.makeRequest(client, "POST", WORKFLOW_URI + "?validation=off", Collections.emptyMap(), template.toJson(), null);
330330
}
331331

332+
/**
333+
* Helper method to invoke the Create Workflow Rest Action without validation
334+
* @param client the rest client
335+
* @param useCase the usecase to create
336+
* @throws Exception if the request fails
337+
* @return a rest response
338+
*/
339+
protected Response createWorkflowWithUseCase(RestClient client, String useCase) throws Exception {
340+
return TestHelpers.makeRequest(
341+
client,
342+
"POST",
343+
WORKFLOW_URI + "?validation=off&use_case=" + useCase,
344+
Collections.emptyMap(),
345+
"{}",
346+
null
347+
);
348+
}
349+
332350
/**
333351
* Helper method to invoke the Create Workflow Rest Action with provision
334352
* @param client the rest client
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* The OpenSearch Contributors require contributions made to
6+
* this file be licensed under the Apache-2.0 license or a
7+
* compatible open source license.
8+
*/
9+
package org.opensearch.flowframework.common;
10+
11+
import org.opensearch.flowframework.exception.FlowFrameworkException;
12+
import org.opensearch.test.OpenSearchTestCase;
13+
14+
public class DefaultUseCasesTests extends OpenSearchTestCase {
15+
16+
@Override
17+
public void setUp() throws Exception {
18+
super.setUp();
19+
}
20+
21+
public void testGetDefaultsFileByValidUseCaseName() throws FlowFrameworkException {
22+
String defaultsFile = DefaultUseCases.getDefaultsFileByUseCaseName("open_ai_embedding_model_deploy");
23+
assertEquals("defaults/open-ai-embedding-defaults.json", defaultsFile);
24+
}
25+
26+
public void testGetDefaultsFileByInvalidUseCaseName() throws FlowFrameworkException {
27+
FlowFrameworkException e = assertThrows(
28+
FlowFrameworkException.class,
29+
() -> DefaultUseCases.getDefaultsFileByUseCaseName("invalid_use_case")
30+
);
31+
}
32+
33+
public void testGetSubstitutionTemplateByValidUseCaseName() throws FlowFrameworkException {
34+
String templateFile = DefaultUseCases.getSubstitutionReadyFileByUseCaseName("open_ai_embedding_model_deploy");
35+
assertEquals("substitutionTemplates/deploy-remote-model-template.json", templateFile);
36+
}
37+
38+
public void testGetSubstitutionTemplateByInvalidUseCaseName() throws FlowFrameworkException {
39+
FlowFrameworkException e = assertThrows(
40+
FlowFrameworkException.class,
41+
() -> DefaultUseCases.getSubstitutionReadyFileByUseCaseName("invalid_use_case")
42+
);
43+
}
44+
}

‎src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java

+41
Original file line numberDiff line numberDiff line change
@@ -399,4 +399,45 @@ public void testCreateAndProvisionIngestAndSearchPipeline() throws Exception {
399399

400400
}
401401

402+
public void testDefaultCohereUseCase() throws Exception {
403+
404+
// Using a 3 step template to create a connector, register remote model and deploy model
405+
Template template = TestHelpers.createTemplateFromFile("ingest-search-pipeline-template.json");
406+
407+
// Hit Create Workflow API with original template
408+
Response response = createWorkflowWithUseCase(client(), "cohere-embedding_model_deploy");
409+
assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response));
410+
411+
Map<String, Object> responseMap = entityAsMap(response);
412+
String workflowId = (String) responseMap.get(WORKFLOW_ID);
413+
getAndAssertWorkflowStatus(client(), workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED);
414+
415+
// Ensure Ml config index is initialized as creating a connector requires this, then hit Provision API and assert status
416+
if (!indexExistsWithAdminClient(".plugins-ml-config")) {
417+
assertBusy(() -> assertTrue(indexExistsWithAdminClient(".plugins-ml-config")), 40, TimeUnit.SECONDS);
418+
response = provisionWorkflow(client(), workflowId);
419+
} else {
420+
response = provisionWorkflow(client(), workflowId);
421+
}
422+
423+
assertEquals(RestStatus.OK, TestHelpers.restStatus(response));
424+
getAndAssertWorkflowStatus(client(), workflowId, State.PROVISIONING, ProvisioningProgress.IN_PROGRESS);
425+
426+
// Wait until provisioning has completed successfully before attempting to retrieve created resources
427+
List<ResourceCreated> resourcesCreated = getResourcesCreated(client(), workflowId, 30);
428+
429+
List<String> expectedStepNames = List.of("create_connector", "register_remote_model", "deploy_model");
430+
431+
List workflowStepNames = resourcesCreated.stream()
432+
.peek(resourceCreated -> assertNotNull(resourceCreated.resourceId()))
433+
.map(ResourceCreated::workflowStepName)
434+
.collect(Collectors.toList());
435+
for (String expectedName : expectedStepNames) {
436+
assertTrue(workflowStepNames.contains(expectedName));
437+
}
438+
439+
// This template should create 5 resources, connector_id, registered model_id, deployed model_id and pipelineId
440+
assertEquals(3, resourcesCreated.size());
441+
}
442+
402443
}

‎src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java

+36
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.opensearch.core.rest.RestStatus;
1616
import org.opensearch.core.xcontent.MediaTypeRegistry;
1717
import org.opensearch.flowframework.TestHelpers;
18+
import org.opensearch.flowframework.common.DefaultUseCases;
1819
import org.opensearch.flowframework.common.FlowFrameworkSettings;
1920
import org.opensearch.flowframework.model.Template;
2021
import org.opensearch.flowframework.model.Workflow;
@@ -34,6 +35,7 @@
3435
import java.util.Map;
3536

3637
import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW;
38+
import static org.opensearch.flowframework.common.CommonValue.USE_CASE;
3739
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI;
3840
import static org.mockito.ArgumentMatchers.any;
3941
import static org.mockito.Mockito.doAnswer;
@@ -134,6 +136,40 @@ public void testCreateWorkflowRequestWithParamsButNoProvision() throws Exception
134136
);
135137
}
136138

139+
public void testCreateWorkflowRequestWithUseCaseButNoProvision() throws Exception {
140+
RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
141+
.withPath(this.createWorkflowPath)
142+
.withParams(Map.of(USE_CASE, DefaultUseCases.COHERE_EMBEDDING_MODEL_DEPLOY.getUseCaseName()))
143+
.withContent(new BytesArray(""), MediaTypeRegistry.JSON)
144+
.build();
145+
FakeRestChannel channel = new FakeRestChannel(request, false, 1);
146+
doAnswer(invocation -> {
147+
ActionListener<WorkflowResponse> actionListener = invocation.getArgument(2);
148+
actionListener.onResponse(new WorkflowResponse("id-123"));
149+
return null;
150+
}).when(nodeClient).execute(any(), any(WorkflowRequest.class), any());
151+
createWorkflowRestAction.handleRequest(request, channel, nodeClient);
152+
assertEquals(RestStatus.CREATED, channel.capturedResponse().status());
153+
assertTrue(channel.capturedResponse().content().utf8ToString().contains("id-123"));
154+
}
155+
156+
public void testCreateWorkflowRequestWithUseCaseAndContent() throws Exception {
157+
RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
158+
.withPath(this.createWorkflowPath)
159+
.withParams(Map.of(USE_CASE, DefaultUseCases.COHERE_EMBEDDING_MODEL_DEPLOY.getUseCaseName()))
160+
.withContent(new BytesArray("{\"key\":\"step\"}"), MediaTypeRegistry.JSON)
161+
.build();
162+
FakeRestChannel channel = new FakeRestChannel(request, false, 1);
163+
doAnswer(invocation -> {
164+
ActionListener<WorkflowResponse> actionListener = invocation.getArgument(2);
165+
actionListener.onResponse(new WorkflowResponse("id-123"));
166+
return null;
167+
}).when(nodeClient).execute(any(), any(WorkflowRequest.class), any());
168+
createWorkflowRestAction.handleRequest(request, channel, nodeClient);
169+
assertEquals(RestStatus.CREATED, channel.capturedResponse().status());
170+
assertTrue(channel.capturedResponse().content().utf8ToString().contains("id-123"));
171+
}
172+
137173
public void testInvalidCreateWorkflowRequest() throws Exception {
138174
RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
139175
.withPath(this.createWorkflowPath)

‎src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java

+37-5
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ public void testMaxWorkflow() {
211211

212212
@SuppressWarnings("unchecked")
213213
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
214-
WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap());
214+
WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap(), null, Collections.emptyMap());
215215

216216
doAnswer(invocation -> {
217217
ActionListener<SearchResponse> searchListener = invocation.getArgument(1);
@@ -248,7 +248,15 @@ public void onFailure(Exception e) {
248248
public void testFailedToCreateNewWorkflow() {
249249
@SuppressWarnings("unchecked")
250250
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
251-
WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap());
251+
WorkflowRequest workflowRequest = new WorkflowRequest(
252+
null,
253+
template,
254+
new String[] { "off" },
255+
false,
256+
Collections.emptyMap(),
257+
null,
258+
Collections.emptyMap()
259+
);
252260

253261
// Bypass checkMaxWorkflows and force onResponse
254262
doAnswer(invocation -> {
@@ -279,7 +287,15 @@ public void testFailedToCreateNewWorkflow() {
279287
public void testCreateNewWorkflow() {
280288
@SuppressWarnings("unchecked")
281289
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
282-
WorkflowRequest workflowRequest = new WorkflowRequest(null, template, new String[] { "off" }, false, Collections.emptyMap());
290+
WorkflowRequest workflowRequest = new WorkflowRequest(
291+
null,
292+
template,
293+
new String[] { "off" },
294+
false,
295+
Collections.emptyMap(),
296+
null,
297+
Collections.emptyMap()
298+
);
283299

284300
// Bypass checkMaxWorkflows and force onResponse
285301
doAnswer(invocation -> {
@@ -410,7 +426,15 @@ public void testCreateWorkflow_withValidation_withProvision_Success() throws Exc
410426
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
411427

412428
doNothing().when(workflowProcessSorter).validate(any(), any());
413-
WorkflowRequest workflowRequest = new WorkflowRequest(null, validTemplate, new String[] { "all" }, true, Collections.emptyMap());
429+
WorkflowRequest workflowRequest = new WorkflowRequest(
430+
null,
431+
validTemplate,
432+
new String[] { "all" },
433+
true,
434+
Collections.emptyMap(),
435+
null,
436+
Collections.emptyMap()
437+
);
414438

415439
// Bypass checkMaxWorkflows and force onResponse
416440
doAnswer(invocation -> {
@@ -463,7 +487,15 @@ public void testCreateWorkflow_withValidation_withProvision_FailedProvisioning()
463487
@SuppressWarnings("unchecked")
464488
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
465489
doNothing().when(workflowProcessSorter).validate(any(), any());
466-
WorkflowRequest workflowRequest = new WorkflowRequest(null, validTemplate, new String[] { "all" }, true, Collections.emptyMap());
490+
WorkflowRequest workflowRequest = new WorkflowRequest(
491+
null,
492+
validTemplate,
493+
new String[] { "all" },
494+
true,
495+
Collections.emptyMap(),
496+
null,
497+
Collections.emptyMap()
498+
);
467499

468500
// Bypass checkMaxWorkflows and force onResponse
469501
doAnswer(invocation -> {

‎src/test/java/org/opensearch/flowframework/transport/WorkflowRequestResponseTests.java

+46-1
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,55 @@ public void testWorkflowRequestWithParams() throws IOException {
143143
assertEquals("bar", workflowRequest.getParams().get("foo"));
144144
}
145145

146+
public void testWorkflowRequestWithUseCase() throws IOException {
147+
WorkflowRequest workflowRequest = new WorkflowRequest("123", template, "cohere-embedding_model_deploy", Collections.emptyMap());
148+
assertNotNull(workflowRequest.getWorkflowId());
149+
assertEquals(template, workflowRequest.getTemplate());
150+
assertNull(workflowRequest.validate());
151+
assertFalse(workflowRequest.isProvision());
152+
assertTrue(workflowRequest.getDefaultParams().isEmpty());
153+
assertEquals(workflowRequest.getUseCase(), "cohere-embedding_model_deploy");
154+
155+
BytesStreamOutput out = new BytesStreamOutput();
156+
workflowRequest.writeTo(out);
157+
BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()));
158+
159+
WorkflowRequest streamInputRequest = new WorkflowRequest(in);
160+
161+
assertEquals(workflowRequest.getWorkflowId(), streamInputRequest.getWorkflowId());
162+
assertEquals(workflowRequest.getTemplate().toString(), streamInputRequest.getTemplate().toString());
163+
assertNull(workflowRequest.validate());
164+
assertFalse(workflowRequest.isProvision());
165+
assertTrue(workflowRequest.getDefaultParams().isEmpty());
166+
assertEquals(workflowRequest.getUseCase(), "cohere-embedding_model_deploy");
167+
}
168+
169+
public void testWorkflowRequestWithUseCaseAndParamsInBody() throws IOException {
170+
WorkflowRequest workflowRequest = new WorkflowRequest("123", template, "cohere-embedding_model_deploy", Map.of("step", "model"));
171+
assertNotNull(workflowRequest.getWorkflowId());
172+
assertEquals(template, workflowRequest.getTemplate());
173+
assertNull(workflowRequest.validate());
174+
assertFalse(workflowRequest.isProvision());
175+
assertEquals(workflowRequest.getDefaultParams().get("step"), "model");
176+
177+
BytesStreamOutput out = new BytesStreamOutput();
178+
workflowRequest.writeTo(out);
179+
BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()));
180+
181+
WorkflowRequest streamInputRequest = new WorkflowRequest(in);
182+
183+
assertEquals(workflowRequest.getWorkflowId(), streamInputRequest.getWorkflowId());
184+
assertEquals(workflowRequest.getTemplate().toString(), streamInputRequest.getTemplate().toString());
185+
assertNull(workflowRequest.validate());
186+
assertFalse(workflowRequest.isProvision());
187+
assertEquals(workflowRequest.getDefaultParams().get("step"), "model");
188+
189+
}
190+
146191
public void testWorkflowRequestWithParamsNoProvision() throws IOException {
147192
IllegalArgumentException ex = assertThrows(
148193
IllegalArgumentException.class,
149-
() -> new WorkflowRequest("123", template, new String[] { "all" }, false, Map.of("foo", "bar"))
194+
() -> new WorkflowRequest("123", template, new String[] { "all" }, false, Map.of("foo", "bar"), null, Collections.emptyMap())
150195
);
151196
assertEquals("Params may only be included when provisioning.", ex.getMessage());
152197
}

‎src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java

+46
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
import java.io.IOException;
2121
import java.time.Instant;
22+
import java.util.Collections;
23+
import java.util.HashMap;
2224
import java.util.List;
2325
import java.util.Map;
2426
import java.util.Set;
@@ -88,6 +90,50 @@ public void testParseArbitraryStringToObjectMapToString() throws IOException {
8890
assertEquals("{\"test-1\":{\"test-1\":\"test-1\"}}", parsedMap);
8991
}
9092

93+
public void testConditionallySubstituteWithNoPlaceholders() {
94+
String input = "This string has no placeholders";
95+
Map<String, WorkflowData> outputs = new HashMap<>();
96+
Map<String, String> params = new HashMap<>();
97+
98+
Object result = ParseUtils.conditionallySubstitute(input, outputs, params);
99+
100+
assertEquals("This string has no placeholders", result);
101+
}
102+
103+
public void testConditionallySubstituteWithUnmatchedPlaceholders() {
104+
String input = "This string has unmatched ${{placeholder}}";
105+
Map<String, WorkflowData> outputs = new HashMap<>();
106+
Map<String, String> params = new HashMap<>();
107+
108+
Object result = ParseUtils.conditionallySubstitute(input, outputs, params);
109+
110+
assertEquals("This string has unmatched ${{placeholder}}", result);
111+
}
112+
113+
public void testConditionallySubstituteWithOutputsSubstitution() {
114+
String input = "This string contains ${{node.step}}";
115+
Map<String, WorkflowData> outputs = new HashMap<>();
116+
Map<String, String> params = new HashMap<>();
117+
Map<String, Object> contents = new HashMap<>(Collections.emptyMap());
118+
contents.put("step", "model_id");
119+
WorkflowData data = new WorkflowData(contents, params, "test", "test");
120+
outputs.put("node", data);
121+
Object result = ParseUtils.conditionallySubstitute(input, outputs, params);
122+
assertEquals("This string contains model_id", result);
123+
}
124+
125+
public void testConditionallySubstituteWithParamsSubstitution() {
126+
String input = "This string contains ${{node}}";
127+
Map<String, WorkflowData> outputs = new HashMap<>();
128+
Map<String, String> params = new HashMap<>();
129+
params.put("node", "step");
130+
Map<String, Object> contents = new HashMap<>(Collections.emptyMap());
131+
WorkflowData data = new WorkflowData(contents, params, "test", "test");
132+
outputs.put("node", data);
133+
Object result = ParseUtils.conditionallySubstitute(input, outputs, params);
134+
assertEquals("This string contains step", result);
135+
}
136+
91137
public void testGetInputsFromPreviousSteps() {
92138
WorkflowData currentNodeInputs = new WorkflowData(
93139
Map.ofEntries(

0 commit comments

Comments
 (0)
Please sign in to comment.