-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Add gateway/...:... to Known Model Names
#3593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
293b3f8
00d05bb
90680fc
e3b090a
4af4513
9d213ce
6e05859
dc3fd1c
0359386
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |
| from typing_extensions import TypedDict | ||
|
|
||
| from pydantic_ai.models import KnownModelName | ||
| from pydantic_ai.providers.gateway import ModelProviders | ||
|
|
||
| from ..conftest import try_import | ||
|
|
||
|
|
@@ -49,6 +50,22 @@ def vcr_config(): # pragma: lax no cover | |
| } | ||
|
|
||
|
|
||
| _PROVIDER_TO_MODEL_NAMES = { | ||
| 'anthropic': AnthropicModelName, | ||
| 'bedrock': BedrockModelName, | ||
| 'cohere': CohereModelName, | ||
| 'deepseek': Literal['deepseek-chat', 'deepseek-reasoner'], | ||
|
||
| 'google-gla': GoogleModelName, | ||
| 'google-vertex': GoogleModelName, | ||
| 'grok': GrokModelName, | ||
| 'groq': GroqModelName, | ||
| 'huggingface': HuggingFaceModelName, | ||
| 'mistral': MistralModelName, | ||
| 'moonshotai': MoonshotAIModelName, | ||
| 'openai': OpenAIModelName, | ||
| } | ||
|
|
||
|
|
||
| def test_known_model_names(): # pragma: lax no cover | ||
| # Coverage seems to be misbehaving..? | ||
| def get_model_names(model_name_type: Any) -> Iterator[str]: | ||
|
|
@@ -58,39 +75,23 @@ def get_model_names(model_name_type: Any) -> Iterator[str]: | |
| else: | ||
| yield from get_model_names(arg) | ||
|
|
||
| anthropic_names = [f'anthropic:{n}' for n in get_model_names(AnthropicModelName)] | ||
| cohere_names = [f'cohere:{n}' for n in get_model_names(CohereModelName)] | ||
| google_names = [f'google-gla:{n}' for n in get_model_names(GoogleModelName)] + [ | ||
| f'google-vertex:{n}' for n in get_model_names(GoogleModelName) | ||
| all_generated_names = [ | ||
| f'{provider}:{n}' | ||
| for provider, model_names in _PROVIDER_TO_MODEL_NAMES.items() | ||
| for n in get_model_names(model_names) | ||
| ] | ||
| grok_names = [f'grok:{n}' for n in get_model_names(GrokModelName)] | ||
| groq_names = [f'groq:{n}' for n in get_model_names(GroqModelName)] | ||
| moonshotai_names = [f'moonshotai:{n}' for n in get_model_names(MoonshotAIModelName)] | ||
| mistral_names = [f'mistral:{n}' for n in get_model_names(MistralModelName)] | ||
| openai_names = [f'openai:{n}' for n in get_model_names(OpenAIModelName)] | ||
| bedrock_names = [f'bedrock:{n}' for n in get_model_names(BedrockModelName)] | ||
| deepseek_names = ['deepseek:deepseek-chat', 'deepseek:deepseek-reasoner'] | ||
| huggingface_names = [f'huggingface:{n}' for n in get_model_names(HuggingFaceModelName)] | ||
| heroku_names = get_heroku_model_names() | ||
|
|
||
| cerebras_names = get_cerebras_model_names() | ||
| heroku_names = get_heroku_model_names() | ||
| gateway_names = [ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I realize you couldn't see this comment in a private Slack channel, but I responded to Samuel (and he agreed):
So we should NOT hard-code this, but dynamically build this based on the known model names of the providers that are known to work with
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is now done |
||
| f'gateway/{provider}:{model_name}' | ||
| for provider in ModelProviders.__args__ | ||
| for model_name in get_model_names(_PROVIDER_TO_MODEL_NAMES[provider]) | ||
| ] | ||
|
|
||
| extra_names = ['test'] | ||
|
|
||
| generated_names = sorted( | ||
| anthropic_names | ||
| + cohere_names | ||
| + google_names | ||
| + grok_names | ||
| + groq_names | ||
| + mistral_names | ||
| + moonshotai_names | ||
| + openai_names | ||
| + bedrock_names | ||
| + deepseek_names | ||
| + huggingface_names | ||
| + heroku_names | ||
| + cerebras_names | ||
| + extra_names | ||
| ) | ||
| generated_names = sorted(all_generated_names + gateway_names + heroku_names + cerebras_names + extra_names) | ||
|
|
||
| known_model_names = sorted(get_args(KnownModelName.__value__)) | ||
| assert generated_names == known_model_names | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a type, so singular :)