Skip to content

Commit 521214b

Browse files
2.x consistent get interactions (#1334) (#1335)
* consistent getInteractions response when security/no security Signed-off-by: HenryL27 <[email protected]> * fix deletion race condition Signed-off-by: HenryL27 <[email protected]> * cleanup Signed-off-by: HenryL27 <[email protected]> --------- Signed-off-by: HenryL27 <[email protected]> (cherry picked from commit 24a629b) Co-authored-by: HenryL27 <[email protected]>
1 parent 63fe2c7 commit 521214b

File tree

5 files changed

+50
-23
lines changed

5 files changed

+50
-23
lines changed

memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java

+12-6
Original file line numberDiff line numberDiff line change
@@ -270,17 +270,17 @@ public void checkAccess(String conversationId, ActionListener<Boolean> listener)
270270
String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
271271
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
272272
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
273-
// If security is off - User doesn't exist - you have permission
274-
if (userstr == null || User.parse(userstr) == null) {
275-
internalListener.onResponse(true);
276-
return;
277-
}
278273
GetRequest getRequest = Requests.getRequest(indexName).id(conversationId);
279274
ActionListener<GetResponse> al = ActionListener.wrap(getResponse -> {
280275
// If the conversation doesn't exist, fail
281276
if (!(getResponse.isExists() && getResponse.getId().equals(conversationId))) {
282277
throw new ResourceNotFoundException("Conversation [" + conversationId + "] not found");
283278
}
279+
// If security is off - User doesn't exist - you have permission
280+
if (userstr == null || User.parse(userstr) == null) {
281+
internalListener.onResponse(true);
282+
return;
283+
}
284284
ConversationMeta conversation = ConversationMeta.fromMap(conversationId, getResponse.getSourceAsMap());
285285
String user = User.parse(userstr).getName();
286286
// If you're not the owner of this conversation, you do not have permission
@@ -290,7 +290,13 @@ public void checkAccess(String conversationId, ActionListener<Boolean> listener)
290290
}
291291
internalListener.onResponse(true);
292292
}, e -> { internalListener.onFailure(e); });
293-
client.get(getRequest, al);
293+
client
294+
.admin()
295+
.indices()
296+
.refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> { client.get(getRequest, al); }, e -> {
297+
log.error("Failed to refresh conversations index during check access ", e);
298+
internalListener.onFailure(e);
299+
}));
294300
} catch (Exception e) {
295301
listener.onFailure(e);
296302
}

memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java

+14-5
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,12 @@
3333

3434
import com.google.common.annotations.VisibleForTesting;
3535

36+
import lombok.extern.log4j.Log4j2;
37+
3638
/**
3739
* Class for handling all Conversational Memory operactions
3840
*/
41+
@Log4j2
3942
public class OpenSearchConversationalMemoryHandler implements ConversationalMemoryHandler {
4043

4144
private ConversationMetaIndex conversationMetaIndex;
@@ -247,19 +250,25 @@ public ActionFuture<List<ConversationMeta>> getConversations(int maxResults) {
247250
public void deleteConversation(String conversationId, ActionListener<Boolean> listener) {
248251
StepListener<Boolean> accessListener = new StepListener<>();
249252
conversationMetaIndex.checkAccess(conversationId, accessListener);
250-
253+
log.info("DELETING CONVERSATION " + conversationId);
251254
accessListener.whenComplete(access -> {
252255
if (access) {
253256
StepListener<Boolean> metaDeleteListener = new StepListener<>();
254257
StepListener<Boolean> interactionsListener = new StepListener<>();
255258

256-
conversationMetaIndex.deleteConversation(conversationId, metaDeleteListener);
257259
interactionsIndex.deleteConversation(conversationId, interactionsListener);
258260

259-
metaDeleteListener.whenComplete(metaResult -> {
260-
interactionsListener
261-
.whenComplete(interactionResult -> { listener.onResponse(metaResult && interactionResult); }, listener::onFailure);
261+
interactionsListener
262+
.whenComplete(
263+
interactionResult -> { conversationMetaIndex.deleteConversation(conversationId, metaDeleteListener); },
264+
listener::onFailure
265+
);
266+
267+
metaDeleteListener.whenComplete(metaDeleteResult -> {
268+
log.info("SUCCESSFUL DELETION OF CONVERSATION " + conversationId);
269+
listener.onResponse(metaDeleteResult && interactionsListener.result());
262270
}, listener::onFailure);
271+
263272
} else {
264273
listener.onResponse(false);
265274
}

memory/src/test/java/org/opensearch/ml/memory/ConversationalMemoryHandlerITTests.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -249,16 +249,18 @@ public void testCanDeleteConversations() {
249249
});
250250

251251
StepListener<List<Interaction>> inters2 = new StepListener<>();
252-
inters1.whenComplete(ints -> { cmHandler.getInteractions(cid2.result(), 0, 10, inters2); }, e -> {
252+
inters1.whenComplete(ints -> {
253253
cdl.countDown();
254254
assert (false);
255+
}, e -> {
256+
assert (e.getMessage().startsWith("Conversation ["));
257+
cmHandler.getInteractions(cid2.result(), 0, 10, inters2);
255258
});
256259

257260
LatchedActionListener<List<Interaction>> finishAndAssert = new LatchedActionListener<>(ActionListener.wrap(r -> {
258261
assert (del.result());
259262
assert (conversations.result().size() == 1);
260263
assert (conversations.result().get(0).getId().equals(cid2.result()));
261-
assert (inters1.result().size() == 0);
262264
assert (inters2.result().size() == 1);
263265
assert (inters2.result().get(0).getId().equals(iid3.result()));
264266
}, e -> { assert (false); }), cdl);

memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java

+4
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ public void testDelete_DeleteFails_ThenFail() {
404404

405405
public void testCheckAccess_DoesNotExist_ThenFail() {
406406
setupUser("user");
407+
setupRefreshSuccess();
407408
doReturn(true).when(metadata).hasIndex(anyString());
408409
GetResponse response = mock(GetResponse.class);
409410
doReturn(false).when(response).isExists();
@@ -423,6 +424,7 @@ public void testCheckAccess_DoesNotExist_ThenFail() {
423424

424425
public void testCheckAccess_WrongId_ThenFail() {
425426
setupUser("user");
427+
setupRefreshSuccess();
426428
doReturn(true).when(metadata).hasIndex(anyString());
427429
GetResponse response = mock(GetResponse.class);
428430
doReturn(true).when(response).isExists();
@@ -443,6 +445,7 @@ public void testCheckAccess_WrongId_ThenFail() {
443445

444446
public void testCheckAccess_GetFails_ThenFail() {
445447
setupUser("user");
448+
setupRefreshSuccess();
446449
doReturn(true).when(metadata).hasIndex(anyString());
447450
doAnswer(invocation -> {
448451
ActionListener<GetResponse> al = invocation.getArgument(1);
@@ -459,6 +462,7 @@ public void testCheckAccess_GetFails_ThenFail() {
459462

460463
public void testCheckAccess_ClientFails_ThenFail() {
461464
setupUser("user");
465+
setupRefreshSuccess();
462466
doReturn(true).when(metadata).hasIndex(anyString());
463467
doThrow(new RuntimeException("Client Test Fail")).when(client).get(any(), any());
464468
@SuppressWarnings("unchecked")

plugin/src/test/java/org/opensearch/ml/rest/RestMemoryDeleteConversationActionIT.java

+16-10
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.apache.http.message.BasicHeader;
2727
import org.junit.Before;
2828
import org.opensearch.client.Response;
29+
import org.opensearch.client.ResponseException;
2930
import org.opensearch.core.rest.RestStatus;
3031
import org.opensearch.ml.common.conversation.ActionConstants;
3132
import org.opensearch.ml.settings.MLCommonsSettings;
@@ -163,15 +164,20 @@ public void testDeleteConversation_WithInteractions() throws IOException {
163164
assert (!gcmap.containsKey("next_token"));
164165
assert (((ArrayList) gcmap.get("conversations")).size() == 0);
165166

166-
Response giresponse = TestHelper
167-
.makeRequest(client(), "GET", ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), null, "", null);
168-
assert (giresponse != null);
169-
assert (TestHelper.restStatus(giresponse) == RestStatus.OK);
170-
HttpEntity gihttpEntity = giresponse.getEntity();
171-
String gientityString = TestHelper.httpEntityToString(gihttpEntity);
172-
Map gimap = gson.fromJson(gientityString, Map.class);
173-
assert (gimap.containsKey("interactions"));
174-
assert (!gimap.containsKey("next_token"));
175-
assert (((ArrayList) gimap.get("interactions")).size() == 0);
167+
try {
168+
Response giresponse = TestHelper
169+
.makeRequest(client(), "GET", ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), null, "", null);
170+
assert (giresponse != null);
171+
assert (TestHelper.restStatus(giresponse) == RestStatus.OK);
172+
HttpEntity gihttpEntity = giresponse.getEntity();
173+
String gientityString = TestHelper.httpEntityToString(gihttpEntity);
174+
Map gimap = gson.fromJson(gientityString, Map.class);
175+
assert (gimap.containsKey("interactions"));
176+
assert (!gimap.containsKey("next_token"));
177+
assert (((ArrayList) gimap.get("interactions")).size() == 0);
178+
assert (false);
179+
} catch (ResponseException e) {
180+
assert (TestHelper.restStatus(e.getResponse()) == RestStatus.NOT_FOUND);
181+
}
176182
}
177183
}

0 commit comments

Comments
 (0)