|
5 | 5 |
|
6 | 6 | package org.opensearch.ml.action.memorycontainer.memory; |
7 | 7 |
|
| 8 | +import static org.junit.Assert.assertFalse; |
| 9 | +import static org.junit.Assert.assertNotEquals; |
8 | 10 | import static org.junit.Assert.assertTrue; |
9 | 11 | import static org.mockito.ArgumentMatchers.any; |
10 | 12 | import static org.mockito.ArgumentMatchers.eq; |
11 | 13 | import static org.mockito.Mockito.doAnswer; |
12 | 14 | import static org.mockito.Mockito.mock; |
13 | 15 | import static org.mockito.Mockito.verify; |
14 | 16 | import static org.mockito.Mockito.when; |
| 17 | +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.JSON_ENFORCEMENT_MESSAGE; |
| 18 | +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.USER_PREFERENCE_FACTS_EXTRACTION_PROMPT; |
| 19 | +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.USER_PREFERENCE_JSON_ENFORCEMENT_MESSAGE; |
15 | 20 | import static org.opensearch.ml.utils.TestHelper.createTestContent; |
16 | 21 |
|
17 | 22 | import java.util.ArrayList; |
@@ -980,4 +985,90 @@ public void testExtractFactsFromConversation_JsonEnforcementMessageAppended() { |
980 | 985 |
|
981 | 986 | verify(client).execute(any(), any(), any()); |
982 | 987 | } |
| 988 | + |
| 989 | + @Test |
| 990 | + public void testUserPreferencePromptFormat() { |
| 991 | + // Test that the new user preference prompt contains required elements |
| 992 | + String prompt = USER_PREFERENCE_FACTS_EXTRACTION_PROMPT; |
| 993 | + |
| 994 | + // Verify key improvements are present |
| 995 | + assertTrue("Should have character limit", prompt.contains("< 350 chars")); |
| 996 | + assertTrue("Should specify natural language format", prompt.contains("Context: <why/how>. Categories:")); |
| 997 | + assertTrue("Should contain example categories", prompt.contains("tools,tech,apps")); |
| 998 | + assertTrue("Should be role-based", prompt.contains("USER PREFERENCE EXTRACTOR")); |
| 999 | + |
| 1000 | + // Verify old problematic format is removed |
| 1001 | + assertFalse("Should not use pipe delimiters", prompt.contains("preference | context:")); |
| 1002 | + } |
| 1003 | + |
| 1004 | + @Test |
| 1005 | + public void testUserPreferenceEnforcementMessage() { |
| 1006 | + // Test that enforcement message matches the new format |
| 1007 | + String enforcement = USER_PREFERENCE_JSON_ENFORCEMENT_MESSAGE; |
| 1008 | + |
| 1009 | + assertTrue("Should specify natural language format", enforcement.contains("Context: <why/how>. Categories:")); |
| 1010 | + assertFalse("Should not use old pipe format", enforcement.contains("preference | context:")); |
| 1011 | + } |
| 1012 | + |
| 1013 | + @Test |
| 1014 | + public void testEnforcementMessageSelection() { |
| 1015 | + // Test that correct enforcement message is selected based on strategy type |
| 1016 | + MemoryStrategy userPrefStrategy = new MemoryStrategy( |
| 1017 | + "id", |
| 1018 | + true, |
| 1019 | + MemoryStrategyType.USER_PREFERENCE, |
| 1020 | + Arrays.asList("user_id"), |
| 1021 | + new HashMap<>() |
| 1022 | + ); |
| 1023 | + MemoryStrategy semanticStrategy = new MemoryStrategy( |
| 1024 | + "id", |
| 1025 | + true, |
| 1026 | + MemoryStrategyType.SEMANTIC, |
| 1027 | + Arrays.asList("user_id"), |
| 1028 | + new HashMap<>() |
| 1029 | + ); |
| 1030 | + |
| 1031 | + // This tests the logic in MemoryProcessingService.java lines 165-168 |
| 1032 | + // We can't easily test the private method, but we can verify the constants exist and are different |
| 1033 | + assertNotEquals( |
| 1034 | + "User preference and semantic should have different enforcement messages", |
| 1035 | + USER_PREFERENCE_JSON_ENFORCEMENT_MESSAGE, |
| 1036 | + JSON_ENFORCEMENT_MESSAGE |
| 1037 | + ); |
| 1038 | + |
| 1039 | + assertTrue( |
| 1040 | + "User preference enforcement should be for natural format", |
| 1041 | + USER_PREFERENCE_JSON_ENFORCEMENT_MESSAGE.contains("Context: <why/how>") |
| 1042 | + ); |
| 1043 | + assertTrue("Semantic enforcement should be for standard format", JSON_ENFORCEMENT_MESSAGE.contains("fact1")); |
| 1044 | + } |
| 1045 | + |
| 1046 | + @Test |
| 1047 | + public void testUserPreferenceExtractionScenarios() { |
| 1048 | + // Test various user preference extraction scenarios |
| 1049 | + String prompt = USER_PREFERENCE_FACTS_EXTRACTION_PROMPT; |
| 1050 | + |
| 1051 | + // Verify explicit preference handling |
| 1052 | + assertTrue("Should handle explicit preferences", prompt.contains("user states a preference")); |
| 1053 | + assertTrue("Should handle implicit preferences", prompt.contains("repeated choices")); |
| 1054 | + |
| 1055 | + // Verify format requirements |
| 1056 | + assertTrue("Should require JSON format", prompt.contains("{\"facts\":[")); |
| 1057 | + assertTrue("Should specify context format", prompt.contains("Context: <why/how>")); |
| 1058 | + assertTrue("Should limit character count", prompt.contains("< 350 chars")); |
| 1059 | + } |
| 1060 | + |
| 1061 | + @Test |
| 1062 | + public void testMultiTurnConversationHandling() { |
| 1063 | + // Test that prompt correctly handles multi-turn conversations |
| 1064 | + String prompt = USER_PREFERENCE_FACTS_EXTRACTION_PROMPT; |
| 1065 | + |
| 1066 | + // Verify assistant message handling |
| 1067 | + assertTrue("Should use assistant messages as context only", prompt.contains("Assistant messages are context only")); |
| 1068 | + assertTrue("Should extract from USER messages", prompt.contains("Extract preferences only from USER messages")); |
| 1069 | + |
| 1070 | + // Verify role clarity |
| 1071 | + assertTrue("Should not be a chat assistant", prompt.contains("not a chat assistant")); |
| 1072 | + assertTrue("Should only output JSON facts", prompt.contains("only job is to output JSON facts")); |
| 1073 | + } |
983 | 1074 | } |
0 commit comments