diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 06258bc6ba..c8c59fceac 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -9,7 +9,7 @@ import base64 import warnings from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Iterator +from collections.abc import AsyncIterator, Callable, Iterator from contextlib import asynccontextmanager, contextmanager from dataclasses import dataclass, field, replace from datetime import datetime @@ -47,7 +47,7 @@ ) from ..output import OutputMode from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec -from ..providers import infer_provider +from ..providers import Provider, infer_provider from ..settings import ModelSettings, merge_model_settings from ..tools import ToolDefinition from ..usage import RequestUsage @@ -724,8 +724,17 @@ def override_allow_model_requests(allow_model_requests: bool) -> Iterator[None]: ALLOW_MODEL_REQUESTS = old_value # pyright: ignore[reportConstantRedefinition] -def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901 - """Infer the model from the name.""" +def infer_model( # noqa: C901 + model: Model | KnownModelName | str, provider_factory: Callable[[str], Provider[Any]] = infer_provider +) -> Model: + """Infer the model from the name. + + Args: + model: + Model name to instantiate, in the format of `provider:model`. Use the string "test" to instantiate TestModel. + provider_factory: + Function that instantiates a provider object. The provider name is passed into the function parameter. Defaults to `provider.infer_provider`. + """ if isinstance(model, Model): return model elif model == 'test': @@ -760,7 +769,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901 ) provider_name = 'google-vertex' - provider = infer_provider(provider_name) + provider: Provider[Any] = provider_factory(provider_name) model_kind = provider_name if model_kind.startswith('gateway/'): diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 20a50699d8..9e8879de7f 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -242,6 +242,17 @@ def test_infer_model( assert m2 is m +def test_infer_model_with_provider(): + from pydantic_ai.providers import openai + + provider_class = openai.OpenAIProvider(api_key='1234', base_url='http://test') + m = infer_model('openai:gpt-5', lambda x: provider_class) + + assert isinstance(m, OpenAIChatModel) + assert m._provider is provider_class # type: ignore + assert m._provider.base_url == 'http://test' # type: ignore + + def test_infer_str_unknown(): with pytest.raises(UserError, match='Unknown model: foobar'): infer_model('foobar')