Skip to content

Commit de1d580

Browse files
committed
Mostly satisfy linters
1 parent 3b26b7a commit de1d580

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed

llm-service/app/tests/model_provider_mocks/bedrock.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,21 +36,21 @@
3636
# DATA.
3737
#
3838
from contextlib import AbstractContextManager
39-
from typing import Iterator
39+
from typing import Iterator, Callable, Any
4040
from unittest.mock import patch
4141
from urllib.parse import urljoin
4242

4343
import botocore
4444
import pytest
4545
import responses
46+
from fastapi.testclient import TestClient
4647
from llama_index.llms.bedrock_converse.utils import BEDROCK_MODELS
4748

4849
from app.config import settings
4950
from app.services.caii.types import ModelResponse
5051
from app.services.models.providers import BedrockModelProvider
5152
from .testing_chat_history_manager import (
5253
patch_get_chat_history_manager,
53-
TestingChatHistoryManager,
5454
)
5555
from .utils import patch_env_vars
5656

@@ -67,7 +67,7 @@
6767
]
6868

6969

70-
def _patch_requests() -> AbstractContextManager:
70+
def _patch_requests() -> AbstractContextManager[responses.RequestsMock]:
7171
bedrock_url_base = f"https://bedrock.{settings.aws_default_region}.amazonaws.com/"
7272
r_mock = responses.RequestsMock(assert_all_requests_are_fired=False)
7373
for model_id, availability in TEXT_MODELS + EMBEDDING_MODELS:
@@ -91,10 +91,17 @@ def _patch_requests() -> AbstractContextManager:
9191
return r_mock
9292

9393

94-
def _patch_boto3() -> AbstractContextManager:
95-
make_api_call = botocore.client.BaseClient._make_api_call
94+
make_api_callable = Callable[[type, str, dict[str, str]], Any]
9695

97-
def mock_make_api_call(self, operation_name: str, api_params: dict[str, str]):
96+
97+
def _patch_boto3() -> AbstractContextManager[make_api_callable]:
98+
make_api_call: make_api_callable = botocore.client.BaseClient._make_api_call # type: ignore
99+
100+
def mock_make_api_call(
101+
self: type,
102+
operation_name: str,
103+
api_params: dict[str, str],
104+
) -> Any:
98105
"""Mock Boto3 Bedrock operations, since moto doesn't have full coverage.
99106
100107
Based on https://docs.getmoto.org/en/latest/docs/services/patching_other_services.html.
@@ -154,7 +161,7 @@ def mock_make_api_call(self, operation_name: str, api_params: dict[str, str]):
154161

155162

156163
@pytest.fixture(autouse=True)
157-
def mock_bedrock(monkeypatch) -> Iterator[None]:
164+
def mock_bedrock(monkeypatch: pytest.MonkeyPatch) -> Iterator[None]:
158165
with patch_env_vars(BedrockModelProvider):
159166
with (
160167
_patch_requests(),
@@ -175,7 +182,7 @@ def mock_bedrock(monkeypatch) -> Iterator[None]:
175182

176183

177184
# TODO: move test functions to a discoverable place
178-
def test_bedrock_models(client) -> None:
185+
def test_bedrock_models(client: TestClient) -> None:
179186
response = client.get("/llm-service/models/model_source")
180187
assert response.status_code == 200
181188
assert response.json() == "Bedrock"
@@ -207,7 +214,7 @@ def test_bedrock_models(client) -> None:
207214
# response = client.get("/llm-service/models/embedding/cohere.embed-english-v3/test")
208215

209216

210-
def test_bedrock_sessions(client) -> None:
217+
def test_bedrock_sessions(client: TestClient) -> None:
211218
session_id = 1
212219
with patch_get_chat_history_manager() as get_testing_chat_history_manager:
213220
chat_history = get_testing_chat_history_manager().retrieve_chat_history(
@@ -231,7 +238,7 @@ def test_bedrock_sessions(client) -> None:
231238
# assert response.status_code == 200
232239

233240

234-
# def test_bedrock_chat(client) -> None:
241+
# def test_bedrock_chat(client: TestClient) -> None:
235242
# response = client.post("/llm-service/sessions/suggest-questions")
236243
# print(f"{response.json()=}")
237244
# assert response.status_code == 200

llm-service/app/tests/model_provider_mocks/testing_chat_history_manager.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,16 @@
4141
from typing import Callable
4242
from unittest.mock import patch
4343

44-
from alembic.testing.fixtures import testing_config
45-
4644
from app.services.chat_history.chat_history_manager import (
4745
ChatHistoryManager,
4846
RagStudioChatMessage,
4947
RagMessage,
5048
)
51-
from app.services.models import get_provider_class
49+
from app.services.models.providers import get_provider_class
5250

5351

5452
class TestingChatHistoryManager(ChatHistoryManager):
55-
def __init__(self):
53+
def __init__(self) -> None:
5654
self._chat_history: dict[int, list[RagStudioChatMessage]] = dict()
5755

5856
def retrieve_chat_history(self, session_id: int) -> list[RagStudioChatMessage]:
@@ -70,6 +68,7 @@ def append_to_history(
7068
self._chat_history.setdefault(session_id, []).extend(messages)
7169

7270

71+
# TODO: we might want to specifically patch S3 and Simple to test their implementations
7372
def patch_get_chat_history_manager() -> (
7473
AbstractContextManager[Callable[[], TestingChatHistoryManager]]
7574
):

0 commit comments

Comments
 (0)