diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MLMemory.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MLMemory.java index b3b0cff6c2..09a649f052 100644 --- a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MLMemory.java +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MLMemory.java @@ -203,8 +203,13 @@ public static MLMemory parse(XContentParser parser) throws IOException { lastUpdatedTime = Instant.ofEpochMilli(parser.longValue()); break; case MEMORY_EMBEDDING_FIELD: - // Parse embedding as generic object (could be array or sparse map) - memoryEmbedding = parser.map(); + if (parser.currentToken() == XContentParser.Token.START_ARRAY) { + memoryEmbedding = parser.list(); // Simple list parsing like ModelTensor + } else if (parser.currentToken() == XContentParser.Token.START_OBJECT) { + memoryEmbedding = parser.map(); // For sparse embeddings + } else { + parser.skipChildren(); + } break; default: parser.skipChildren(); diff --git a/common/src/test/java/org/opensearch/ml/common/memorycontainer/MLMemoryTest.java b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MLMemoryTest.java index 4aa1d5a137..09978d0a7a 100644 --- a/common/src/test/java/org/opensearch/ml/common/memorycontainer/MLMemoryTest.java +++ b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MLMemoryTest.java @@ -14,6 +14,7 @@ import java.io.IOException; import java.time.Instant; import java.util.HashMap; +import java.util.List; import java.util.Map; import org.junit.Before; @@ -442,4 +443,192 @@ public void testSpecialCharactersInFields() throws IOException { assertEquals(specialMemory.getRole(), parsed.getRole()); assertEquals(specialMemory.getTags(), parsed.getTags()); } + + @Test + public void testParseWithDenseEmbeddingArray() throws IOException { + String jsonString = "{" + + "\"session_id\":\"session-123\"," + + "\"memory\":\"Test memory with dense embedding\"," + + "\"memory_type\":\"FACT\"," + + "\"created_time\":" + + testCreatedTime.toEpochMilli() + + "," + + "\"last_updated_time\":" + + testUpdatedTime.toEpochMilli() + + "," + + "\"memory_embedding\":[0.1,0.2,0.3,0.4]" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLMemory parsed = MLMemory.parse(parser); + + assertEquals("session-123", parsed.getSessionId()); + assertEquals("Test memory with dense embedding", parsed.getMemory()); + assertNotNull(parsed.getMemoryEmbedding()); + assertTrue(parsed.getMemoryEmbedding() instanceof List); + @SuppressWarnings("unchecked") + List embedding = (List) parsed.getMemoryEmbedding(); + assertEquals(4, embedding.size()); + assertEquals(0.1, embedding.get(0), 0.001); + assertEquals(0.2, embedding.get(1), 0.001); + assertEquals(0.3, embedding.get(2), 0.001); + assertEquals(0.4, embedding.get(3), 0.001); + } + + @Test + public void testParseWithSparseEmbeddingObject() throws IOException { + String jsonString = "{" + + "\"session_id\":\"session-456\"," + + "\"memory\":\"Test memory with sparse embedding\"," + + "\"memory_type\":\"FACT\"," + + "\"created_time\":" + + testCreatedTime.toEpochMilli() + + "," + + "\"last_updated_time\":" + + testUpdatedTime.toEpochMilli() + + "," + + "\"memory_embedding\":{\"token1\":0.5,\"token2\":0.8,\"token3\":0.2}" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLMemory parsed = MLMemory.parse(parser); + + assertEquals("session-456", parsed.getSessionId()); + assertEquals("Test memory with sparse embedding", parsed.getMemory()); + assertNotNull(parsed.getMemoryEmbedding()); + assertTrue(parsed.getMemoryEmbedding() instanceof Map); + @SuppressWarnings("unchecked") + Map embedding = (Map) parsed.getMemoryEmbedding(); + assertEquals(3, embedding.size()); + assertEquals(0.5, embedding.get("token1")); + assertEquals(0.8, embedding.get("token2")); + assertEquals(0.2, embedding.get("token3")); + } + + @Test + public void testParseWithWrappedDenseEmbedding() throws IOException { + String jsonString = "{" + + "\"session_id\":\"session-789\"," + + "\"memory\":\"Test memory with wrapped dense embedding\"," + + "\"memory_type\":\"FACT\"," + + "\"created_time\":" + + testCreatedTime.toEpochMilli() + + "," + + "\"last_updated_time\":" + + testUpdatedTime.toEpochMilli() + + "," + + "\"memory_embedding\":{\"values\":[0.1,0.2,0.3]}" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLMemory parsed = MLMemory.parse(parser); + + assertEquals("session-789", parsed.getSessionId()); + assertEquals("Test memory with wrapped dense embedding", parsed.getMemory()); + assertNotNull(parsed.getMemoryEmbedding()); + assertTrue(parsed.getMemoryEmbedding() instanceof Map); + @SuppressWarnings("unchecked") + Map embedding = (Map) parsed.getMemoryEmbedding(); + assertTrue(embedding.containsKey("values")); + assertTrue(embedding.get("values") instanceof List); + } + + @Test + public void testParseWithEmptyEmbeddingArray() throws IOException { + String jsonString = "{" + + "\"session_id\":\"session-empty\"," + + "\"memory\":\"Test memory with empty embedding\"," + + "\"memory_type\":\"FACT\"," + + "\"created_time\":" + + testCreatedTime.toEpochMilli() + + "," + + "\"last_updated_time\":" + + testUpdatedTime.toEpochMilli() + + "," + + "\"memory_embedding\":[]" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLMemory parsed = MLMemory.parse(parser); + + assertEquals("session-empty", parsed.getSessionId()); + assertNotNull(parsed.getMemoryEmbedding()); + assertTrue(parsed.getMemoryEmbedding() instanceof List); + @SuppressWarnings("unchecked") + List embedding = (List) parsed.getMemoryEmbedding(); + assertTrue(embedding.isEmpty()); + } + + @Test + public void testParseWithEmptyEmbeddingObject() throws IOException { + String jsonString = "{" + + "\"session_id\":\"session-empty-obj\"," + + "\"memory\":\"Test memory with empty embedding object\"," + + "\"memory_type\":\"FACT\"," + + "\"created_time\":" + + testCreatedTime.toEpochMilli() + + "," + + "\"last_updated_time\":" + + testUpdatedTime.toEpochMilli() + + "," + + "\"memory_embedding\":{}" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLMemory parsed = MLMemory.parse(parser); + + assertEquals("session-empty-obj", parsed.getSessionId()); + assertNotNull(parsed.getMemoryEmbedding()); + assertTrue(parsed.getMemoryEmbedding() instanceof Map); + @SuppressWarnings("unchecked") + Map embedding = (Map) parsed.getMemoryEmbedding(); + assertTrue(embedding.isEmpty()); + } + + @Test + public void testParseWithInvalidEmbeddingType() throws IOException { + String jsonString = "{" + + "\"session_id\":\"session-invalid\"," + + "\"memory\":\"Test memory with invalid embedding\"," + + "\"memory_type\":\"FACT\"," + + "\"created_time\":" + + testCreatedTime.toEpochMilli() + + "," + + "\"last_updated_time\":" + + testUpdatedTime.toEpochMilli() + + "," + + "\"memory_embedding\":\"invalid_string_embedding\"" + + "}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString); + parser.nextToken(); + + MLMemory parsed = MLMemory.parse(parser); + + assertEquals("session-invalid", parsed.getSessionId()); + // Should gracefully handle invalid embedding type by skipping it + assertNull(parsed.getMemoryEmbedding()); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/TransportGetMemoryActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/TransportGetMemoryActionTests.java index 3054ba54c5..d10ee95796 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/TransportGetMemoryActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/TransportGetMemoryActionTests.java @@ -437,6 +437,74 @@ private String createMemoryJson() { + "}"; } + private String createMemoryJsonWithDenseEmbedding() { + long currentTimeEpoch = System.currentTimeMillis(); + return "{" + + "\"session_id\":\"test-session\"," + + "\"memory\":\"Test memory with dense embedding\"," + + "\"memory_type\":\"RAW_MESSAGE\"," + + "\"user_id\":\"test-user\"," + + "\"created_time\":" + + currentTimeEpoch + + "," + + "\"last_updated_time\":" + + currentTimeEpoch + + "," + + "\"memory_embedding\":[0.1,0.2,0.3,0.4]" + + "}"; + } + + private String createMemoryJsonWithSparseEmbedding() { + long currentTimeEpoch = System.currentTimeMillis(); + return "{" + + "\"session_id\":\"test-session\"," + + "\"memory\":\"Test memory with sparse embedding\"," + + "\"memory_type\":\"RAW_MESSAGE\"," + + "\"user_id\":\"test-user\"," + + "\"created_time\":" + + currentTimeEpoch + + "," + + "\"last_updated_time\":" + + currentTimeEpoch + + "," + + "\"memory_embedding\":{\"token1\":0.5,\"token2\":0.8,\"token3\":0.2}" + + "}"; + } + + private String createMemoryJsonWithWrappedEmbedding() { + long currentTimeEpoch = System.currentTimeMillis(); + return "{" + + "\"session_id\":\"test-session\"," + + "\"memory\":\"Test memory with wrapped embedding\"," + + "\"memory_type\":\"RAW_MESSAGE\"," + + "\"user_id\":\"test-user\"," + + "\"created_time\":" + + currentTimeEpoch + + "," + + "\"last_updated_time\":" + + currentTimeEpoch + + "," + + "\"memory_embedding\":{\"values\":[0.1,0.2,0.3]}" + + "}"; + } + + private String createMemoryJsonWithInvalidEmbedding() { + long currentTimeEpoch = System.currentTimeMillis(); + return "{" + + "\"session_id\":\"test-session\"," + + "\"memory\":\"Test memory with invalid embedding\"," + + "\"memory_type\":\"RAW_MESSAGE\"," + + "\"user_id\":\"test-user\"," + + "\"created_time\":" + + currentTimeEpoch + + "," + + "\"last_updated_time\":" + + currentTimeEpoch + + "," + + "\"memory_embedding\":\"invalid_string_embedding\"" + + "}"; + } + @Test public void testDoExecuteWithFeatureDisabled() { // Setup feature flag to be disabled @@ -461,4 +529,149 @@ public void testDoExecuteWithFeatureDisabled() { // Verify that no other operations were attempted verify(memoryContainerHelper, never()).getMemoryContainer(any(String.class), any(ActionListener.class)); } + + @Test + public void testDoExecuteWithDenseEmbeddingArray() { + // Setup request + MLGetMemoryRequest getRequest = new MLGetMemoryRequest(MEMORY_CONTAINER_ID, MEMORY_ID); + + // Setup memory container helper to return container + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(testMemoryContainer); + return null; + }).when(memoryContainerHelper).getMemoryContainer(any(String.class), any(ActionListener.class)); + + // Setup client to return response with dense embedding as array + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + org.opensearch.action.get.GetResponse getResponse = mock(org.opensearch.action.get.GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSourceAsString()).thenReturn(createMemoryJsonWithDenseEmbedding()); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(org.opensearch.action.get.GetRequest.class), any(ActionListener.class)); + + // Execute + action.doExecute(task, getRequest, actionListener); + + // Verify successful response + ArgumentCaptor responseCaptor = forClass(MLGetMemoryResponse.class); + verify(actionListener).onResponse(responseCaptor.capture()); + + MLGetMemoryResponse capturedResponse = responseCaptor.getValue(); + assertNotNull(capturedResponse); + MLMemory returnedMemory = capturedResponse.getMlMemory(); + assertNotNull(returnedMemory); + assertNotNull(returnedMemory.getMemoryEmbedding()); + } + + @Test + public void testDoExecuteWithSparseEmbeddingObject() { + // Setup request + MLGetMemoryRequest getRequest = new MLGetMemoryRequest(MEMORY_CONTAINER_ID, MEMORY_ID); + + // Setup memory container helper to return container + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(testMemoryContainer); + return null; + }).when(memoryContainerHelper).getMemoryContainer(any(String.class), any(ActionListener.class)); + + // Setup client to return response with sparse embedding as object + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + org.opensearch.action.get.GetResponse getResponse = mock(org.opensearch.action.get.GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSourceAsString()).thenReturn(createMemoryJsonWithSparseEmbedding()); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(org.opensearch.action.get.GetRequest.class), any(ActionListener.class)); + + // Execute + action.doExecute(task, getRequest, actionListener); + + // Verify successful response + ArgumentCaptor responseCaptor = forClass(MLGetMemoryResponse.class); + verify(actionListener).onResponse(responseCaptor.capture()); + + MLGetMemoryResponse capturedResponse = responseCaptor.getValue(); + assertNotNull(capturedResponse); + MLMemory returnedMemory = capturedResponse.getMlMemory(); + assertNotNull(returnedMemory); + assertNotNull(returnedMemory.getMemoryEmbedding()); + } + + @Test + public void testDoExecuteWithWrappedDenseEmbedding() { + // Setup request + MLGetMemoryRequest getRequest = new MLGetMemoryRequest(MEMORY_CONTAINER_ID, MEMORY_ID); + + // Setup memory container helper to return container + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(testMemoryContainer); + return null; + }).when(memoryContainerHelper).getMemoryContainer(any(String.class), any(ActionListener.class)); + + // Setup client to return response with wrapped dense embedding + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + org.opensearch.action.get.GetResponse getResponse = mock(org.opensearch.action.get.GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSourceAsString()).thenReturn(createMemoryJsonWithWrappedEmbedding()); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(org.opensearch.action.get.GetRequest.class), any(ActionListener.class)); + + // Execute + action.doExecute(task, getRequest, actionListener); + + // Verify successful response + ArgumentCaptor responseCaptor = forClass(MLGetMemoryResponse.class); + verify(actionListener).onResponse(responseCaptor.capture()); + + MLGetMemoryResponse capturedResponse = responseCaptor.getValue(); + assertNotNull(capturedResponse); + MLMemory returnedMemory = capturedResponse.getMlMemory(); + assertNotNull(returnedMemory); + assertNotNull(returnedMemory.getMemoryEmbedding()); + } + + @Test + public void testDoExecuteWithInvalidEmbeddingType() { + // Setup request + MLGetMemoryRequest getRequest = new MLGetMemoryRequest(MEMORY_CONTAINER_ID, MEMORY_ID); + + // Setup memory container helper to return container + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(testMemoryContainer); + return null; + }).when(memoryContainerHelper).getMemoryContainer(any(String.class), any(ActionListener.class)); + + // Setup client to return response with invalid embedding type (string) + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + org.opensearch.action.get.GetResponse getResponse = mock(org.opensearch.action.get.GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + when(getResponse.getSourceAsString()).thenReturn(createMemoryJsonWithInvalidEmbedding()); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(org.opensearch.action.get.GetRequest.class), any(ActionListener.class)); + + // Execute + action.doExecute(task, getRequest, actionListener); + + // Verify successful response (should gracefully handle invalid embedding) + ArgumentCaptor responseCaptor = forClass(MLGetMemoryResponse.class); + verify(actionListener).onResponse(responseCaptor.capture()); + + MLGetMemoryResponse capturedResponse = responseCaptor.getValue(); + assertNotNull(capturedResponse); + MLMemory returnedMemory = capturedResponse.getMlMemory(); + assertNotNull(returnedMemory); + // Invalid embedding should be null (gracefully handled) + assertNull(returnedMemory.getMemoryEmbedding()); + } }