3535# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
3636# DATA.
3737#
38+ import itertools
3839from typing import Generator
3940from unittest .mock import patch
4041from urllib .parse import urljoin
4142
4243import botocore
4344import pytest
4445import responses
46+ from llama_index .llms .bedrock_converse .utils import BEDROCK_MODELS
4547
4648from app .config import settings
49+ from app .services .caii .types import ModelResponse
50+ from app .services .models import ModelProvider
51+ from app .services .models .providers import BedrockModelProvider
52+
53+ TEXT_MODELS = [
54+ ("test.unavailable-text-model-v1" , "NOT_AVAILABLE" ),
55+ ("test.available-text-model-v1" , "AVAILABLE" ),
56+ ]
57+ EMBEDDING_MODELS = [
58+ ("test.unavailable-embedding-model-v1" , "NOT_AVAILABLE" ),
59+ ("test.available-embedding-model-v1" , "AVAILABLE" ),
60+ ]
61+ RERANKING_MODELS = [
62+ ("test.available-reranking-model-v1" , "AVAILABLE" ),
63+ ]
4764
4865
4966@pytest .fixture
50- def mock_bedrock () -> Generator [None , None , None ]:
51- BEDROCK_URL_BASE = f"https://bedrock.{ settings .aws_default_region } .amazonaws.com/"
52- TEXT_MODELS = [
53- ("test.unavailable-text-model-v1" , "NOT_AVAILABLE" ),
54- ("test.available-text-model-v1" , "AVAILABLE" ),
55- ]
56- EMBEDDING_MODELS = [
57- ("test.unavailable-embedding-model-v1" , "NOT_AVAILABLE" ),
58- ("test.available-embedding-model-v1" , "AVAILABLE" ),
59- ]
67+ def mock_bedrock (monkeypatch ) -> Generator [None , None , None ]:
68+ for name in BedrockModelProvider .get_env_var_names ():
69+ monkeypatch .setenv (name , "test" )
70+ for name in get_all_env_var_names () - BedrockModelProvider .get_env_var_names ():
71+ monkeypatch .delenv (name , raising = False )
6072
73+ # mock calls made directly through `requests`
74+ bedrock_url_base = f"https://bedrock.{ settings .aws_default_region } .amazonaws.com/"
6175 r_mock = responses .RequestsMock (assert_all_requests_are_fired = False )
6276 for model_id , availability in TEXT_MODELS + EMBEDDING_MODELS :
6377 r_mock .get (
6478 urljoin (
65- BEDROCK_URL_BASE ,
66- f"foundation-model-availability/{ model_id } :0 " ,
79+ bedrock_url_base ,
80+ f"foundation-model-availability/{ model_id } " ,
6781 ),
6882 json = {
6983 "agreementAvailability" : {
@@ -77,67 +91,54 @@ def mock_bedrock() -> Generator[None, None, None]:
7791 },
7892 )
7993
94+ # mock calls made through `boto3`
8095 make_api_call = botocore .client .BaseClient ._make_api_call
8196
8297 def mock_make_api_call (self , operation_name : str , api_params : dict [str , str ]):
83- """Mock Bedrock calls , since moto doesn't have full coverage.
98+ """Mock Boto3 Bedrock operations , since moto doesn't have full coverage.
8499
85100 Based on https://docs.getmoto.org/en/latest/docs/services/patching_other_services.html.
86101
87102 """
88103 if operation_name == "ListFoundationModels" :
89104 modality = api_params ["byOutputModality" ]
90- if modality == "TEXT" :
91- return {
92- "modelSummaries" : [
93- {
94- "modelArn" : f"arn:aws:bedrock:{ settings .aws_default_region } ::foundation-model/{ model_id } :0" ,
95- "modelId" : f"{ model_id } :0" ,
96- "modelName" : model_id .upper (),
97- "providerName" : "Test" ,
98- "inputModalities" : ["TEXT" ],
99- "outputModalities" : ["TEXT" ],
100- "responseStreamingSupported" : True ,
101- "customizationsSupported" : [],
102- "inferenceTypesSupported" : ["ON_DEMAND" ],
103- "modelLifecycle" : {"status" : "ACTIVE" },
104- }
105- for model_id , _ in TEXT_MODELS
106- ],
107- }
108- elif modality == "EMBEDDING" :
109- return {
110- "modelSummaries" : [
111- {
112- "modelArn" : f"arn:aws:bedrock:{ settings .aws_default_region } ::foundation-model/{ model_id } :0" ,
113- "modelId" : f"{ model_id } :0" ,
114- "modelName" : model_id .upper (),
115- "providerName" : "Test" ,
116- "inputModalities" : ["TEXT" ],
117- "outputModalities" : ["EMBEDDING" ],
118- "responseStreamingSupported" : False ,
119- "customizationsSupported" : [],
120- "inferenceTypesSupported" : ["ON_DEMAND" ],
121- "modelLifecycle" : {"status" : "ACTIVE" },
122- }
123- for model_id , _ in EMBEDDING_MODELS
124- ],
125- }
126- else :
105+ models = {
106+ "TEXT" : TEXT_MODELS ,
107+ "EMBEDDING" : EMBEDDING_MODELS ,
108+ }.get (modality )
109+ if models is None :
127110 raise ValueError (f"test encountered unexpected modality { modality } " )
111+
112+ return {
113+ "modelSummaries" : [
114+ {
115+ "modelArn" : f"arn:aws:bedrock:{ settings .aws_default_region } ::foundation-model/{ model_id } " ,
116+ "modelId" : model_id ,
117+ "modelName" : model_id .upper (),
118+ "providerName" : "Test" ,
119+ "inputModalities" : ["TEXT" ],
120+ "outputModalities" : [modality ],
121+ "responseStreamingSupported" : modality == "TEXT" , # arbitrary
122+ "customizationsSupported" : [],
123+ "inferenceTypesSupported" : ["ON_DEMAND" ],
124+ "modelLifecycle" : {"status" : "ACTIVE" },
125+ }
126+ for model_id , _ in models
127+ ],
128+ }
128129 elif operation_name == "ListInferenceProfiles" :
129130 return {
130131 "inferenceProfileSummaries" : [
131132 {
132133 "inferenceProfileName" : f"US { model_id .upper ()} " ,
133134 "description" : f"Routes requests to { model_id .upper ()} in { settings .aws_default_region } ." ,
134- "inferenceProfileArn" : f"arn:aws:bedrock:{ settings .aws_default_region } :123456789012:inference-profile/{ model_id } :0 " ,
135+ "inferenceProfileArn" : f"arn:aws:bedrock:{ settings .aws_default_region } :123456789012:inference-profile/{ model_id } " ,
135136 "models" : [
136137 {
137- "modelArn" : f"arn:aws:bedrock:{ settings .aws_default_region } ::foundation-model/{ model_id } :0 "
138+ "modelArn" : f"arn:aws:bedrock:{ settings .aws_default_region } ::foundation-model/{ model_id } "
138139 },
139140 ],
140- "inferenceProfileId" : f" { model_id } :0" ,
141+ "inferenceProfileId" : model_id ,
141142 "status" : "ACTIVE" ,
142143 "type" : "SYSTEM_DEFINED" ,
143144 }
@@ -149,13 +150,66 @@ def mock_make_api_call(self, operation_name: str, api_params: dict[str, str]):
149150 # passthrough
150151 return make_api_call (self , operation_name , api_params )
151152
152- with patch ("botocore.client.BaseClient._make_api_call" , new = mock_make_api_call ):
153- with r_mock :
154- yield
153+ # mock reranking models, which are hard-coded in our app
154+ def list_reranking_models () -> list [ModelResponse ]:
155+ return [
156+ ModelResponse (model_id = model_id , name = model_id .upper ())
157+ for model_id , _ in RERANKING_MODELS
158+ ]
155159
160+ with (
161+ r_mock ,
162+ patch (
163+ "botocore.client.BaseClient._make_api_call" ,
164+ new = mock_make_api_call ,
165+ ),
166+ patch (
167+ "app.services.models.providers.BedrockModelProvider.list_reranking_models" ,
168+ new = list_reranking_models ,
169+ ),
170+ patch ( # work around a llama-index filter we have in list_llm_models()
171+ "app.services.models.providers.bedrock.BEDROCK_MODELS" ,
172+ new = BEDROCK_MODELS | {model_id : 128000 for model_id , _ in TEXT_MODELS },
173+ ),
174+ ):
175+ yield
156176
157- def test_bedrock (mock_bedrock ) -> None :
158- from app .services .models .providers import BedrockModelProvider
159177
160- BedrockModelProvider .list_available_models ()
161- BedrockModelProvider ._get_model_arns ()
178+ def get_all_env_var_names () -> set [str ]:
179+ """Return the names of all the env vars required by all model providers."""
180+ return set (
181+ itertools .chain .from_iterable (
182+ subcls .get_env_var_names () for subcls in ModelProvider .__subclasses__ ()
183+ )
184+ )
185+
186+
187+ # TODO: move this test function to a discoverable place
188+ def test_bedrock (mock_bedrock , client ) -> None :
189+ response = client .get ("/llm-service/models/model_source" )
190+ assert response .status_code == 200
191+ assert response .json () == "Bedrock"
192+
193+ response = client .get ("/llm-service/models/embeddings" )
194+ assert response .status_code == 200
195+ assert [model ["model_id" ] for model in response .json ()] == [
196+ model_id
197+ for model_id , availability in EMBEDDING_MODELS
198+ if availability == "AVAILABLE"
199+ ]
200+
201+ response = client .get ("/llm-service/models/llm" )
202+ assert response .status_code == 200
203+ assert [model ["model_id" ] for model in response .json ()] == [
204+ model_id
205+ for model_id , availability in TEXT_MODELS
206+ if availability == "AVAILABLE"
207+ ]
208+
209+ response = client .get ("/llm-service/models/reranking" )
210+ assert response .status_code == 200
211+ assert [model ["model_id" ] for model in response .json ()] == [
212+ model_id
213+ for model_id , availability in RERANKING_MODELS
214+ if availability == "AVAILABLE"
215+ ]
0 commit comments