Skip to content

Commit 148aa55

Browse files
committed
Get a working test
1 parent 921bdea commit 148aa55

File tree

1 file changed

+114
-60
lines changed

1 file changed

+114
-60
lines changed

llm-service/app/tests/provider_mocks/bedrock.py

Lines changed: 114 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -35,35 +35,49 @@
3535
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
3636
# DATA.
3737
#
38+
import itertools
3839
from typing import Generator
3940
from unittest.mock import patch
4041
from urllib.parse import urljoin
4142

4243
import botocore
4344
import pytest
4445
import responses
46+
from llama_index.llms.bedrock_converse.utils import BEDROCK_MODELS
4547

4648
from app.config import settings
49+
from app.services.caii.types import ModelResponse
50+
from app.services.models import ModelProvider
51+
from app.services.models.providers import BedrockModelProvider
52+
53+
TEXT_MODELS = [
54+
("test.unavailable-text-model-v1", "NOT_AVAILABLE"),
55+
("test.available-text-model-v1", "AVAILABLE"),
56+
]
57+
EMBEDDING_MODELS = [
58+
("test.unavailable-embedding-model-v1", "NOT_AVAILABLE"),
59+
("test.available-embedding-model-v1", "AVAILABLE"),
60+
]
61+
RERANKING_MODELS = [
62+
("test.available-reranking-model-v1", "AVAILABLE"),
63+
]
4764

4865

4966
@pytest.fixture
50-
def mock_bedrock() -> Generator[None, None, None]:
51-
BEDROCK_URL_BASE = f"https://bedrock.{settings.aws_default_region}.amazonaws.com/"
52-
TEXT_MODELS = [
53-
("test.unavailable-text-model-v1", "NOT_AVAILABLE"),
54-
("test.available-text-model-v1", "AVAILABLE"),
55-
]
56-
EMBEDDING_MODELS = [
57-
("test.unavailable-embedding-model-v1", "NOT_AVAILABLE"),
58-
("test.available-embedding-model-v1", "AVAILABLE"),
59-
]
67+
def mock_bedrock(monkeypatch) -> Generator[None, None, None]:
68+
for name in BedrockModelProvider.get_env_var_names():
69+
monkeypatch.setenv(name, "test")
70+
for name in get_all_env_var_names() - BedrockModelProvider.get_env_var_names():
71+
monkeypatch.delenv(name, raising=False)
6072

73+
# mock calls made directly through `requests`
74+
bedrock_url_base = f"https://bedrock.{settings.aws_default_region}.amazonaws.com/"
6175
r_mock = responses.RequestsMock(assert_all_requests_are_fired=False)
6276
for model_id, availability in TEXT_MODELS + EMBEDDING_MODELS:
6377
r_mock.get(
6478
urljoin(
65-
BEDROCK_URL_BASE,
66-
f"foundation-model-availability/{model_id}:0",
79+
bedrock_url_base,
80+
f"foundation-model-availability/{model_id}",
6781
),
6882
json={
6983
"agreementAvailability": {
@@ -77,67 +91,54 @@ def mock_bedrock() -> Generator[None, None, None]:
7791
},
7892
)
7993

94+
# mock calls made through `boto3`
8095
make_api_call = botocore.client.BaseClient._make_api_call
8196

8297
def mock_make_api_call(self, operation_name: str, api_params: dict[str, str]):
83-
"""Mock Bedrock calls, since moto doesn't have full coverage.
98+
"""Mock Boto3 Bedrock operations, since moto doesn't have full coverage.
8499
85100
Based on https://docs.getmoto.org/en/latest/docs/services/patching_other_services.html.
86101
87102
"""
88103
if operation_name == "ListFoundationModels":
89104
modality = api_params["byOutputModality"]
90-
if modality == "TEXT":
91-
return {
92-
"modelSummaries": [
93-
{
94-
"modelArn": f"arn:aws:bedrock:{settings.aws_default_region}::foundation-model/{model_id}:0",
95-
"modelId": f"{model_id}:0",
96-
"modelName": model_id.upper(),
97-
"providerName": "Test",
98-
"inputModalities": ["TEXT"],
99-
"outputModalities": ["TEXT"],
100-
"responseStreamingSupported": True,
101-
"customizationsSupported": [],
102-
"inferenceTypesSupported": ["ON_DEMAND"],
103-
"modelLifecycle": {"status": "ACTIVE"},
104-
}
105-
for model_id, _ in TEXT_MODELS
106-
],
107-
}
108-
elif modality == "EMBEDDING":
109-
return {
110-
"modelSummaries": [
111-
{
112-
"modelArn": f"arn:aws:bedrock:{settings.aws_default_region}::foundation-model/{model_id}:0",
113-
"modelId": f"{model_id}:0",
114-
"modelName": model_id.upper(),
115-
"providerName": "Test",
116-
"inputModalities": ["TEXT"],
117-
"outputModalities": ["EMBEDDING"],
118-
"responseStreamingSupported": False,
119-
"customizationsSupported": [],
120-
"inferenceTypesSupported": ["ON_DEMAND"],
121-
"modelLifecycle": {"status": "ACTIVE"},
122-
}
123-
for model_id, _ in EMBEDDING_MODELS
124-
],
125-
}
126-
else:
105+
models = {
106+
"TEXT": TEXT_MODELS,
107+
"EMBEDDING": EMBEDDING_MODELS,
108+
}.get(modality)
109+
if models is None:
127110
raise ValueError(f"test encountered unexpected modality {modality}")
111+
112+
return {
113+
"modelSummaries": [
114+
{
115+
"modelArn": f"arn:aws:bedrock:{settings.aws_default_region}::foundation-model/{model_id}",
116+
"modelId": model_id,
117+
"modelName": model_id.upper(),
118+
"providerName": "Test",
119+
"inputModalities": ["TEXT"],
120+
"outputModalities": [modality],
121+
"responseStreamingSupported": modality == "TEXT", # arbitrary
122+
"customizationsSupported": [],
123+
"inferenceTypesSupported": ["ON_DEMAND"],
124+
"modelLifecycle": {"status": "ACTIVE"},
125+
}
126+
for model_id, _ in models
127+
],
128+
}
128129
elif operation_name == "ListInferenceProfiles":
129130
return {
130131
"inferenceProfileSummaries": [
131132
{
132133
"inferenceProfileName": f"US {model_id.upper()}",
133134
"description": f"Routes requests to {model_id.upper()} in {settings.aws_default_region}.",
134-
"inferenceProfileArn": f"arn:aws:bedrock:{settings.aws_default_region}:123456789012:inference-profile/{model_id}:0",
135+
"inferenceProfileArn": f"arn:aws:bedrock:{settings.aws_default_region}:123456789012:inference-profile/{model_id}",
135136
"models": [
136137
{
137-
"modelArn": f"arn:aws:bedrock:{settings.aws_default_region}::foundation-model/{model_id}:0"
138+
"modelArn": f"arn:aws:bedrock:{settings.aws_default_region}::foundation-model/{model_id}"
138139
},
139140
],
140-
"inferenceProfileId": f"{model_id}:0",
141+
"inferenceProfileId": model_id,
141142
"status": "ACTIVE",
142143
"type": "SYSTEM_DEFINED",
143144
}
@@ -149,13 +150,66 @@ def mock_make_api_call(self, operation_name: str, api_params: dict[str, str]):
149150
# passthrough
150151
return make_api_call(self, operation_name, api_params)
151152

152-
with patch("botocore.client.BaseClient._make_api_call", new=mock_make_api_call):
153-
with r_mock:
154-
yield
153+
# mock reranking models, which are hard-coded in our app
154+
def list_reranking_models() -> list[ModelResponse]:
155+
return [
156+
ModelResponse(model_id=model_id, name=model_id.upper())
157+
for model_id, _ in RERANKING_MODELS
158+
]
155159

160+
with (
161+
r_mock,
162+
patch(
163+
"botocore.client.BaseClient._make_api_call",
164+
new=mock_make_api_call,
165+
),
166+
patch(
167+
"app.services.models.providers.BedrockModelProvider.list_reranking_models",
168+
new=list_reranking_models,
169+
),
170+
patch( # work around a llama-index filter we have in list_llm_models()
171+
"app.services.models.providers.bedrock.BEDROCK_MODELS",
172+
new=BEDROCK_MODELS | {model_id: 128000 for model_id, _ in TEXT_MODELS},
173+
),
174+
):
175+
yield
156176

157-
def test_bedrock(mock_bedrock) -> None:
158-
from app.services.models.providers import BedrockModelProvider
159177

160-
BedrockModelProvider.list_available_models()
161-
BedrockModelProvider._get_model_arns()
178+
def get_all_env_var_names() -> set[str]:
179+
"""Return the names of all the env vars required by all model providers."""
180+
return set(
181+
itertools.chain.from_iterable(
182+
subcls.get_env_var_names() for subcls in ModelProvider.__subclasses__()
183+
)
184+
)
185+
186+
187+
# TODO: move this test function to a discoverable place
188+
def test_bedrock(mock_bedrock, client) -> None:
189+
response = client.get("/llm-service/models/model_source")
190+
assert response.status_code == 200
191+
assert response.json() == "Bedrock"
192+
193+
response = client.get("/llm-service/models/embeddings")
194+
assert response.status_code == 200
195+
assert [model["model_id"] for model in response.json()] == [
196+
model_id
197+
for model_id, availability in EMBEDDING_MODELS
198+
if availability == "AVAILABLE"
199+
]
200+
201+
response = client.get("/llm-service/models/llm")
202+
assert response.status_code == 200
203+
assert [model["model_id"] for model in response.json()] == [
204+
model_id
205+
for model_id, availability in TEXT_MODELS
206+
if availability == "AVAILABLE"
207+
]
208+
209+
response = client.get("/llm-service/models/reranking")
210+
assert response.status_code == 200
211+
assert [model["model_id"] for model in response.json()] == [
212+
model_id
213+
for model_id, availability in RERANKING_MODELS
214+
if availability == "AVAILABLE"
215+
]

0 commit comments

Comments
 (0)