Skip to content

Commit a85bec6

Browse files
Add support for Cohere and others. (opensearch-project#2238)
* Add support for Cohere and others. Signed-off-by: Austin Lee <[email protected]> * Fix spotless and improve test coverage. Signed-off-by: Austin Lee <[email protected]> * Remove unused code. Signed-off-by: Austin Lee <[email protected]> * Apply review comments, add more tests, simplify code. Signed-off-by: Austin Lee <[email protected]> * Add test coverage for error cases. Signed-off-by: Austin Lee <[email protected]> * Add test coverage. Signed-off-by: Austin Lee <[email protected]> --------- Signed-off-by: Austin Lee <[email protected]>
1 parent 148672e commit a85bec6

File tree

15 files changed

+594
-78
lines changed

15 files changed

+594
-78
lines changed

plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java

Lines changed: 174 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.opensearch.core.rest.RestStatus;
3535
import org.opensearch.ml.common.MLTaskState;
3636
import org.opensearch.ml.utils.TestHelper;
37+
import org.opensearch.searchpipelines.questionanswering.generative.llm.LlmIOUtil;
3738

3839
import com.google.common.collect.ImmutableList;
3940
import com.google.common.collect.ImmutableMap;
@@ -147,6 +148,35 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
147148
private static final String BEDROCK_CONNECTOR_BLUEPRINT = AWS_SESSION_TOKEN == null
148149
? BEDROCK_CONNECTOR_BLUEPRINT2
149150
: BEDROCK_CONNECTOR_BLUEPRINT1;
151+
152+
private static final String COHERE_API_KEY = System.getenv("COHERE_API_KEY");
153+
private static final String COHERE_CONNECTOR_BLUEPRINT = "{\n"
154+
+ " \"name\": \"Cohere Chat Model\",\n"
155+
+ " \"description\": \"The connector to Cohere's public chat API\",\n"
156+
+ " \"version\": \"1\",\n"
157+
+ " \"protocol\": \"http\",\n"
158+
+ " \"credential\": {\n"
159+
+ " \"cohere_key\": \""
160+
+ COHERE_API_KEY
161+
+ "\"\n"
162+
+ " },\n"
163+
+ " \"parameters\": {\n"
164+
+ " \"model\": \"command\"\n"
165+
+ " },\n"
166+
+ " \"actions\": [\n"
167+
+ " {\n"
168+
+ " \"action_type\": \"predict\",\n"
169+
+ " \"method\": \"POST\",\n"
170+
+ " \"url\": \"https://api.cohere.ai/v1/chat\",\n"
171+
+ " \"headers\": {\n"
172+
+ " \"Authorization\": \"Bearer ${credential.cohere_key}\",\n"
173+
+ " \"Request-Source\": \"unspecified:opensearch\"\n"
174+
+ " },\n"
175+
+ " \"request_body\": \"{ \\\"message\\\": \\\"${parameters.inputs}\\\", \\\"model\\\": \\\"${parameters.model}\\\" }\" \n"
176+
+ " }\n"
177+
+ " ]\n"
178+
+ "}";
179+
150180
private static final String PIPELINE_TEMPLATE = "{\n"
151181
+ " \"response_processors\": [\n"
152182
+ " {\n"
@@ -199,6 +229,23 @@ public class RestMLRAGSearchProcessorIT extends RestMLRemoteInferenceIT {
199229
+ " }\n"
200230
+ "}";
201231

232+
private static final String BM25_SEARCH_REQUEST_WITH_LLM_RESPONSE_FIELD_TEMPLATE = "{\n"
233+
+ " \"_source\": [\"%s\"],\n"
234+
+ " \"query\" : {\n"
235+
+ " \"match\": {\"%s\": \"%s\"}\n"
236+
+ " },\n"
237+
+ " \"ext\": {\n"
238+
+ " \"generative_qa_parameters\": {\n"
239+
+ " \"llm_model\": \"%s\",\n"
240+
+ " \"llm_question\": \"%s\",\n"
241+
+ " \"context_size\": %d,\n"
242+
+ " \"message_size\": %d,\n"
243+
+ " \"timeout\": %d,\n"
244+
+ " \"llm_response_field\": \"%s\"\n"
245+
+ " }\n"
246+
+ " }\n"
247+
+ "}";
248+
202249
private static final String OPENAI_MODEL = "gpt-3.5-turbo";
203250
private static final String BEDROCK_ANTHROPIC_CLAUDE = "bedrock/anthropic-claude";
204251
private static final String TEST_DOC_PATH = "org/opensearch/ml/rest/test_data/";
@@ -472,6 +519,111 @@ public void testBM25WithBedrockWithConversation() throws Exception {
472519
assertNotNull(interactionId);
473520
}
474521

522+
public void testBM25WithCohere() throws Exception {
523+
// Skip test if key is null
524+
if (COHERE_API_KEY == null) {
525+
return;
526+
}
527+
Response response = createConnector(COHERE_CONNECTOR_BLUEPRINT);
528+
Map responseMap = parseResponseToMap(response);
529+
String connectorId = (String) responseMap.get("connector_id");
530+
response = registerRemoteModel("Cohere Chat Completion v1", connectorId);
531+
responseMap = parseResponseToMap(response);
532+
String taskId = (String) responseMap.get("task_id");
533+
waitForTask(taskId, MLTaskState.COMPLETED);
534+
response = getTask(taskId);
535+
responseMap = parseResponseToMap(response);
536+
String modelId = (String) responseMap.get("model_id");
537+
response = deployRemoteModel(modelId);
538+
responseMap = parseResponseToMap(response);
539+
taskId = (String) responseMap.get("task_id");
540+
waitForTask(taskId, MLTaskState.COMPLETED);
541+
542+
PipelineParameters pipelineParameters = new PipelineParameters();
543+
pipelineParameters.tag = "testBM25WithCohere";
544+
pipelineParameters.description = "desc";
545+
pipelineParameters.modelId = modelId;
546+
pipelineParameters.systemPrompt = "You are a helpful assistant";
547+
pipelineParameters.userInstructions = "none";
548+
pipelineParameters.context_field = "text";
549+
Response response1 = createSearchPipeline("pipeline_test", pipelineParameters);
550+
assertEquals(200, response1.getStatusLine().getStatusCode());
551+
552+
SearchRequestParameters requestParameters = new SearchRequestParameters();
553+
requestParameters.source = "text";
554+
requestParameters.match = "president";
555+
requestParameters.llmModel = LlmIOUtil.COHERE_PROVIDER_PREFIX + "command";
556+
requestParameters.llmQuestion = "who is lincoln";
557+
requestParameters.contextSize = 5;
558+
requestParameters.interactionSize = 5;
559+
requestParameters.timeout = 60;
560+
Response response2 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters);
561+
assertEquals(200, response2.getStatusLine().getStatusCode());
562+
563+
Map responseMap2 = parseResponseToMap(response2);
564+
Map ext = (Map) responseMap2.get("ext");
565+
assertNotNull(ext);
566+
Map rag = (Map) ext.get("retrieval_augmented_generation");
567+
assertNotNull(rag);
568+
569+
// TODO handle errors such as throttling
570+
String answer = (String) rag.get("answer");
571+
assertNotNull(answer);
572+
}
573+
574+
public void testBM25WithCohereUsingLlmResponseField() throws Exception {
575+
// Skip test if key is null
576+
if (COHERE_API_KEY == null) {
577+
return;
578+
}
579+
Response response = createConnector(COHERE_CONNECTOR_BLUEPRINT);
580+
Map responseMap = parseResponseToMap(response);
581+
String connectorId = (String) responseMap.get("connector_id");
582+
response = registerRemoteModel("Cohere Chat Completion v1", connectorId);
583+
responseMap = parseResponseToMap(response);
584+
String taskId = (String) responseMap.get("task_id");
585+
waitForTask(taskId, MLTaskState.COMPLETED);
586+
response = getTask(taskId);
587+
responseMap = parseResponseToMap(response);
588+
String modelId = (String) responseMap.get("model_id");
589+
response = deployRemoteModel(modelId);
590+
responseMap = parseResponseToMap(response);
591+
taskId = (String) responseMap.get("task_id");
592+
waitForTask(taskId, MLTaskState.COMPLETED);
593+
594+
PipelineParameters pipelineParameters = new PipelineParameters();
595+
pipelineParameters.tag = "testBM25WithCohereLlmResponseField";
596+
pipelineParameters.description = "desc";
597+
pipelineParameters.modelId = modelId;
598+
pipelineParameters.systemPrompt = "You are a helpful assistant";
599+
pipelineParameters.userInstructions = "none";
600+
pipelineParameters.context_field = "text";
601+
Response response1 = createSearchPipeline("pipeline_test", pipelineParameters);
602+
assertEquals(200, response1.getStatusLine().getStatusCode());
603+
604+
SearchRequestParameters requestParameters = new SearchRequestParameters();
605+
requestParameters.source = "text";
606+
requestParameters.match = "president";
607+
requestParameters.llmModel = "command";
608+
requestParameters.llmQuestion = "who is lincoln";
609+
requestParameters.contextSize = 5;
610+
requestParameters.interactionSize = 5;
611+
requestParameters.timeout = 60;
612+
requestParameters.llmResponseField = "text";
613+
Response response2 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters);
614+
assertEquals(200, response2.getStatusLine().getStatusCode());
615+
616+
Map responseMap2 = parseResponseToMap(response2);
617+
Map ext = (Map) responseMap2.get("ext");
618+
assertNotNull(ext);
619+
Map rag = (Map) ext.get("retrieval_augmented_generation");
620+
assertNotNull(rag);
621+
622+
// TODO handle errors such as throttling
623+
String answer = (String) rag.get("answer");
624+
assertNotNull(answer);
625+
}
626+
475627
private Response createSearchPipeline(String pipeline, PipelineParameters parameters) throws Exception {
476628
return makeRequest(
477629
client(),
@@ -498,11 +650,11 @@ private Response createSearchPipeline(String pipeline, PipelineParameters parame
498650
private Response performSearch(String indexName, String pipeline, int size, SearchRequestParameters requestParameters)
499651
throws Exception {
500652

501-
String httpEntity = (requestParameters.conversationId == null)
653+
String httpEntity = requestParameters.llmResponseField != null
502654
? String
503655
.format(
504656
Locale.ROOT,
505-
BM25_SEARCH_REQUEST_TEMPLATE,
657+
BM25_SEARCH_REQUEST_WITH_LLM_RESPONSE_FIELD_TEMPLATE,
506658
requestParameters.source,
507659
requestParameters.source,
508660
requestParameters.match,
@@ -512,8 +664,25 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
512664
requestParameters.userInstructions,
513665
requestParameters.contextSize,
514666
requestParameters.interactionSize,
515-
requestParameters.timeout
667+
requestParameters.timeout,
668+
requestParameters.llmResponseField
516669
)
670+
: (requestParameters.conversationId == null)
671+
? String
672+
.format(
673+
Locale.ROOT,
674+
BM25_SEARCH_REQUEST_TEMPLATE,
675+
requestParameters.source,
676+
requestParameters.source,
677+
requestParameters.match,
678+
requestParameters.llmModel,
679+
requestParameters.llmQuestion,
680+
requestParameters.systemPrompt,
681+
requestParameters.userInstructions,
682+
requestParameters.contextSize,
683+
requestParameters.interactionSize,
684+
requestParameters.timeout
685+
)
517686
: String
518687
.format(
519688
Locale.ROOT,
@@ -572,5 +741,7 @@ static class SearchRequestParameters {
572741
int interactionSize;
573742
int timeout;
574743
String conversationId;
744+
745+
String llmResponseField;
575746
}
576747
}

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,8 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
165165
llmQuestion,
166166
chatHistory,
167167
searchResults,
168-
timeout
168+
timeout,
169+
params.getLlmResponseField()
169170
)
170171
);
171172
log.info("doChatCompletion complete. ({})", getDuration(start));

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,16 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
7171
// from a remote inference endpoint before timing out the request.
7272
private static final ParseField TIMEOUT = new ParseField("timeout");
7373

74+
// Optional parameter: this parameter allows request-level customization of the "system" (role) prompt.
7475
private static final ParseField SYSTEM_PROMPT = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_SYSTEM_PROMPT);
7576

77+
// Optional parameter: this parameter allows request-level customization of the "user" (role) prompt.
7678
private static final ParseField USER_INSTRUCTIONS = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_USER_INSTRUCTIONS);
7779

80+
// Optional parameter; this parameter indicates the name of the field in the LLM response
81+
// that contains the chat completion text, i.e. "answer".
82+
private static final ParseField LLM_RESPONSE_FIELD = new ParseField("llm_response_field");
83+
7884
public static final int SIZE_NULL_VALUE = -1;
7985

8086
static {
@@ -87,6 +93,7 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
8793
PARSER.declareIntOrNull(GenerativeQAParameters::setContextSize, SIZE_NULL_VALUE, CONTEXT_SIZE);
8894
PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE);
8995
PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT);
96+
PARSER.declareStringOrNull(GenerativeQAParameters::setLlmResponseField, LLM_RESPONSE_FIELD);
9097
}
9198

9299
@Setter
@@ -121,6 +128,10 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {
121128
@Getter
122129
private String userInstructions;
123130

131+
@Setter
132+
@Getter
133+
private String llmResponseField;
134+
124135
public GenerativeQAParameters(
125136
String conversationId,
126137
String llmModel,
@@ -129,7 +140,8 @@ public GenerativeQAParameters(
129140
String userInstructions,
130141
Integer contextSize,
131142
Integer interactionSize,
132-
Integer timeout
143+
Integer timeout,
144+
String llmResponseField
133145
) {
134146
this.conversationId = conversationId;
135147
this.llmModel = llmModel;
@@ -143,6 +155,7 @@ public GenerativeQAParameters(
143155
this.contextSize = (contextSize == null) ? SIZE_NULL_VALUE : contextSize;
144156
this.interactionSize = (interactionSize == null) ? SIZE_NULL_VALUE : interactionSize;
145157
this.timeout = (timeout == null) ? SIZE_NULL_VALUE : timeout;
158+
this.llmResponseField = llmResponseField;
146159
}
147160

148161
public GenerativeQAParameters(StreamInput input) throws IOException {
@@ -154,6 +167,7 @@ public GenerativeQAParameters(StreamInput input) throws IOException {
154167
this.contextSize = input.readInt();
155168
this.interactionSize = input.readInt();
156169
this.timeout = input.readInt();
170+
this.llmResponseField = input.readOptionalString();
157171
}
158172

159173
@Override
@@ -166,7 +180,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
166180
.field(USER_INSTRUCTIONS.getPreferredName(), this.userInstructions)
167181
.field(CONTEXT_SIZE.getPreferredName(), this.contextSize)
168182
.field(INTERACTION_SIZE.getPreferredName(), this.interactionSize)
169-
.field(TIMEOUT.getPreferredName(), this.timeout);
183+
.field(TIMEOUT.getPreferredName(), this.timeout)
184+
.field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField);
170185
}
171186

172187
@Override
@@ -181,6 +196,7 @@ public void writeTo(StreamOutput out) throws IOException {
181196
out.writeInt(contextSize);
182197
out.writeInt(interactionSize);
183198
out.writeInt(timeout);
199+
out.writeOptionalString(llmResponseField);
184200
}
185201

186202
public static GenerativeQAParameters parse(XContentParser parser) throws IOException {
@@ -204,6 +220,7 @@ public boolean equals(Object o) {
204220
&& Objects.equals(this.userInstructions, other.getUserInstructions())
205221
&& (this.contextSize == other.getContextSize())
206222
&& (this.interactionSize == other.getInteractionSize())
207-
&& (this.timeout == other.getTimeout());
223+
&& (this.timeout == other.getTimeout())
224+
&& Objects.equals(this.llmResponseField, other.getLlmResponseField());
208225
}
209226
}

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,5 @@ public class ChatCompletionInput {
4343
private String systemPrompt;
4444
private String userInstructions;
4545
private Llm.ModelProvider modelProvider;
46+
private String llmResponseField;
4647
}

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutput.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
import java.util.List;
2121

22+
import org.opensearch.core.common.util.CollectionUtils;
23+
2224
import lombok.Getter;
2325
import lombok.Setter;
2426
import lombok.extern.log4j.Log4j2;
@@ -38,19 +40,12 @@ public class ChatCompletionOutput {
3840

3941
public ChatCompletionOutput(List<Object> answers, List<String> errors) {
4042

41-
if (answers == null && errors == null) {
43+
if (CollectionUtils.isEmpty(answers) && CollectionUtils.isEmpty(errors)) {
4244
throw new IllegalArgumentException("answers and errors can't both be null.");
4345
}
4446

45-
if (answers == null) {
46-
if (errors.isEmpty()) {
47-
throw new IllegalArgumentException("If answers is not provided, one or more errors must be provided.");
48-
}
47+
if (CollectionUtils.isEmpty(answers)) {
4948
this.errorOccurred = true;
50-
} else if (errors == null) {
51-
if (answers.isEmpty()) {
52-
throw new IllegalArgumentException("If errors is not provided, one or more answers must be provided.");
53-
}
5449
}
5550

5651
this.answers = answers;

0 commit comments

Comments
 (0)