|
6 | 6 | package org.opensearch.ml.common.connector;
|
7 | 7 |
|
8 | 8 | import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
|
9 |
| -import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_BATCH_JOB_ARN; |
10 |
| -import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_EMBEDDING; |
11 |
| -import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_RERANK; |
12 |
| -import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_V2_EMBEDDING_BINARY; |
13 |
| -import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_V2_EMBEDDING_FLOAT; |
14 |
| -import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_EMBEDDING; |
15 |
| -import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_RERANK; |
16 |
| -import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_V2_EMBEDDING_BINARY; |
17 |
| -import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_V2_EMBEDDING_FLOAT32; |
18 |
| -import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_V2_EMBEDDING_INT8; |
19 |
| -import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_V2_EMBEDDING_UBINARY; |
20 |
| -import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_V2_EMBEDDING_UINT8; |
21 | 9 | import static org.opensearch.ml.common.connector.MLPostProcessFunction.DEFAULT_EMBEDDING;
|
22 | 10 | import static org.opensearch.ml.common.connector.MLPostProcessFunction.DEFAULT_RERANK;
|
23 |
| -import static org.opensearch.ml.common.connector.MLPostProcessFunction.OPENAI_EMBEDDING; |
24 |
| -import static org.opensearch.ml.common.connector.MLPreProcessFunction.IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT; |
25 |
| -import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT; |
26 |
| -import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT; |
27 | 11 | import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT;
|
28 |
| -import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; |
29 |
| -import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT; |
30 |
| -import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT; |
31 |
| -import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT; |
32 | 12 | import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_SIMILARITY_TO_DEFAULT_INPUT;
|
33 | 13 |
|
34 | 14 | import java.io.IOException;
|
@@ -227,140 +207,75 @@ public void validatePrePostProcessFunctions(Map<String, String> parameters) {
|
227 | 207 | }
|
228 | 208 |
|
229 | 209 | private void validatePreProcessFunctions(String remoteServer) {
|
230 |
| - if (isInBuiltFunction(preProcessFunction)) { |
231 |
| - switch (remoteServer) { |
232 |
| - case OPENAI: |
233 |
| - if (!TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT.equals(preProcessFunction)) { |
234 |
| - throw new IllegalArgumentException( |
235 |
| - "LLM service is " + OPENAI + ", so PreProcessFunction should be " + TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT |
236 |
| - ); |
237 |
| - } |
238 |
| - break; |
239 |
| - case COHERE: |
240 |
| - if (!(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT.equals(preProcessFunction) |
241 |
| - || IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT.equals(preProcessFunction) |
242 |
| - || TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT.equals(preProcessFunction))) { |
243 |
| - throw new IllegalArgumentException( |
244 |
| - "LLM service is " |
245 |
| - + COHERE |
246 |
| - + ", so PreProcessFunction should be " |
247 |
| - + TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT |
248 |
| - + " or " |
249 |
| - + IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT |
250 |
| - + " or " |
251 |
| - + TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT |
252 |
| - ); |
253 |
| - } |
254 |
| - break; |
255 |
| - case BEDROCK: |
256 |
| - if (!(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT.equals(preProcessFunction) |
257 |
| - || TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT.equals(preProcessFunction) |
258 |
| - || TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT.equals(preProcessFunction))) { |
259 |
| - throw new IllegalArgumentException( |
260 |
| - "LLM service is " |
261 |
| - + BEDROCK |
262 |
| - + ", so PreProcessFunction should be " |
263 |
| - + TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT |
264 |
| - + " or " |
265 |
| - + TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT |
266 |
| - + " or " |
267 |
| - + TEXT_SIMILARITY_TO_BEDROCK_RERANK_INPUT |
268 |
| - ); |
269 |
| - } |
270 |
| - break; |
271 |
| - case SAGEMAKER: |
272 |
| - if (!(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT.equals(preProcessFunction) |
273 |
| - || TEXT_SIMILARITY_TO_DEFAULT_INPUT.equals(preProcessFunction))) { |
274 |
| - throw new IllegalArgumentException( |
275 |
| - "LLM service is " |
276 |
| - + SAGEMAKER |
277 |
| - + ", so PreProcessFunction should be " |
278 |
| - + TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT |
279 |
| - + " or " |
280 |
| - + TEXT_SIMILARITY_TO_DEFAULT_INPUT |
281 |
| - ); |
282 |
| - } |
283 |
| - } |
| 210 | + if (!isInBuiltProcessFunction(preProcessFunction)) { |
| 211 | + return; |
| 212 | + } |
| 213 | + switch (remoteServer) { |
| 214 | + case OPENAI: |
| 215 | + if (!preProcessFunction.contains(OPENAI)) { |
| 216 | + throw new IllegalArgumentException(invalidProcessFuncExcText(OPENAI, "PreProcessFunction")); |
| 217 | + } |
| 218 | + break; |
| 219 | + case COHERE: |
| 220 | + if (!preProcessFunction.contains(COHERE)) { |
| 221 | + throw new IllegalArgumentException(invalidProcessFuncExcText(COHERE, "PreProcessFunction")); |
| 222 | + } |
| 223 | + break; |
| 224 | + case BEDROCK: |
| 225 | + if (!preProcessFunction.contains(BEDROCK)) { |
| 226 | + throw new IllegalArgumentException(invalidProcessFuncExcText(BEDROCK, "PreProcessFunction")); |
| 227 | + } |
| 228 | + break; |
| 229 | + case SAGEMAKER: |
| 230 | + if (!(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT.equals(preProcessFunction) |
| 231 | + || TEXT_SIMILARITY_TO_DEFAULT_INPUT.equals(preProcessFunction))) { |
| 232 | + throw new IllegalArgumentException( |
| 233 | + "LLM service is " |
| 234 | + + SAGEMAKER |
| 235 | + + ", so PreProcessFunction should be " |
| 236 | + + TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT |
| 237 | + + " or " |
| 238 | + + TEXT_SIMILARITY_TO_DEFAULT_INPUT |
| 239 | + ); |
| 240 | + } |
284 | 241 | }
|
285 | 242 | }
|
286 | 243 |
|
287 | 244 | private void validatePostProcessFunctions(String remoteServer) {
|
288 |
| - if (isInBuiltFunction(postProcessFunction)) { |
289 |
| - switch (remoteServer) { |
290 |
| - case OPENAI: |
291 |
| - if (!OPENAI_EMBEDDING.equals(postProcessFunction)) { |
292 |
| - throw new IllegalArgumentException( |
293 |
| - "LLM service is " + OPENAI + ", so PostProcessFunction should be " + OPENAI_EMBEDDING |
294 |
| - ); |
295 |
| - } |
296 |
| - break; |
297 |
| - case COHERE: |
298 |
| - if (!(COHERE_EMBEDDING.equals(postProcessFunction) |
299 |
| - || COHERE_RERANK.equals(postProcessFunction) |
300 |
| - || COHERE_V2_EMBEDDING_FLOAT32.equals(postProcessFunction) |
301 |
| - || COHERE_V2_EMBEDDING_INT8.equals(postProcessFunction) |
302 |
| - || COHERE_V2_EMBEDDING_UINT8.equals(postProcessFunction) |
303 |
| - || COHERE_V2_EMBEDDING_BINARY.equals(postProcessFunction) |
304 |
| - || COHERE_V2_EMBEDDING_UBINARY.equals(postProcessFunction))) { |
305 |
| - throw new IllegalArgumentException( |
306 |
| - "LLM service is " |
307 |
| - + COHERE |
308 |
| - + ", so PostProcessFunction should be " |
309 |
| - + COHERE_EMBEDDING |
310 |
| - + " or " |
311 |
| - + COHERE_RERANK |
312 |
| - + " or " |
313 |
| - + COHERE_V2_EMBEDDING_FLOAT32 |
314 |
| - + " or " |
315 |
| - + COHERE_V2_EMBEDDING_INT8 |
316 |
| - + " or " |
317 |
| - + COHERE_V2_EMBEDDING_UINT8 |
318 |
| - + " or " |
319 |
| - + COHERE_V2_EMBEDDING_BINARY |
320 |
| - + " or " |
321 |
| - + COHERE_V2_EMBEDDING_UBINARY |
322 |
| - ); |
323 |
| - } |
324 |
| - break; |
325 |
| - case BEDROCK: |
326 |
| - if (!(BEDROCK_EMBEDDING.equals(postProcessFunction) |
327 |
| - || BEDROCK_BATCH_JOB_ARN.equals(postProcessFunction) |
328 |
| - || BEDROCK_RERANK.equals(postProcessFunction) |
329 |
| - || BEDROCK_V2_EMBEDDING_FLOAT.equals(postProcessFunction) |
330 |
| - || BEDROCK_V2_EMBEDDING_BINARY.equals(postProcessFunction))) { |
331 |
| - throw new IllegalArgumentException( |
332 |
| - "LLM service is " |
333 |
| - + BEDROCK |
334 |
| - + ", so PostProcessFunction should be " |
335 |
| - + BEDROCK_EMBEDDING |
336 |
| - + " or " |
337 |
| - + BEDROCK_BATCH_JOB_ARN |
338 |
| - + " or " |
339 |
| - + BEDROCK_RERANK |
340 |
| - + " or " |
341 |
| - + BEDROCK_V2_EMBEDDING_FLOAT |
342 |
| - + " or " |
343 |
| - + BEDROCK_V2_EMBEDDING_BINARY |
344 |
| - ); |
345 |
| - } |
346 |
| - break; |
347 |
| - case SAGEMAKER: |
348 |
| - if (!(DEFAULT_EMBEDDING.equals(postProcessFunction) || DEFAULT_RERANK.equals(postProcessFunction))) { |
349 |
| - throw new IllegalArgumentException( |
350 |
| - "LLM service is " |
351 |
| - + SAGEMAKER |
352 |
| - + ", so PostProcessFunction should be " |
353 |
| - + DEFAULT_EMBEDDING |
354 |
| - + " or " |
355 |
| - + DEFAULT_RERANK |
356 |
| - ); |
357 |
| - } |
358 |
| - } |
| 245 | + if (!isInBuiltProcessFunction(postProcessFunction)) { |
| 246 | + return; |
359 | 247 | }
|
| 248 | + switch (remoteServer) { |
| 249 | + case OPENAI: |
| 250 | + if (!postProcessFunction.contains(OPENAI)) { |
| 251 | + throw new IllegalArgumentException(invalidProcessFuncExcText(OPENAI, "PostProcessFunction")); |
| 252 | + } |
| 253 | + break; |
| 254 | + case COHERE: |
| 255 | + if (!postProcessFunction.contains(COHERE)) { |
| 256 | + throw new IllegalArgumentException(invalidProcessFuncExcText(COHERE, "PostProcessFunction")); |
| 257 | + } |
| 258 | + break; |
| 259 | + case BEDROCK: |
| 260 | + if (!postProcessFunction.contains(BEDROCK)) { |
| 261 | + throw new IllegalArgumentException(invalidProcessFuncExcText(BEDROCK, "PostProcessFunction")); |
| 262 | + } |
| 263 | + break; |
| 264 | + case SAGEMAKER: |
| 265 | + if (!(DEFAULT_EMBEDDING.equals(postProcessFunction) || DEFAULT_RERANK.equals(postProcessFunction))) { |
| 266 | + throw new IllegalArgumentException( |
| 267 | + "LLM service is " + SAGEMAKER + ", so PostProcessFunction should be " + DEFAULT_EMBEDDING + " or " + DEFAULT_RERANK |
| 268 | + ); |
| 269 | + } |
| 270 | + } |
| 271 | + } |
| 272 | + |
| 273 | + private String invalidProcessFuncExcText(String remoteServer, String func) { |
| 274 | + return "LLM service is " + remoteServer + ", so " + func + " should be " + remoteServer + " " + func; |
360 | 275 | }
|
361 | 276 |
|
362 |
| - private boolean isInBuiltFunction(String function) { |
363 |
| - return (function != null && function.startsWith(INBUILT_FUNC_PREFIX)); |
| 277 | + private boolean isInBuiltProcessFunction(String processFunction) { |
| 278 | + return (processFunction != null && processFunction.startsWith(INBUILT_FUNC_PREFIX)); |
364 | 279 | }
|
365 | 280 |
|
366 | 281 | public static String getRemoteServerFromURL(String url) {
|
|
0 commit comments