Skip to content

Commit 236fda2

Browse files
committed
[Enhancement] Enhance validation for create connector API
This change will address the second part of validation "pre and post processing function validation". Partially resolves opensearch-project#2993 Signed-off-by: Abdul Muneer Kolarkunnu <[email protected]>
1 parent 7df638e commit 236fda2

File tree

2 files changed

+42
-5
lines changed

2 files changed

+42
-5
lines changed

common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java

+37-3
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,15 @@
99
import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_BATCH_JOB_ARN;
1010
import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_EMBEDDING;
1111
import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_RERANK;
12+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_V2_EMBEDDING_BINARY;
13+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_V2_EMBEDDING_FLOAT;
1214
import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_EMBEDDING;
1315
import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_RERANK;
16+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_V2_EMBEDDING_BINARY;
17+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_V2_EMBEDDING_FLOAT32;
18+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_V2_EMBEDDING_INT8;
19+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_V2_EMBEDDING_UBINARY;
20+
import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_V2_EMBEDDING_UINT8;
1421
import static org.opensearch.ml.common.connector.MLPostProcessFunction.DEFAULT_EMBEDDING;
1522
import static org.opensearch.ml.common.connector.MLPostProcessFunction.DEFAULT_RERANK;
1623
import static org.opensearch.ml.common.connector.MLPostProcessFunction.OPENAI_EMBEDDING;
@@ -288,16 +295,39 @@ private void validatePostProcessFunctions(String remoteServer) {
288295
}
289296
break;
290297
case COHERE:
291-
if (!(COHERE_EMBEDDING.equals(postProcessFunction) || COHERE_RERANK.equals(postProcessFunction))) {
298+
if (!(COHERE_EMBEDDING.equals(postProcessFunction)
299+
|| COHERE_RERANK.equals(postProcessFunction)
300+
|| COHERE_V2_EMBEDDING_FLOAT32.equals(postProcessFunction)
301+
|| COHERE_V2_EMBEDDING_INT8.equals(postProcessFunction)
302+
|| COHERE_V2_EMBEDDING_UINT8.equals(postProcessFunction)
303+
|| COHERE_V2_EMBEDDING_BINARY.equals(postProcessFunction)
304+
|| COHERE_V2_EMBEDDING_UBINARY.equals(postProcessFunction))) {
292305
throw new IllegalArgumentException(
293-
"LLM service is " + COHERE + ", so PostProcessFunction should be " + COHERE_EMBEDDING + " or " + COHERE_RERANK
306+
"LLM service is "
307+
+ COHERE
308+
+ ", so PostProcessFunction should be "
309+
+ COHERE_EMBEDDING
310+
+ " or "
311+
+ COHERE_RERANK
312+
+ " or "
313+
+ COHERE_V2_EMBEDDING_FLOAT32
314+
+ " or "
315+
+ COHERE_V2_EMBEDDING_INT8
316+
+ " or "
317+
+ COHERE_V2_EMBEDDING_UINT8
318+
+ " or "
319+
+ COHERE_V2_EMBEDDING_BINARY
320+
+ " or "
321+
+ COHERE_V2_EMBEDDING_UBINARY
294322
);
295323
}
296324
break;
297325
case BEDROCK:
298326
if (!(BEDROCK_EMBEDDING.equals(postProcessFunction)
299327
|| BEDROCK_BATCH_JOB_ARN.equals(postProcessFunction)
300-
|| BEDROCK_RERANK.equals(postProcessFunction))) {
328+
|| BEDROCK_RERANK.equals(postProcessFunction)
329+
|| BEDROCK_V2_EMBEDDING_FLOAT.equals(postProcessFunction)
330+
|| BEDROCK_V2_EMBEDDING_BINARY.equals(postProcessFunction))) {
301331
throw new IllegalArgumentException(
302332
"LLM service is "
303333
+ BEDROCK
@@ -307,6 +337,10 @@ private void validatePostProcessFunctions(String remoteServer) {
307337
+ BEDROCK_BATCH_JOB_ARN
308338
+ " or "
309339
+ BEDROCK_RERANK
340+
+ " or "
341+
+ BEDROCK_V2_EMBEDDING_FLOAT
342+
+ " or "
343+
+ BEDROCK_V2_EMBEDDING_BINARY
310344
);
311345
}
312346
break;

common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java

+5-2
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,9 @@ public void cohereConnectorWithWrongInBuiltPrePostProcessFunction() {
222222
exception = assertThrows(IllegalArgumentException.class, () -> action2.validatePrePostProcessFunctions(Map.of()));
223223
assertEquals(
224224
"LLM service is cohere, so PostProcessFunction should be connector.post_process.cohere.embedding"
225-
+ " or connector.post_process.cohere.rerank",
225+
+ " or connector.post_process.cohere.rerank or connector.post_process.cohere_v2.embedding.float"
226+
+ " or connector.post_process.cohere_v2.embedding.int8 or connector.post_process.cohere_v2.embedding.uint8"
227+
+ " or connector.post_process.cohere_v2.embedding.binary or connector.post_process.cohere_v2.embedding.ubinary",
226228
exception.getMessage()
227229
);
228230
}
@@ -293,7 +295,8 @@ public void bedrockConnectorWithWrongInBuiltPrePostProcessFunction() {
293295
exception = assertThrows(IllegalArgumentException.class, () -> action2.validatePrePostProcessFunctions(Map.of()));
294296
assertEquals(
295297
"LLM service is bedrock, so PostProcessFunction should be connector.post_process.bedrock.embedding"
296-
+ " or connector.post_process.bedrock.batch_job_arn or connector.post_process.bedrock.rerank",
298+
+ " or connector.post_process.bedrock.batch_job_arn or connector.post_process.bedrock.rerank"
299+
+ " or connector.post_process.bedrock_v2.embedding.float or connector.post_process.bedrock_v2.embedding.binary",
297300
exception.getMessage()
298301
);
299302
}

0 commit comments

Comments
 (0)