Skip to content

Commit 4e5455c

Browse files
committed
[Enhancement] Enhance validation for create connector API
This change will address the second part of validation "pre and post processing function validation". Partially resolves opensearch-project#2993 Signed-off-by: Abdul Muneer Kolarkunnu <[email protected]>
1 parent 953f16e commit 4e5455c

File tree

2 files changed

+68
-178
lines changed

2 files changed

+68
-178
lines changed

common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java

+62-147
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,9 @@
66
package org.opensearch.ml.common.connector;
77

88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
9-
import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_BATCH_JOB_ARN;
10-
import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_EMBEDDING;
11-
import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_RERANK;
12-
import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_V2_EMBEDDING_BINARY;
13-
import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_V2_EMBEDDING_FLOAT;
14-
import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_EMBEDDING;
15-
import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_RERANK;
16-
import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_V2_EMBEDDING_BINARY;
17-
import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_V2_EMBEDDING_FLOAT32;
18-
import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_V2_EMBEDDING_INT8;
19-
import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_V2_EMBEDDING_UBINARY;
20-
import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_V2_EMBEDDING_UINT8;
219
import static org.opensearch.ml.common.connector.MLPostProcessFunction.DEFAULT_EMBEDDING;
2210
import static org.opensearch.ml.common.connector.MLPostProcessFunction.DEFAULT_RERANK;
23-
import static org.opensearch.ml.common.connector.MLPostProcessFunction.OPENAI_EMBEDDING;
24-
import static org.opensearch.ml.common.connector.MLPreProcessFunction.IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT;
25-
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT;
26-
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT;
2711
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT;
28-
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT;
29-
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT;
30-
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT;
31-
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT;
3212
import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_SIMILARITY_TO_DEFAULT_INPUT;
3313

3414
import java.io.IOException;
@@ -227,140 +207,75 @@ public void validatePrePostProcessFunctions(Map<String, String> parameters) {
227207
}
228208

229209
private void validatePreProcessFunctions(String remoteServer) {
230-
if (isInBuiltFunction(preProcessFunction)) {
231-
switch (remoteServer) {
232-
case OPENAI:
233-
if (!TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT.equals(preProcessFunction)) {
234-
throw new IllegalArgumentException(
235-
"LLM service is " + OPENAI + ", so PreProcessFunction should be " + TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT
236-
);
237-
}
238-
break;
239-
case COHERE:
240-
if (!(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT.equals(preProcessFunction)
241-
|| IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT.equals(preProcessFunction)
242-
|| TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT.equals(preProcessFunction))) {
243-
throw new IllegalArgumentException(
244-
"LLM service is "
245-
+ COHERE
246-
+ ", so PreProcessFunction should be "
247-
+ TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT
248-
+ " or "
249-
+ IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT
250-
+ " or "
251-
+ TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT
252-
);
253-
}
254-
break;
255-
case BEDROCK:
256-
if (!(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT.equals(preProcessFunction)
257-
|| TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT.equals(preProcessFunction)
258-
|| TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT.equals(preProcessFunction))) {
259-
throw new IllegalArgumentException(
260-
"LLM service is "
261-
+ BEDROCK
262-
+ ", so PreProcessFunction should be "
263-
+ TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT
264-
+ " or "
265-
+ TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT
266-
+ " or "
267-
+ TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT
268-
);
269-
}
270-
break;
271-
case SAGEMAKER:
272-
if (!(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT.equals(preProcessFunction)
273-
|| TEXT_SIMILARITY_TO_DEFAULT_INPUT.equals(preProcessFunction))) {
274-
throw new IllegalArgumentException(
275-
"LLM service is "
276-
+ SAGEMAKER
277-
+ ", so PreProcessFunction should be "
278-
+ TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT
279-
+ " or "
280-
+ TEXT_SIMILARITY_TO_DEFAULT_INPUT
281-
);
282-
}
283-
}
210+
if (!isInBuiltProcessFunction(preProcessFunction)) {
211+
return;
212+
}
213+
switch (remoteServer) {
214+
case OPENAI:
215+
if (!preProcessFunction.contains(OPENAI)) {
216+
throw new IllegalArgumentException(invalidProcessFuncExcText(OPENAI, "PreProcessFunction"));
217+
}
218+
break;
219+
case COHERE:
220+
if (!preProcessFunction.contains(COHERE)) {
221+
throw new IllegalArgumentException(invalidProcessFuncExcText(COHERE, "PreProcessFunction"));
222+
}
223+
break;
224+
case BEDROCK:
225+
if (!preProcessFunction.contains(BEDROCK)) {
226+
throw new IllegalArgumentException(invalidProcessFuncExcText(BEDROCK, "PreProcessFunction"));
227+
}
228+
break;
229+
case SAGEMAKER:
230+
if (!(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT.equals(preProcessFunction)
231+
|| TEXT_SIMILARITY_TO_DEFAULT_INPUT.equals(preProcessFunction))) {
232+
throw new IllegalArgumentException(
233+
"LLM service is "
234+
+ SAGEMAKER
235+
+ ", so PreProcessFunction should be "
236+
+ TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT
237+
+ " or "
238+
+ TEXT_SIMILARITY_TO_DEFAULT_INPUT
239+
);
240+
}
284241
}
285242
}
286243

287244
private void validatePostProcessFunctions(String remoteServer) {
288-
if (isInBuiltFunction(postProcessFunction)) {
289-
switch (remoteServer) {
290-
case OPENAI:
291-
if (!OPENAI_EMBEDDING.equals(postProcessFunction)) {
292-
throw new IllegalArgumentException(
293-
"LLM service is " + OPENAI + ", so PostProcessFunction should be " + OPENAI_EMBEDDING
294-
);
295-
}
296-
break;
297-
case COHERE:
298-
if (!(COHERE_EMBEDDING.equals(postProcessFunction)
299-
|| COHERE_RERANK.equals(postProcessFunction)
300-
|| COHERE_V2_EMBEDDING_FLOAT32.equals(postProcessFunction)
301-
|| COHERE_V2_EMBEDDING_INT8.equals(postProcessFunction)
302-
|| COHERE_V2_EMBEDDING_UINT8.equals(postProcessFunction)
303-
|| COHERE_V2_EMBEDDING_BINARY.equals(postProcessFunction)
304-
|| COHERE_V2_EMBEDDING_UBINARY.equals(postProcessFunction))) {
305-
throw new IllegalArgumentException(
306-
"LLM service is "
307-
+ COHERE
308-
+ ", so PostProcessFunction should be "
309-
+ COHERE_EMBEDDING
310-
+ " or "
311-
+ COHERE_RERANK
312-
+ " or "
313-
+ COHERE_V2_EMBEDDING_FLOAT32
314-
+ " or "
315-
+ COHERE_V2_EMBEDDING_INT8
316-
+ " or "
317-
+ COHERE_V2_EMBEDDING_UINT8
318-
+ " or "
319-
+ COHERE_V2_EMBEDDING_BINARY
320-
+ " or "
321-
+ COHERE_V2_EMBEDDING_UBINARY
322-
);
323-
}
324-
break;
325-
case BEDROCK:
326-
if (!(BEDROCK_EMBEDDING.equals(postProcessFunction)
327-
|| BEDROCK_BATCH_JOB_ARN.equals(postProcessFunction)
328-
|| BEDROCK_RERANK.equals(postProcessFunction)
329-
|| BEDROCK_V2_EMBEDDING_FLOAT.equals(postProcessFunction)
330-
|| BEDROCK_V2_EMBEDDING_BINARY.equals(postProcessFunction))) {
331-
throw new IllegalArgumentException(
332-
"LLM service is "
333-
+ BEDROCK
334-
+ ", so PostProcessFunction should be "
335-
+ BEDROCK_EMBEDDING
336-
+ " or "
337-
+ BEDROCK_BATCH_JOB_ARN
338-
+ " or "
339-
+ BEDROCK_RERANK
340-
+ " or "
341-
+ BEDROCK_V2_EMBEDDING_FLOAT
342-
+ " or "
343-
+ BEDROCK_V2_EMBEDDING_BINARY
344-
);
345-
}
346-
break;
347-
case SAGEMAKER:
348-
if (!(DEFAULT_EMBEDDING.equals(postProcessFunction) || DEFAULT_RERANK.equals(postProcessFunction))) {
349-
throw new IllegalArgumentException(
350-
"LLM service is "
351-
+ SAGEMAKER
352-
+ ", so PostProcessFunction should be "
353-
+ DEFAULT_EMBEDDING
354-
+ " or "
355-
+ DEFAULT_RERANK
356-
);
357-
}
358-
}
245+
if (!isInBuiltProcessFunction(postProcessFunction)) {
246+
return;
359247
}
248+
switch (remoteServer) {
249+
case OPENAI:
250+
if (!postProcessFunction.contains(OPENAI)) {
251+
throw new IllegalArgumentException(invalidProcessFuncExcText(OPENAI, "PostProcessFunction"));
252+
}
253+
break;
254+
case COHERE:
255+
if (!postProcessFunction.contains(COHERE)) {
256+
throw new IllegalArgumentException(invalidProcessFuncExcText(COHERE, "PostProcessFunction"));
257+
}
258+
break;
259+
case BEDROCK:
260+
if (!postProcessFunction.contains(BEDROCK)) {
261+
throw new IllegalArgumentException(invalidProcessFuncExcText(BEDROCK, "PostProcessFunction"));
262+
}
263+
break;
264+
case SAGEMAKER:
265+
if (!(DEFAULT_EMBEDDING.equals(postProcessFunction) || DEFAULT_RERANK.equals(postProcessFunction))) {
266+
throw new IllegalArgumentException(
267+
"LLM service is " + SAGEMAKER + ", so PostProcessFunction should be " + DEFAULT_EMBEDDING + " or " + DEFAULT_RERANK
268+
);
269+
}
270+
}
271+
}
272+
273+
private String invalidProcessFuncExcText(String remoteServer, String func) {
274+
return "LLM service is " + remoteServer + ", so " + func + " should be " + remoteServer + " " + func;
360275
}
361276

362-
private boolean isInBuiltFunction(String function) {
363-
return (function != null && function.startsWith(INBUILT_FUNC_PREFIX));
277+
private boolean isInBuiltProcessFunction(String processFunction) {
278+
return (processFunction != null && processFunction.startsWith(INBUILT_FUNC_PREFIX));
364279
}
365280

366281
public static String getRemoteServerFromURL(String url) {

common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java

+6-31
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,7 @@ public void openAIConnectorWithWrongInBuiltPrePostProcessFunction() {
136136
OPENAI_EMBEDDING
137137
);
138138
Throwable exception = assertThrows(IllegalArgumentException.class, () -> action1.validatePrePostProcessFunctions(Map.of()));
139-
assertEquals(
140-
"LLM service is openai, so PreProcessFunction should be connector.pre_process.openai.embedding",
141-
exception.getMessage()
142-
);
139+
assertEquals("LLM service is openai, so PreProcessFunction should be openai PreProcessFunction", exception.getMessage());
143140
ConnectorAction action2 = new ConnectorAction(
144141
TEST_ACTION_TYPE,
145142
TEST_METHOD_HTTP,
@@ -150,10 +147,7 @@ public void openAIConnectorWithWrongInBuiltPrePostProcessFunction() {
150147
COHERE_EMBEDDING
151148
);
152149
exception = assertThrows(IllegalArgumentException.class, () -> action2.validatePrePostProcessFunctions(Map.of()));
153-
assertEquals(
154-
"LLM service is openai, so PostProcessFunction should be connector.post_process.openai.embedding",
155-
exception.getMessage()
156-
);
150+
assertEquals("LLM service is openai, so PostProcessFunction should be openai PostProcessFunction", exception.getMessage());
157151
}
158152

159153
@Test
@@ -205,11 +199,7 @@ public void cohereConnectorWithWrongInBuiltPrePostProcessFunction() {
205199
COHERE_EMBEDDING
206200
);
207201
Throwable exception = assertThrows(IllegalArgumentException.class, () -> action1.validatePrePostProcessFunctions(Map.of()));
208-
assertEquals(
209-
"LLM service is cohere, so PreProcessFunction should be connector.pre_process.cohere.embedding"
210-
+ " or connector.pre_process.cohere.multimodal_embedding or connector.pre_process.cohere.rerank",
211-
exception.getMessage()
212-
);
202+
assertEquals("LLM service is cohere, so PreProcessFunction should be cohere PreProcessFunction", exception.getMessage());
213203
ConnectorAction action2 = new ConnectorAction(
214204
TEST_ACTION_TYPE,
215205
TEST_METHOD_HTTP,
@@ -220,13 +210,7 @@ public void cohereConnectorWithWrongInBuiltPrePostProcessFunction() {
220210
OPENAI_EMBEDDING
221211
);
222212
exception = assertThrows(IllegalArgumentException.class, () -> action2.validatePrePostProcessFunctions(Map.of()));
223-
assertEquals(
224-
"LLM service is cohere, so PostProcessFunction should be connector.post_process.cohere.embedding"
225-
+ " or connector.post_process.cohere.rerank or connector.post_process.cohere_v2.embedding.float"
226-
+ " or connector.post_process.cohere_v2.embedding.int8 or connector.post_process.cohere_v2.embedding.uint8"
227-
+ " or connector.post_process.cohere_v2.embedding.binary or connector.post_process.cohere_v2.embedding.ubinary",
228-
exception.getMessage()
229-
);
213+
assertEquals("LLM service is cohere, so PostProcessFunction should be cohere PostProcessFunction", exception.getMessage());
230214
}
231215

232216
@Test
@@ -278,11 +262,7 @@ public void bedrockConnectorWithWrongInBuiltPrePostProcessFunction() {
278262
BEDROCK_EMBEDDING
279263
);
280264
Throwable exception = assertThrows(IllegalArgumentException.class, () -> action1.validatePrePostProcessFunctions(Map.of()));
281-
assertEquals(
282-
"LLM service is bedrock, so PreProcessFunction should be connector.pre_process.bedrock.embedding"
283-
+ " or connector.pre_process.bedrock.multimodal_embedding or connector.pre_process.bedrock.rerank",
284-
exception.getMessage()
285-
);
265+
assertEquals("LLM service is bedrock, so PreProcessFunction should be bedrock PreProcessFunction", exception.getMessage());
286266
ConnectorAction action2 = new ConnectorAction(
287267
TEST_ACTION_TYPE,
288268
TEST_METHOD_HTTP,
@@ -293,12 +273,7 @@ public void bedrockConnectorWithWrongInBuiltPrePostProcessFunction() {
293273
COHERE_EMBEDDING
294274
);
295275
exception = assertThrows(IllegalArgumentException.class, () -> action2.validatePrePostProcessFunctions(Map.of()));
296-
assertEquals(
297-
"LLM service is bedrock, so PostProcessFunction should be connector.post_process.bedrock.embedding"
298-
+ " or connector.post_process.bedrock.batch_job_arn or connector.post_process.bedrock.rerank"
299-
+ " or connector.post_process.bedrock_v2.embedding.float or connector.post_process.bedrock_v2.embedding.binary",
300-
exception.getMessage()
301-
);
276+
assertEquals("LLM service is bedrock, so PostProcessFunction should be bedrock PostProcessFunction", exception.getMessage());
302277
}
303278

304279
@Test

0 commit comments

Comments
 (0)