Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import base64
import warnings
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator
from collections.abc import AsyncIterator, Callable, Iterator
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass, field, replace
from datetime import datetime
Expand Down Expand Up @@ -47,7 +47,7 @@
)
from ..output import OutputMode
from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
from ..providers import infer_provider
from ..providers import Provider, infer_provider
from ..settings import ModelSettings, merge_model_settings
from ..tools import ToolDefinition
from ..usage import RequestUsage
Expand Down Expand Up @@ -677,8 +677,17 @@ def override_allow_model_requests(allow_model_requests: bool) -> Iterator[None]:
ALLOW_MODEL_REQUESTS = old_value # pyright: ignore[reportConstantRedefinition]


def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
"""Infer the model from the name."""
def infer_model( # noqa: C901
model: Model | KnownModelName | str, provider_factory: Callable[[str], Provider[Any]] = infer_provider
) -> Model:
"""Infer the model from the name.

Args:
model:
Model name to instantiate, in the format of `provider:model`. Use the string "test" to instantiate TestModel.
provider_factory:
Function that instantiates a provider object. The provider name is passed into the function parameter. Defaults to `provider.infer_provider`.
"""
if isinstance(model, Model):
return model
elif model == 'test':
Expand Down Expand Up @@ -713,7 +722,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
)
provider_name = 'google-vertex'

provider = infer_provider(provider_name)
provider: Provider[Any] = provider_factory(provider_name)

model_kind = provider_name
if model_kind.startswith('gateway/'):
Expand Down
11 changes: 11 additions & 0 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,17 @@ def test_infer_model(
assert m2 is m


def test_infer_model_with_provider():
from pydantic_ai.providers import openai

provider_class = openai.OpenAIProvider(api_key='1234', base_url='http://test')
m = infer_model('openai:gpt-5', lambda x: provider_class)

assert isinstance(m, OpenAIChatModel)
assert m._provider is provider_class # type: ignore
assert m._provider.base_url == 'http://test' # type: ignore


def test_infer_str_unknown():
with pytest.raises(UserError, match='Unknown model: foobar'):
infer_model('foobar')
Loading