Skip to content
Closed

bug fix #4148

Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,13 @@ public static MLMemory parse(XContentParser parser) throws IOException {
lastUpdatedTime = Instant.ofEpochMilli(parser.longValue());
break;
case MEMORY_EMBEDDING_FIELD:
// Parse embedding as generic object (could be array or sparse map)
memoryEmbedding = parser.map();
if (parser.currentToken() == XContentParser.Token.START_ARRAY) {
memoryEmbedding = parser.list(); // Simple list parsing like ModelTensor
} else if (parser.currentToken() == XContentParser.Token.START_OBJECT) {
memoryEmbedding = parser.map(); // For sparse embeddings
} else {
parser.skipChildren();
}
break;
default:
parser.skipChildren();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.io.IOException;
import java.time.Instant;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.junit.Before;
Expand Down Expand Up @@ -442,4 +443,192 @@ public void testSpecialCharactersInFields() throws IOException {
assertEquals(specialMemory.getRole(), parsed.getRole());
assertEquals(specialMemory.getTags(), parsed.getTags());
}

@Test
public void testParseWithDenseEmbeddingArray() throws IOException {
String jsonString = "{"
+ "\"session_id\":\"session-123\","
+ "\"memory\":\"Test memory with dense embedding\","
+ "\"memory_type\":\"FACT\","
+ "\"created_time\":"
+ testCreatedTime.toEpochMilli()
+ ","
+ "\"last_updated_time\":"
+ testUpdatedTime.toEpochMilli()
+ ","
+ "\"memory_embedding\":[0.1,0.2,0.3,0.4]"
+ "}";

XContentParser parser = XContentType.JSON
.xContent()
.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString);
parser.nextToken();

MLMemory parsed = MLMemory.parse(parser);

assertEquals("session-123", parsed.getSessionId());
assertEquals("Test memory with dense embedding", parsed.getMemory());
assertNotNull(parsed.getMemoryEmbedding());
assertTrue(parsed.getMemoryEmbedding() instanceof List);
@SuppressWarnings("unchecked")
List<Double> embedding = (List<Double>) parsed.getMemoryEmbedding();
assertEquals(4, embedding.size());
assertEquals(0.1, embedding.get(0), 0.001);
assertEquals(0.2, embedding.get(1), 0.001);
assertEquals(0.3, embedding.get(2), 0.001);
assertEquals(0.4, embedding.get(3), 0.001);
}

@Test
public void testParseWithSparseEmbeddingObject() throws IOException {
String jsonString = "{"
+ "\"session_id\":\"session-456\","
+ "\"memory\":\"Test memory with sparse embedding\","
+ "\"memory_type\":\"FACT\","
+ "\"created_time\":"
+ testCreatedTime.toEpochMilli()
+ ","
+ "\"last_updated_time\":"
+ testUpdatedTime.toEpochMilli()
+ ","
+ "\"memory_embedding\":{\"token1\":0.5,\"token2\":0.8,\"token3\":0.2}"
+ "}";

XContentParser parser = XContentType.JSON
.xContent()
.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString);
parser.nextToken();

MLMemory parsed = MLMemory.parse(parser);

assertEquals("session-456", parsed.getSessionId());
assertEquals("Test memory with sparse embedding", parsed.getMemory());
assertNotNull(parsed.getMemoryEmbedding());
assertTrue(parsed.getMemoryEmbedding() instanceof Map);
@SuppressWarnings("unchecked")
Map<String, Object> embedding = (Map<String, Object>) parsed.getMemoryEmbedding();
assertEquals(3, embedding.size());
assertEquals(0.5, embedding.get("token1"));
assertEquals(0.8, embedding.get("token2"));
assertEquals(0.2, embedding.get("token3"));
}

@Test
public void testParseWithWrappedDenseEmbedding() throws IOException {
String jsonString = "{"
+ "\"session_id\":\"session-789\","
+ "\"memory\":\"Test memory with wrapped dense embedding\","
+ "\"memory_type\":\"FACT\","
+ "\"created_time\":"
+ testCreatedTime.toEpochMilli()
+ ","
+ "\"last_updated_time\":"
+ testUpdatedTime.toEpochMilli()
+ ","
+ "\"memory_embedding\":{\"values\":[0.1,0.2,0.3]}"
+ "}";

XContentParser parser = XContentType.JSON
.xContent()
.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString);
parser.nextToken();

MLMemory parsed = MLMemory.parse(parser);

assertEquals("session-789", parsed.getSessionId());
assertEquals("Test memory with wrapped dense embedding", parsed.getMemory());
assertNotNull(parsed.getMemoryEmbedding());
assertTrue(parsed.getMemoryEmbedding() instanceof Map);
@SuppressWarnings("unchecked")
Map<String, Object> embedding = (Map<String, Object>) parsed.getMemoryEmbedding();
assertTrue(embedding.containsKey("values"));
assertTrue(embedding.get("values") instanceof List);
}

@Test
public void testParseWithEmptyEmbeddingArray() throws IOException {
String jsonString = "{"
+ "\"session_id\":\"session-empty\","
+ "\"memory\":\"Test memory with empty embedding\","
+ "\"memory_type\":\"FACT\","
+ "\"created_time\":"
+ testCreatedTime.toEpochMilli()
+ ","
+ "\"last_updated_time\":"
+ testUpdatedTime.toEpochMilli()
+ ","
+ "\"memory_embedding\":[]"
+ "}";

XContentParser parser = XContentType.JSON
.xContent()
.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString);
parser.nextToken();

MLMemory parsed = MLMemory.parse(parser);

assertEquals("session-empty", parsed.getSessionId());
assertNotNull(parsed.getMemoryEmbedding());
assertTrue(parsed.getMemoryEmbedding() instanceof List);
@SuppressWarnings("unchecked")
List<?> embedding = (List<?>) parsed.getMemoryEmbedding();
assertTrue(embedding.isEmpty());
}

@Test
public void testParseWithEmptyEmbeddingObject() throws IOException {
String jsonString = "{"
+ "\"session_id\":\"session-empty-obj\","
+ "\"memory\":\"Test memory with empty embedding object\","
+ "\"memory_type\":\"FACT\","
+ "\"created_time\":"
+ testCreatedTime.toEpochMilli()
+ ","
+ "\"last_updated_time\":"
+ testUpdatedTime.toEpochMilli()
+ ","
+ "\"memory_embedding\":{}"
+ "}";

XContentParser parser = XContentType.JSON
.xContent()
.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString);
parser.nextToken();

MLMemory parsed = MLMemory.parse(parser);

assertEquals("session-empty-obj", parsed.getSessionId());
assertNotNull(parsed.getMemoryEmbedding());
assertTrue(parsed.getMemoryEmbedding() instanceof Map);
@SuppressWarnings("unchecked")
Map<?, ?> embedding = (Map<?, ?>) parsed.getMemoryEmbedding();
assertTrue(embedding.isEmpty());
}

@Test
public void testParseWithInvalidEmbeddingType() throws IOException {
String jsonString = "{"
+ "\"session_id\":\"session-invalid\","
+ "\"memory\":\"Test memory with invalid embedding\","
+ "\"memory_type\":\"FACT\","
+ "\"created_time\":"
+ testCreatedTime.toEpochMilli()
+ ","
+ "\"last_updated_time\":"
+ testUpdatedTime.toEpochMilli()
+ ","
+ "\"memory_embedding\":\"invalid_string_embedding\""
+ "}";

XContentParser parser = XContentType.JSON
.xContent()
.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString);
parser.nextToken();

MLMemory parsed = MLMemory.parse(parser);

assertEquals("session-invalid", parsed.getSessionId());
// Should gracefully handle invalid embedding type by skipping it
assertNull(parsed.getMemoryEmbedding());
}
}
Loading
Loading