Skip to content

Commit f641cb4

Browse files
committed
Move get_all_env_var_names() to _model_provider
1 parent 148aa55 commit f641cb4

File tree

3 files changed

+15
-19
lines changed

3 files changed

+15
-19
lines changed

llm-service/app/services/models/providers/_model_provider.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
# DATA.
3737
#
3838
import abc
39+
import itertools
3940
import os
4041

4142
from llama_index.core.base.embeddings.base import BaseEmbedding
@@ -128,3 +129,12 @@ def get_reranking_model(name: str, top_n: int) -> BaseNodePostprocessor:
128129
@abc.abstractmethod
129130
def get_model_source() -> ModelSource:
130131
raise NotImplementedError
132+
133+
134+
def get_all_env_var_names() -> set[str]:
135+
"""Return the names of all the env vars required by all model providers."""
136+
return set(
137+
itertools.chain.from_iterable(
138+
subcls.get_env_var_names() for subcls in ModelProvider.__subclasses__()
139+
)
140+
)

llm-service/app/tests/provider_mocks/bedrock.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from app.services.caii.types import ModelResponse
5050
from app.services.models import ModelProvider
5151
from app.services.models.providers import BedrockModelProvider
52+
from app.services.models.providers._model_provider import get_all_env_var_names
5253

5354
TEXT_MODELS = [
5455
("test.unavailable-text-model-v1", "NOT_AVAILABLE"),
@@ -175,15 +176,6 @@ def list_reranking_models() -> list[ModelResponse]:
175176
yield
176177

177178

178-
def get_all_env_var_names() -> set[str]:
179-
"""Return the names of all the env vars required by all model providers."""
180-
return set(
181-
itertools.chain.from_iterable(
182-
subcls.get_env_var_names() for subcls in ModelProvider.__subclasses__()
183-
)
184-
)
185-
186-
187179
# TODO: move this test function to a discoverable place
188180
def test_bedrock(mock_bedrock, client) -> None:
189181
response = client.get("/llm-service/models/model_source")

llm-service/app/tests/services/test_models.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,10 @@
4343
from app.services.caii import caii
4444
from app.services.caii.types import ListEndpointEntry
4545
from app.services.models.providers import BedrockModelProvider
46-
from app.services.models.providers._model_provider import ModelProvider
47-
48-
49-
def get_all_env_var_names() -> set[str]:
50-
"""Return the names of all the env vars required by all model providers."""
51-
return set(
52-
itertools.chain.from_iterable(
53-
subcls.get_env_var_names() for subcls in ModelProvider.__subclasses__()
54-
)
55-
)
46+
from app.services.models.providers._model_provider import (
47+
ModelProvider,
48+
get_all_env_var_names,
49+
)
5650

5751

5852
@pytest.fixture(params=ModelProvider.__subclasses__())

0 commit comments

Comments
 (0)