3636# DATA.
3737#
3838from contextlib import AbstractContextManager
39- from typing import Iterator
39+ from typing import Iterator , Callable , Any
4040from unittest .mock import patch
4141from urllib .parse import urljoin
4242
4343import botocore
4444import pytest
4545import responses
46+ from fastapi .testclient import TestClient
4647from llama_index .llms .bedrock_converse .utils import BEDROCK_MODELS
4748
4849from app .config import settings
4950from app .services .caii .types import ModelResponse
5051from app .services .models .providers import BedrockModelProvider
5152from .testing_chat_history_manager import (
5253 patch_get_chat_history_manager ,
53- TestingChatHistoryManager ,
5454)
5555from .utils import patch_env_vars
5656
6767]
6868
6969
70- def _patch_requests () -> AbstractContextManager :
70+ def _patch_requests () -> AbstractContextManager [ responses . RequestsMock ] :
7171 bedrock_url_base = f"https://bedrock.{ settings .aws_default_region } .amazonaws.com/"
7272 r_mock = responses .RequestsMock (assert_all_requests_are_fired = False )
7373 for model_id , availability in TEXT_MODELS + EMBEDDING_MODELS :
@@ -91,10 +91,17 @@ def _patch_requests() -> AbstractContextManager:
9191 return r_mock
9292
9393
94- def _patch_boto3 () -> AbstractContextManager :
95- make_api_call = botocore .client .BaseClient ._make_api_call
94+ make_api_callable = Callable [[type , str , dict [str , str ]], Any ]
9695
97- def mock_make_api_call (self , operation_name : str , api_params : dict [str , str ]):
96+
97+ def _patch_boto3 () -> AbstractContextManager [make_api_callable ]:
98+ make_api_call : make_api_callable = botocore .client .BaseClient ._make_api_call # type: ignore
99+
100+ def mock_make_api_call (
101+ self : type ,
102+ operation_name : str ,
103+ api_params : dict [str , str ],
104+ ) -> Any :
98105 """Mock Boto3 Bedrock operations, since moto doesn't have full coverage.
99106
100107 Based on https://docs.getmoto.org/en/latest/docs/services/patching_other_services.html.
@@ -154,7 +161,7 @@ def mock_make_api_call(self, operation_name: str, api_params: dict[str, str]):
154161
155162
156163@pytest .fixture (autouse = True )
157- def mock_bedrock (monkeypatch ) -> Iterator [None ]:
164+ def mock_bedrock (monkeypatch : pytest . MonkeyPatch ) -> Iterator [None ]:
158165 with patch_env_vars (BedrockModelProvider ):
159166 with (
160167 _patch_requests (),
@@ -175,7 +182,7 @@ def mock_bedrock(monkeypatch) -> Iterator[None]:
175182
176183
177184# TODO: move test functions to a discoverable place
178- def test_bedrock_models (client ) -> None :
185+ def test_bedrock_models (client : TestClient ) -> None :
179186 response = client .get ("/llm-service/models/model_source" )
180187 assert response .status_code == 200
181188 assert response .json () == "Bedrock"
@@ -207,7 +214,7 @@ def test_bedrock_models(client) -> None:
207214 # response = client.get("/llm-service/models/embedding/cohere.embed-english-v3/test")
208215
209216
210- def test_bedrock_sessions (client ) -> None :
217+ def test_bedrock_sessions (client : TestClient ) -> None :
211218 session_id = 1
212219 with patch_get_chat_history_manager () as get_testing_chat_history_manager :
213220 chat_history = get_testing_chat_history_manager ().retrieve_chat_history (
@@ -231,7 +238,7 @@ def test_bedrock_sessions(client) -> None:
231238 # assert response.status_code == 200
232239
233240
234- # def test_bedrock_chat(client) -> None:
241+ # def test_bedrock_chat(client: TestClient ) -> None:
235242# response = client.post("/llm-service/sessions/suggest-questions")
236243# print(f"{response.json()=}")
237244# assert response.status_code == 200
0 commit comments