Skip to content

Commit 3b26b7a

Browse files
committed
Mock ChatHistoryManager (in a friendly banter sort of way)
1 parent ff02d76 commit 3b26b7a

File tree

5 files changed

+142
-4
lines changed

5 files changed

+142
-4
lines changed

llm-service/app/tests/provider_mocks/bedrock.py renamed to llm-service/app/tests/model_provider_mocks/bedrock.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@
4848
from app.config import settings
4949
from app.services.caii.types import ModelResponse
5050
from app.services.models.providers import BedrockModelProvider
51+
from .testing_chat_history_manager import (
52+
patch_get_chat_history_manager,
53+
TestingChatHistoryManager,
54+
)
5155
from .utils import patch_env_vars
5256

5357
TEXT_MODELS = [
@@ -149,7 +153,7 @@ def mock_make_api_call(self, operation_name: str, api_params: dict[str, str]):
149153
return patch("botocore.client.BaseClient._make_api_call", new=mock_make_api_call)
150154

151155

152-
@pytest.fixture
156+
@pytest.fixture(autouse=True)
153157
def mock_bedrock(monkeypatch) -> Iterator[None]:
154158
with patch_env_vars(BedrockModelProvider):
155159
with (
@@ -170,8 +174,8 @@ def mock_bedrock(monkeypatch) -> Iterator[None]:
170174
yield
171175

172176

173-
# TODO: move this test function to a discoverable place
174-
def test_bedrock(mock_bedrock, client) -> None:
177+
# TODO: move test functions to a discoverable place
178+
def test_bedrock_models(client) -> None:
175179
response = client.get("/llm-service/models/model_source")
176180
assert response.status_code == 200
177181
assert response.json() == "Bedrock"
@@ -199,3 +203,38 @@ def test_bedrock(mock_bedrock, client) -> None:
199203
for model_id, availability in RERANKING_MODELS
200204
if availability == "AVAILABLE"
201205
]
206+
207+
# response = client.get("/llm-service/models/embedding/cohere.embed-english-v3/test")
208+
209+
210+
def test_bedrock_sessions(client) -> None:
211+
session_id = 1
212+
with patch_get_chat_history_manager() as get_testing_chat_history_manager:
213+
chat_history = get_testing_chat_history_manager().retrieve_chat_history(
214+
session_id=session_id
215+
)
216+
217+
response = client.get(f"/llm-service/sessions/{session_id}/chat-history")
218+
assert response.status_code == 200
219+
assert response.json()["data"] == [msg.model_dump() for msg in chat_history]
220+
221+
msg = chat_history[0] # TODO: randomize?
222+
response = client.get(
223+
f"/llm-service/sessions/{msg.session_id}/chat-history/{msg.id}",
224+
)
225+
assert response.status_code == 200
226+
assert response.json() == msg.model_dump()
227+
228+
# TODO: maybe call the chat endpoint and see if history changes
229+
230+
# response = client.post("/llm-service/sessions/1/rename-session")
231+
# assert response.status_code == 200
232+
233+
234+
# def test_bedrock_chat(client) -> None:
235+
# response = client.post("/llm-service/sessions/suggest-questions")
236+
# print(f"{response.json()=}")
237+
# assert response.status_code == 200
238+
#
239+
# response = client.post("/llm-service/sessions/1/stream-completion")
240+
# assert response.status_code == 200
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#
2+
# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
3+
# (C) Cloudera, Inc. 2025
4+
# All rights reserved.
5+
#
6+
# Applicable Open Source License: Apache 2.0
7+
#
8+
# NOTE: Cloudera open source products are modular software products
9+
# made up of hundreds of individual components, each of which was
10+
# individually copyrighted. Each Cloudera open source product is a
11+
# collective work under U.S. Copyright Law. Your license to use the
12+
# collective work is as provided in your written agreement with
13+
# Cloudera. Used apart from the collective work, this file is
14+
# licensed for your use pursuant to the open source license
15+
# identified above.
16+
#
17+
# This code is provided to you pursuant a written agreement with
18+
# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
19+
# this code. If you do not have a written agreement with Cloudera nor
20+
# with an authorized and properly licensed third party, you do not
21+
# have any rights to access nor to use this code.
22+
#
23+
# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the
24+
# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
25+
# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
26+
# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
27+
# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
28+
# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
29+
# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
30+
# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
31+
# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
32+
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
33+
# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
34+
# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
35+
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
36+
# DATA.
37+
#
38+
import time
39+
import uuid
40+
from contextlib import AbstractContextManager
41+
from typing import Callable
42+
from unittest.mock import patch
43+
44+
from alembic.testing.fixtures import testing_config
45+
46+
from app.services.chat_history.chat_history_manager import (
47+
ChatHistoryManager,
48+
RagStudioChatMessage,
49+
RagMessage,
50+
)
51+
from app.services.models import get_provider_class
52+
53+
54+
class TestingChatHistoryManager(ChatHistoryManager):
55+
def __init__(self):
56+
self._chat_history: dict[int, list[RagStudioChatMessage]] = dict()
57+
58+
def retrieve_chat_history(self, session_id: int) -> list[RagStudioChatMessage]:
59+
return self._chat_history.get(session_id, [])
60+
61+
def clear_chat_history(self, session_id: int) -> None:
62+
self._chat_history[session_id] = []
63+
64+
def delete_chat_history(self, session_id: int) -> None:
65+
del self._chat_history[session_id]
66+
67+
def append_to_history(
68+
self, session_id: int, messages: list[RagStudioChatMessage]
69+
) -> None:
70+
self._chat_history.setdefault(session_id, []).extend(messages)
71+
72+
73+
def patch_get_chat_history_manager() -> (
74+
AbstractContextManager[Callable[[], TestingChatHistoryManager]]
75+
):
76+
session_id = 1
77+
testing_chat_history_manager = TestingChatHistoryManager()
78+
testing_chat_history_manager.append_to_history(
79+
session_id,
80+
[
81+
RagStudioChatMessage(
82+
id=str(uuid.uuid4()),
83+
session_id=session_id,
84+
source_nodes=[],
85+
inference_model=get_provider_class()
86+
.list_llm_models()[0] # TODO: randomize?
87+
.model_id,
88+
rag_message=RagMessage(user="test question", assistant="test answer"),
89+
evaluations=[],
90+
timestamp=time.time(),
91+
condensed_question=None,
92+
)
93+
],
94+
)
95+
96+
return patch(
97+
"app.services.chat_history.chat_history_manager._get_chat_history_manager",
98+
new=lambda: testing_chat_history_manager,
99+
)

llm-service/app/tests/services/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from app.services.caii.types import ListEndpointEntry
4545
from app.services.models.providers import BedrockModelProvider
4646
from app.services.models.providers._model_provider import _ModelProvider
47-
from app.tests.provider_mocks.utils import patch_env_vars
47+
from app.tests.model_provider_mocks.utils import patch_env_vars
4848

4949

5050
@pytest.fixture(params=_ModelProvider.__subclasses__())

0 commit comments

Comments
 (0)