34
34
import org .opensearch .core .rest .RestStatus ;
35
35
import org .opensearch .ml .common .MLTaskState ;
36
36
import org .opensearch .ml .utils .TestHelper ;
37
+ import org .opensearch .searchpipelines .questionanswering .generative .llm .LlmIOUtil ;
37
38
38
39
import com .google .common .collect .ImmutableList ;
39
40
import com .google .common .collect .ImmutableMap ;
@@ -147,6 +148,35 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
147
148
private static final String BEDROCK_CONNECTOR_BLUEPRINT = AWS_SESSION_TOKEN == null
148
149
? BEDROCK_CONNECTOR_BLUEPRINT2
149
150
: BEDROCK_CONNECTOR_BLUEPRINT1 ;
151
+
152
+ private static final String COHERE_API_KEY = System .getenv ("COHERE_API_KEY" );
153
+ private static final String COHERE_CONNECTOR_BLUEPRINT = "{\n "
154
+ + " \" name\" : \" Cohere Chat Model\" ,\n "
155
+ + " \" description\" : \" The connector to Cohere's public chat API\" ,\n "
156
+ + " \" version\" : \" 1\" ,\n "
157
+ + " \" protocol\" : \" http\" ,\n "
158
+ + " \" credential\" : {\n "
159
+ + " \" cohere_key\" : \" "
160
+ + COHERE_API_KEY
161
+ + "\" \n "
162
+ + " },\n "
163
+ + " \" parameters\" : {\n "
164
+ + " \" model\" : \" command\" \n "
165
+ + " },\n "
166
+ + " \" actions\" : [\n "
167
+ + " {\n "
168
+ + " \" action_type\" : \" predict\" ,\n "
169
+ + " \" method\" : \" POST\" ,\n "
170
+ + " \" url\" : \" https://api.cohere.ai/v1/chat\" ,\n "
171
+ + " \" headers\" : {\n "
172
+ + " \" Authorization\" : \" Bearer ${credential.cohere_key}\" ,\n "
173
+ + " \" Request-Source\" : \" unspecified:opensearch\" \n "
174
+ + " },\n "
175
+ + " \" request_body\" : \" { \\ \" message\\ \" : \\ \" ${parameters.inputs}\\ \" , \\ \" model\\ \" : \\ \" ${parameters.model}\\ \" }\" \n "
176
+ + " }\n "
177
+ + " ]\n "
178
+ + "}" ;
179
+
150
180
private static final String PIPELINE_TEMPLATE = "{\n "
151
181
+ " \" response_processors\" : [\n "
152
182
+ " {\n "
@@ -199,6 +229,23 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
199
229
+ " }\n "
200
230
+ "}" ;
201
231
232
+ private static final String BM25_SEARCH_REQUEST_WITH_LLM_RESPONSE_FIELD_TEMPLATE = "{\n "
233
+ + " \" _source\" : [\" %s\" ],\n "
234
+ + " \" query\" : {\n "
235
+ + " \" match\" : {\" %s\" : \" %s\" }\n "
236
+ + " },\n "
237
+ + " \" ext\" : {\n "
238
+ + " \" generative_qa_parameters\" : {\n "
239
+ + " \" llm_model\" : \" %s\" ,\n "
240
+ + " \" llm_question\" : \" %s\" ,\n "
241
+ + " \" context_size\" : %d,\n "
242
+ + " \" message_size\" : %d,\n "
243
+ + " \" timeout\" : %d,\n "
244
+ + " \" llm_response_field\" : \" %s\" \n "
245
+ + " }\n "
246
+ + " }\n "
247
+ + "}" ;
248
+
202
249
private static final String OPENAI_MODEL = "gpt-3.5-turbo" ;
203
250
private static final String BEDROCK_ANTHROPIC_CLAUDE = "bedrock/anthropic-claude" ;
204
251
private static final String TEST_DOC_PATH = "org/opensearch/ml/rest/test_data/" ;
@@ -472,6 +519,111 @@ public void testBM25WithBedrockWithConversation() throws Exception {
472
519
assertNotNull (interactionId );
473
520
}
474
521
522
+ public void testBM25WithCohere () throws Exception {
523
+ // Skip test if key is null
524
+ if (COHERE_API_KEY == null ) {
525
+ return ;
526
+ }
527
+ Response response = createConnector (COHERE_CONNECTOR_BLUEPRINT );
528
+ Map responseMap = parseResponseToMap (response );
529
+ String connectorId = (String ) responseMap .get ("connector_id" );
530
+ response = registerRemoteModel ("Cohere Chat Completion v1" , connectorId );
531
+ responseMap = parseResponseToMap (response );
532
+ String taskId = (String ) responseMap .get ("task_id" );
533
+ waitForTask (taskId , MLTaskState .COMPLETED );
534
+ response = getTask (taskId );
535
+ responseMap = parseResponseToMap (response );
536
+ String modelId = (String ) responseMap .get ("model_id" );
537
+ response = deployRemoteModel (modelId );
538
+ responseMap = parseResponseToMap (response );
539
+ taskId = (String ) responseMap .get ("task_id" );
540
+ waitForTask (taskId , MLTaskState .COMPLETED );
541
+
542
+ PipelineParameters pipelineParameters = new PipelineParameters ();
543
+ pipelineParameters .tag = "testBM25WithCohere" ;
544
+ pipelineParameters .description = "desc" ;
545
+ pipelineParameters .modelId = modelId ;
546
+ pipelineParameters .systemPrompt = "You are a helpful assistant" ;
547
+ pipelineParameters .userInstructions = "none" ;
548
+ pipelineParameters .context_field = "text" ;
549
+ Response response1 = createSearchPipeline ("pipeline_test" , pipelineParameters );
550
+ assertEquals (200 , response1 .getStatusLine ().getStatusCode ());
551
+
552
+ SearchRequestParameters requestParameters = new SearchRequestParameters ();
553
+ requestParameters .source = "text" ;
554
+ requestParameters .match = "president" ;
555
+ requestParameters .llmModel = LlmIOUtil .COHERE_PROVIDER_PREFIX + "command" ;
556
+ requestParameters .llmQuestion = "who is lincoln" ;
557
+ requestParameters .contextSize = 5 ;
558
+ requestParameters .interactionSize = 5 ;
559
+ requestParameters .timeout = 60 ;
560
+ Response response2 = performSearch (INDEX_NAME , "pipeline_test" , 5 , requestParameters );
561
+ assertEquals (200 , response2 .getStatusLine ().getStatusCode ());
562
+
563
+ Map responseMap2 = parseResponseToMap (response2 );
564
+ Map ext = (Map ) responseMap2 .get ("ext" );
565
+ assertNotNull (ext );
566
+ Map rag = (Map ) ext .get ("retrieval_augmented_generation" );
567
+ assertNotNull (rag );
568
+
569
+ // TODO handle errors such as throttling
570
+ String answer = (String ) rag .get ("answer" );
571
+ assertNotNull (answer );
572
+ }
573
+
574
+ public void testBM25WithCohereUsingLlmResponseField () throws Exception {
575
+ // Skip test if key is null
576
+ if (COHERE_API_KEY == null ) {
577
+ return ;
578
+ }
579
+ Response response = createConnector (COHERE_CONNECTOR_BLUEPRINT );
580
+ Map responseMap = parseResponseToMap (response );
581
+ String connectorId = (String ) responseMap .get ("connector_id" );
582
+ response = registerRemoteModel ("Cohere Chat Completion v1" , connectorId );
583
+ responseMap = parseResponseToMap (response );
584
+ String taskId = (String ) responseMap .get ("task_id" );
585
+ waitForTask (taskId , MLTaskState .COMPLETED );
586
+ response = getTask (taskId );
587
+ responseMap = parseResponseToMap (response );
588
+ String modelId = (String ) responseMap .get ("model_id" );
589
+ response = deployRemoteModel (modelId );
590
+ responseMap = parseResponseToMap (response );
591
+ taskId = (String ) responseMap .get ("task_id" );
592
+ waitForTask (taskId , MLTaskState .COMPLETED );
593
+
594
+ PipelineParameters pipelineParameters = new PipelineParameters ();
595
+ pipelineParameters .tag = "testBM25WithCohereLlmResponseField" ;
596
+ pipelineParameters .description = "desc" ;
597
+ pipelineParameters .modelId = modelId ;
598
+ pipelineParameters .systemPrompt = "You are a helpful assistant" ;
599
+ pipelineParameters .userInstructions = "none" ;
600
+ pipelineParameters .context_field = "text" ;
601
+ Response response1 = createSearchPipeline ("pipeline_test" , pipelineParameters );
602
+ assertEquals (200 , response1 .getStatusLine ().getStatusCode ());
603
+
604
+ SearchRequestParameters requestParameters = new SearchRequestParameters ();
605
+ requestParameters .source = "text" ;
606
+ requestParameters .match = "president" ;
607
+ requestParameters .llmModel = "command" ;
608
+ requestParameters .llmQuestion = "who is lincoln" ;
609
+ requestParameters .contextSize = 5 ;
610
+ requestParameters .interactionSize = 5 ;
611
+ requestParameters .timeout = 60 ;
612
+ requestParameters .llmResponseField = "text" ;
613
+ Response response2 = performSearch (INDEX_NAME , "pipeline_test" , 5 , requestParameters );
614
+ assertEquals (200 , response2 .getStatusLine ().getStatusCode ());
615
+
616
+ Map responseMap2 = parseResponseToMap (response2 );
617
+ Map ext = (Map ) responseMap2 .get ("ext" );
618
+ assertNotNull (ext );
619
+ Map rag = (Map ) ext .get ("retrieval_augmented_generation" );
620
+ assertNotNull (rag );
621
+
622
+ // TODO handle errors such as throttling
623
+ String answer = (String ) rag .get ("answer" );
624
+ assertNotNull (answer );
625
+ }
626
+
475
627
private Response createSearchPipeline (String pipeline , PipelineParameters parameters ) throws Exception {
476
628
return makeRequest (
477
629
client (),
@@ -498,11 +650,11 @@ private Response createSearchPipeline(String pipeline, PipelineParameters parame
498
650
private Response performSearch (String indexName , String pipeline , int size , SearchRequestParameters requestParameters )
499
651
throws Exception {
500
652
501
- String httpEntity = ( requestParameters .conversationId == null )
653
+ String httpEntity = requestParameters .llmResponseField != null
502
654
? String
503
655
.format (
504
656
Locale .ROOT ,
505
- BM25_SEARCH_REQUEST_TEMPLATE ,
657
+ BM25_SEARCH_REQUEST_WITH_LLM_RESPONSE_FIELD_TEMPLATE ,
506
658
requestParameters .source ,
507
659
requestParameters .source ,
508
660
requestParameters .match ,
@@ -512,8 +664,25 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
512
664
requestParameters .userInstructions ,
513
665
requestParameters .contextSize ,
514
666
requestParameters .interactionSize ,
515
- requestParameters .timeout
667
+ requestParameters .timeout ,
668
+ requestParameters .llmResponseField
516
669
)
670
+ : (requestParameters .conversationId == null )
671
+ ? String
672
+ .format (
673
+ Locale .ROOT ,
674
+ BM25_SEARCH_REQUEST_TEMPLATE ,
675
+ requestParameters .source ,
676
+ requestParameters .source ,
677
+ requestParameters .match ,
678
+ requestParameters .llmModel ,
679
+ requestParameters .llmQuestion ,
680
+ requestParameters .systemPrompt ,
681
+ requestParameters .userInstructions ,
682
+ requestParameters .contextSize ,
683
+ requestParameters .interactionSize ,
684
+ requestParameters .timeout
685
+ )
517
686
: String
518
687
.format (
519
688
Locale .ROOT ,
@@ -572,5 +741,7 @@ static class SearchRequestParameters {
572
741
int interactionSize ;
573
742
int timeout ;
574
743
String conversationId ;
744
+
745
+ String llmResponseField ;
575
746
}
576
747
}
0 commit comments