Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ public class AgentMLInput extends MLInput {
@Setter
private Boolean isAsync;

@Getter
@Setter
private Map<String, Object> memory;

@Builder(builderMethodName = "AgentMLInputBuilder")
public AgentMLInput(String agentId, String tenantId, FunctionName functionName, MLInputDataset inputDataset) {
this(agentId, tenantId, functionName, inputDataset, false);
Expand All @@ -72,6 +76,13 @@ public void writeTo(StreamOutput out) throws IOException {
if (streamOutputVersion.onOrAfter(AgentMLInput.MINIMAL_SUPPORTED_VERSION_FOR_ASYNC_EXECUTION)) {
out.writeOptionalBoolean(isAsync);
}
// Serialize memory field
if (memory != null) {
out.writeBoolean(true);
out.writeMap(memory);
} else {
out.writeBoolean(false);
}
}

public AgentMLInput(StreamInput in) throws IOException {
Expand All @@ -82,6 +93,12 @@ public AgentMLInput(StreamInput in) throws IOException {
if (streamInputVersion.onOrAfter(AgentMLInput.MINIMAL_SUPPORTED_VERSION_FOR_ASYNC_EXECUTION)) {
this.isAsync = in.readOptionalBoolean();
}
// Deserialize memory field
if (in.readBoolean()) {
this.memory = in.readMap();
} else {
this.memory = null;
}
}

public AgentMLInput(XContentParser parser, FunctionName functionName) throws IOException {
Expand All @@ -103,6 +120,9 @@ public AgentMLInput(XContentParser parser, FunctionName functionName) throws IOE
Map<String, String> parameters = StringUtils.getParameterMap(parser.map());
inputDataset = new RemoteInferenceInputDataSet(parameters);
break;
case "memory":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use constants (leaving nitpicks so we can track them, they can be addressed later)

memory = parser.map();
break;
case ASYNC_FIELD:
isAsync = parser.booleanValue();
break;
Expand All @@ -112,5 +132,4 @@ public AgentMLInput(XContentParser parser, FunctionName functionName) throws IOE
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;

import com.google.gson.JsonSyntaxException;
import com.google.gson.reflect.TypeToken;
import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.PathNotFoundException;
Expand Down Expand Up @@ -93,9 +94,42 @@ public static Map<String, String> extractInputParameters(Map<String, String> par
StringSubstitutor stringSubstitutor = new StringSubstitutor(parameters, "${parameters.", "}");
String input = stringSubstitutor.replace(parameters.get("input"));
extractedParameters.put("input", input);
Map<String, String> inputParameters = gson
.fromJson(input, TypeToken.getParameterized(Map.class, String.class, String.class).getType());
extractedParameters.putAll(inputParameters);

// Check if input is a JSON object or array
String trimmedInput = input.trim();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason for these changes? i would avoid parameter parsing changes unless 100% required for memory related changes

if (trimmedInput.startsWith("{")) {
// Input is a JSON object - try parsing as Map<String, String> first (existing behavior)
try {
Map<String, String> inputParameters = gson
.fromJson(input, TypeToken.getParameterized(Map.class, String.class, String.class).getType());
extractedParameters.putAll(inputParameters);
Comment on lines +103 to +105
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use this please:

Map<String, Object> parsedInputParameters = gson.fromJson(input, TypeToken.getParameterized(Map.class, String.class, Object.class).getType());
extractedParameters.putAll(StringUtils.getParameterMap(parsedInputParameters));

Ref: https://github.com/opensearch-project/ml-commons/pull/4138/files#diff-28a59745543ec384ed12c99d11a3d4f9909aa4032449587135b9bdcf313c6293

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's keep this method as in the above PR if possible, thanks

} catch (JsonSyntaxException e) {
// Fallback: handle mixed types (arrays, objects, etc.) for cases like {"index": ["*"]}
try {
Map<String, Object> inputParameters = gson
.fromJson(input, TypeToken.getParameterized(Map.class, String.class, Object.class).getType());

// Convert non-string values to JSON strings for tool compatibility
for (Map.Entry<String, Object> entry : inputParameters.entrySet()) {
String key = entry.getKey();
Object value = entry.getValue();

if (value instanceof String) {
extractedParameters.put(key, (String) value); // Keep strings as-is
} else {
extractedParameters.put(key, gson.toJson(value)); // Convert arrays/objects to JSON strings
}
}
} catch (Exception fallbackException) {
// If both approaches fail, log original error and continue
log.info("fail extract parameters from key 'input' due to" + e.getMessage());
}
}
} else if (trimmedInput.startsWith("[")) {
// Input is a JSON array - skip parsing as it's not a parameter map
log.debug("Input is a JSON array, skipping parameter extraction");
}
// If it's neither object nor array, it's likely a plain string - keep as is
} catch (Exception exception) {
log.info("fail extract parameters from key 'input' due to" + exception.getMessage());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,26 @@ public void testConstructorWithStreamInput_VersionCompatibility() throws IOExcep
assertEquals("testAgentId", inputNewVersion.getAgentId());
assertEquals("testTenantId", inputNewVersion.getTenantId()); // tenantId should be populated for newer versions
}

@Test
public void testMemoryGetterSetter() {
// Test memory field getter/setter functionality
AgentMLInput input = new AgentMLInput("testAgent", null, FunctionName.AGENT, null);

// Initially memory should be null
assertNull("Memory should be null initially", input.getMemory());

// Set memory and verify
Map<String, Object> memoryMap = new HashMap<>();
memoryMap.put("type", "bedrock_agentcore_memory");
memoryMap.put("memory_arn", "test-arn");
memoryMap.put("region", "us-east-1");

input.setMemory(memoryMap);

assertNotNull("Memory should not be null after setting", input.getMemory());
assertEquals("bedrock_agentcore_memory", input.getMemory().get("type"));
assertEquals("test-arn", input.getMemory().get("memory_arn"));
assertEquals("us-east-1", input.getMemory().get("region"));
Comment on lines +150 to +155
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test is only testing the setter, not sure how useful that is

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,95 @@ public void testFilterToolOutput_ComplexNestedPath() {
// Should contain only the targeted deep value
assertEquals("targetValue", result);
}

@Test
public void testExtractInputParameters_ExistingStringValues() {
// Test existing behavior with string-only JSON (should work exactly as before)
Map<String, String> parameters = new HashMap<>();
parameters.put("input", "{\"query\":\"test\",\"limit\":\"10\"}");

Map<String, String> result = ToolUtils.extractInputParameters(parameters, null);

assertEquals("test", result.get("query"));
assertEquals("10", result.get("limit"));
assertEquals("{\"query\":\"test\",\"limit\":\"10\"}", result.get("input"));
}

@Test
public void testExtractInputParameters_ArrayValues() {
// Test new functionality with array values (should work now instead of failing)
Map<String, String> parameters = new HashMap<>();
parameters.put("input", "{\"index\":[\"*\"]}");

Map<String, String> result = ToolUtils.extractInputParameters(parameters, null);

assertEquals("[\"*\"]", result.get("index"));
assertEquals("{\"index\":[\"*\"]}", result.get("input"));
}

@Test
public void testExtractInputParameters_MixedTypes() {
// Test mixed types: strings, arrays, numbers
Map<String, String> parameters = new HashMap<>();
parameters.put("input", "{\"index\":[\"*\",\"logs\"],\"limit\":10,\"query\":\"test\"}");

Map<String, String> result = ToolUtils.extractInputParameters(parameters, null);

assertEquals("[\"*\",\"logs\"]", result.get("index"));
// Numbers are converted using gson.toJson() which gives "10.0" for integers parsed as doubles
assertTrue(
"Expected limit to be numeric string, got: " + result.get("limit"),
"10".equals(result.get("limit")) || "10.0".equals(result.get("limit"))
);
assertEquals("test", result.get("query"));
assertEquals("{\"index\":[\"*\",\"logs\"],\"limit\":10,\"query\":\"test\"}", result.get("input"));
}

@Test
public void testExtractInputParameters_ComplexArrays() {
// Test complex nested arrays and objects
Map<String, String> parameters = new HashMap<>();
parameters.put("input", "{\"indices\":[\"index1\",\"index2\"],\"filters\":{\"term\":\"value\"}}");

Map<String, String> result = ToolUtils.extractInputParameters(parameters, null);

assertEquals("[\"index1\",\"index2\"]", result.get("indices"));
assertEquals("{\"term\":\"value\"}", result.get("filters"));
}

@Test
public void testExtractInputParameters_InvalidJSON() {
// Test that invalid JSON still logs error and continues (existing behavior)
Map<String, String> parameters = new HashMap<>();
parameters.put("input", "{invalid json}");

Map<String, String> result = ToolUtils.extractInputParameters(parameters, null);

// Should still contain the input parameter even if parsing failed
assertEquals("{invalid json}", result.get("input"));
}

@Test
public void testExtractInputParameters_PlainString() {
// Test plain string input (existing behavior should be unchanged)
Map<String, String> parameters = new HashMap<>();
parameters.put("input", "plain string input");

Map<String, String> result = ToolUtils.extractInputParameters(parameters, null);

assertEquals("plain string input", result.get("input"));
}

@Test
public void testExtractInputParameters_JSONArray() {
// Test JSON array input (existing behavior should be unchanged)
Map<String, String> parameters = new HashMap<>();
parameters.put("input", "[\"item1\", \"item2\"]");

Map<String, String> result = ToolUtils.extractInputParameters(parameters, null);

assertEquals("[\"item1\", \"item2\"]", result.get("input"));
// Should not extract individual parameters from array
assertFalse(result.containsKey("item1"));
}
}
15 changes: 8 additions & 7 deletions ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -68,24 +68,25 @@ dependencies {
}
}

implementation platform('software.amazon.awssdk:bom:2.30.18')
api 'software.amazon.awssdk:auth:2.30.18'
implementation platform('software.amazon.awssdk:bom:2.32.31')
api 'software.amazon.awssdk:auth:2.32.31'
implementation 'software.amazon.awssdk:apache-client'
implementation 'software.amazon.awssdk:bedrockagentcore'
implementation ('com.amazonaws:aws-encryption-sdk-java:2.4.1') {
exclude group: 'org.bouncycastle', module: 'bcprov-ext-jdk18on'
}
// needed by aws-encryption-sdk-java
implementation "org.bouncycastle:bc-fips:${versions.bouncycastle_jce}"
compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: "2.30.18"
compileOnly group: 'software.amazon.awssdk', name: 's3', version: "2.30.18"
compileOnly group: 'software.amazon.awssdk', name: 'regions', version: "2.30.18"
compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: "2.32.31"
compileOnly group: 'software.amazon.awssdk', name: 's3', version: "2.32.31"
compileOnly group: 'software.amazon.awssdk', name: 'regions', version: "2.32.31"

implementation ('com.jayway.jsonpath:json-path:2.9.0') {
exclude group: 'net.minidev', module: 'json-smart'
}
implementation('net.minidev:json-smart:2.5.2')
implementation group: 'org.json', name: 'json', version: '20231013'
implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "2.30.18"
implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "2.32.31"
api('io.modelcontextprotocol.sdk:mcp:0.9.0')
testImplementation("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}")
testImplementation("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}")
Expand All @@ -99,7 +100,7 @@ lombok {
configurations.all {
resolutionStrategy.force 'com.google.protobuf:protobuf-java:3.25.5'
resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0'
resolutionStrategy.force 'software.amazon.awssdk:bom:2.30.18'
resolutionStrategy.force 'software.amazon.awssdk:bom:2.32.31'
}

jacocoTestReport {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,9 @@ public static String addToolsToPromptString(
toolsBuilder.append(toolsSuffix);
Map<String, String> toolsPromptMap = new HashMap<>();
toolsPromptMap.put(TOOL_DESCRIPTIONS, toolsBuilder.toString());
toolsPromptMap.put(TOOL_NAMES, toolNamesBuilder.substring(0, toolNamesBuilder.length() - 1));
// Fix: Handle empty toolNamesBuilder to prevent StringIndexOutOfBoundsException
String toolNames = toolNamesBuilder.length() > 0 ? toolNamesBuilder.substring(0, toolNamesBuilder.length() - 1) : "";
toolsPromptMap.put(TOOL_NAMES, toolNames);

if (parameters.containsKey(TOOL_DESCRIPTIONS)) {
toolsPromptMap.put(TOOL_DESCRIPTIONS, parameters.get(TOOL_DESCRIPTIONS));
Expand Down
Loading
Loading