|
6 | 6 | package org.opensearch.ml.rest; |
7 | 7 |
|
8 | 8 | import static org.hamcrest.Matchers.containsString; |
| 9 | +import static org.opensearch.ml.common.output.model.ModelTensor.DATA_AS_MAP_FIELD; |
9 | 10 | import static org.opensearch.ml.rest.RestMLRemoteInferenceIT.disableClusterConnectorAccessControl; |
10 | 11 |
|
11 | 12 | import java.io.IOException; |
| 13 | +import java.util.List; |
| 14 | +import java.util.Map; |
| 15 | +import java.util.Optional; |
12 | 16 |
|
13 | 17 | import org.apache.hc.core5.http.ParseException; |
14 | 18 | import org.hamcrest.MatcherAssert; |
15 | 19 | import org.junit.After; |
16 | 20 | import org.junit.Before; |
| 21 | +import org.opensearch.client.Response; |
17 | 22 | import org.opensearch.client.ResponseException; |
| 23 | +import org.opensearch.core.rest.RestStatus; |
| 24 | +import org.opensearch.ml.common.output.model.ModelTensorOutput; |
| 25 | +import org.opensearch.ml.common.output.model.ModelTensors; |
| 26 | +import org.opensearch.ml.utils.TestHelper; |
18 | 27 |
|
19 | 28 | public class RestConnectorToolIT extends RestBaseAgentToolsIT { |
20 | 29 | private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); |
@@ -143,8 +152,25 @@ public void testConnectorToolInFlowAgent() throws IOException, ParseException { |
143 | 152 | + "}"; |
144 | 153 | String agentId = createAgent(registerAgentRequestBody); |
145 | 154 | String agentInput = "{\n" + " \"parameters\": {\n" + " \"messages\": \"hello\"\n" + " }\n" + "}"; |
146 | | - String result = executeAgent(agentId, agentInput); |
| 155 | + Response response = TestHelper |
| 156 | + .makeRequest(client(), "POST", "/_plugins/_ml/agents/" + agentId + "/_execute", null, agentInput, null); |
| 157 | + String result = parseResponseFromResponse(response); |
| 158 | + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); |
147 | 159 | assertNotNull(result); |
148 | 160 | } |
149 | 161 |
|
| 162 | + private String parseResponseFromResponse(Response response) throws IOException, ParseException { |
| 163 | + Map<String, Object> responseInMap = parseResponseToMap(response); |
| 164 | + return Optional |
| 165 | + .ofNullable(responseInMap) |
| 166 | + .map(m -> (List<Object>) m.get(ModelTensorOutput.INFERENCE_RESULT_FIELD)) |
| 167 | + .filter(l -> !l.isEmpty()) |
| 168 | + .map(l -> (Map<String, Object>) l.get(0)) |
| 169 | + .map(m -> (List<Object>) m.get(ModelTensors.OUTPUT_FIELD)) |
| 170 | + .filter(l -> !l.isEmpty()) |
| 171 | + .map(l -> (Map<String, Object>) l.get(0)) |
| 172 | + .map(m -> (Map<String, Object>) m.get(DATA_AS_MAP_FIELD)) |
| 173 | + .map(m -> (String) m.get("response")) |
| 174 | + .orElseThrow(() -> new AssertionError("Unable to parse response from agent execution")); |
| 175 | + } |
150 | 176 | } |
0 commit comments