From fc05106819c256b4c1db06f838f89d369ffeb724 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Sun, 19 Oct 2025 18:01:13 -0500 Subject: [PATCH 1/6] Add Cloudflare AI Gateway provider Implements CloudflareProvider for routing requests through Cloudflare's unified AI Gateway API with support for: - Multiple AI providers (OpenAI, Anthropic, Groq, Mistral, Cohere, etc.) - BYOK (bring your own key) mode - Stored keys mode (API keys managed in Cloudflare dashboard) - Authenticated gateways with cf-aig-authorization header - Intelligent model profiling for Groq and Cerebras models Generated with Claude Code https://claude.com/claude-code Co-Authored-By: Claude --- docs/api/providers.md | 2 + docs/models/openai.md | 48 +++ docs/models/overview.md | 1 + .../pydantic_ai/providers/__init__.py | 4 + .../pydantic_ai/providers/cloudflare.py | 334 ++++++++++++++++ tests/providers/test_cloudflare.py | 363 ++++++++++++++++++ tests/providers/test_provider_names.py | 2 + 7 files changed, 754 insertions(+) create mode 100644 pydantic_ai_slim/pydantic_ai/providers/cloudflare.py create mode 100644 tests/providers/test_cloudflare.py diff --git a/docs/api/providers.md b/docs/api/providers.md index 68c124ce67..3e60b7fb62 100644 --- a/docs/api/providers.md +++ b/docs/api/providers.md @@ -34,6 +34,8 @@ ::: pydantic_ai.providers.vercel.VercelProvider +::: pydantic_ai.providers.cloudflare.CloudflareProvider + ::: pydantic_ai.providers.huggingface.HuggingFaceProvider ::: pydantic_ai.providers.moonshotai.MoonshotAIProvider diff --git a/docs/models/openai.md b/docs/models/openai.md index e28dc8374e..6235803b85 100644 --- a/docs/models/openai.md +++ b/docs/models/openai.md @@ -441,6 +441,54 @@ agent = Agent(model) ... ``` +### Cloudflare AI Gateway + +To use [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/), first set up a gateway in your [Cloudflare dashboard](https://dash.cloudflare.com/?to=/:account/ai/ai-gateway) and obtain your account ID and gateway ID. + +!!! note + This provider uses Cloudflare's [unified API endpoint](https://developers.cloudflare.com/ai-gateway/usage/chat-completion/) for routing requests to multiple AI providers. For the full list of supported providers, see [Cloudflare's documentation](https://developers.cloudflare.com/ai-gateway/usage/chat-completion/#supported-providers). + +You can set the `CLOUDFLARE_ACCOUNT_ID`, `CLOUDFLARE_GATEWAY_ID`, and optionally `CLOUDFLARE_AI_GATEWAY_AUTH` environment variables and use [`CloudflareProvider`][pydantic_ai.providers.cloudflare.CloudflareProvider]: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIChatModel +from pydantic_ai.providers.cloudflare import CloudflareProvider + +model = OpenAIChatModel( + 'openai/gpt-4o', + provider=CloudflareProvider( + account_id='your-account-id', + gateway_id='your-gateway-id', + api_key='your-openai-api-key', + ), +) +agent = Agent(model) +... +``` + +For authenticated gateways with stored API keys in Cloudflare's dashboard: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIChatModel +from pydantic_ai.providers.cloudflare import CloudflareProvider + +model = OpenAIChatModel( + 'anthropic/claude-3-5-sonnet', + provider=CloudflareProvider( + account_id='your-account-id', + gateway_id='your-gateway-id', + cf_aig_authorization='your-gateway-token', + use_gateway_keys=True, + ), +) +agent = Agent(model) +... +``` + +See [`CloudflareProvider`][pydantic_ai.providers.cloudflare.CloudflareProvider] for additional configuration options including BYOK modes and authenticated gateways. + ### Grok (xAI) Go to [xAI API Console](https://console.x.ai/) and create an API key. diff --git a/docs/models/overview.md b/docs/models/overview.md index 45af29c862..46e40f70fc 100644 --- a/docs/models/overview.md +++ b/docs/models/overview.md @@ -20,6 +20,7 @@ In addition, many providers are compatible with the OpenAI API, and can be used - [Ollama](openai.md#ollama) - [OpenRouter](openai.md#openrouter) - [Vercel AI Gateway](openai.md#vercel-ai-gateway) +- [Cloudflare AI Gateway](openai.md#cloudflare-ai-gateway) - [Perplexity](openai.md#perplexity) - [Fireworks AI](openai.md#fireworks-ai) - [Together AI](openai.md#together-ai) diff --git a/pydantic_ai_slim/pydantic_ai/providers/__init__.py b/pydantic_ai_slim/pydantic_ai/providers/__init__.py index f71f2d94e0..6d6b78904a 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/providers/__init__.py @@ -69,6 +69,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901 from .vercel import VercelProvider return VercelProvider + elif provider == 'cloudflare': + from .cloudflare import CloudflareProvider + + return CloudflareProvider elif provider == 'azure': from .azure import AzureProvider diff --git a/pydantic_ai_slim/pydantic_ai/providers/cloudflare.py b/pydantic_ai_slim/pydantic_ai/providers/cloudflare.py new file mode 100644 index 0000000000..475a10fa51 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/providers/cloudflare.py @@ -0,0 +1,334 @@ +from __future__ import annotations as _annotations + +import os +from typing import overload + +import httpx + +from pydantic_ai import ModelProfile +from pydantic_ai.exceptions import UserError +from pydantic_ai.models import cached_async_http_client +from pydantic_ai.profiles.anthropic import anthropic_model_profile +from pydantic_ai.profiles.cohere import cohere_model_profile +from pydantic_ai.profiles.deepseek import deepseek_model_profile +from pydantic_ai.profiles.google import google_model_profile +from pydantic_ai.profiles.grok import grok_model_profile +from pydantic_ai.profiles.mistral import mistral_model_profile +from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile +from pydantic_ai.providers import Provider + +from .cerebras import CerebrasProvider +from .groq import GroqProvider + +try: + from openai import AsyncOpenAI +except ImportError as _import_error: # pragma: no cover + raise ImportError( + 'Please install the `openai` package to use the Cloudflare provider, ' + 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`' + ) from _import_error + + +def _groq_model_profile_cloudflare(model_name: str) -> ModelProfile | None: + """Get the model profile for Groq models routed through Cloudflare's unified API. + + Cloudflare routes to Groq's OpenAI-compatible endpoint, so we use prefix matching + similar to the native GroqProvider to determine the appropriate profile. + """ + return GroqProvider().model_profile(model_name) + + +def _cerebras_model_profile_cloudflare(model_name: str) -> ModelProfile | None: + """Get the model profile for Cerebras models routed through Cloudflare's unified API. + + Similar to the native CerebrasProvider, use prefix matching to determine profiles. + """ + return CerebrasProvider().model_profile(model_name) + + +class CloudflareProvider(Provider[AsyncOpenAI]): + """Provider for Cloudflare AI Gateway API. + + Cloudflare AI Gateway provides a unified OpenAI-compatible endpoint that routes + requests to various AI providers while adding features like caching, rate limiting, + analytics, and logging. + + !!! note + This provider uses Cloudflare's unified API endpoint for routing requests. + For the full list of supported providers, see + [Cloudflare's documentation](https://developers.cloudflare.com/ai-gateway/usage/chat-completion/#supported-providers). + + This provider looks for these environment variables if they are not provided as parameters: + - account_id: `CLOUDFLARE_ACCOUNT_ID` + - gateway_id: `CLOUDFLARE_GATEWAY_ID` + - cf_aig_authorization: `CLOUDFLARE_AI_GATEWAY_AUTH` (optional) + + There are three usage modes: + + 1. BYOK with unauthenticated gateway (bring your own API key): + ```python + from pydantic_ai import Agent + from pydantic_ai.providers.cloudflare import CloudflareProvider + + provider = CloudflareProvider( + account_id='your-account-id', + gateway_id='your-gateway-id', + api_key='your-openai-api-key' # Your own provider API key + ) + agent = Agent('openai/gpt-4o', provider=provider) + ``` + + 2. BYOK with authenticated gateway (API key + gateway authentication): + ```python + provider = CloudflareProvider( + account_id='your-account-id', + gateway_id='your-gateway-id', + api_key='your-openai-api-key', + cf_aig_authorization='your-gateway-token' + ) + agent = Agent('anthropic/claude-3-5-sonnet', provider=provider) + ``` + + 3. Stored keys mode (use API keys stored in Cloudflare dashboard): + ```python + # Requires authenticated gateway - API keys are stored in your Cloudflare dashboard + # Set use_gateway_keys=True and provide cf_aig_authorization + provider = CloudflareProvider( + account_id='your-account-id', + gateway_id='your-gateway-id', + cf_aig_authorization='your-gateway-token', + use_gateway_keys=True + ) + agent = Agent('openai/gpt-4o', provider=provider) + ``` + """ + + @property + def name(self) -> str: + return 'cloudflare' + + @property + def base_url(self) -> str: + return self._base_url + + @property + def client(self) -> AsyncOpenAI: + return self._client + + def model_profile(self, model_name: str) -> ModelProfile | None: + """Return the model profile for the given model name. + + Model names should be in the format 'provider/model', e.g., 'openai/gpt-4o', + 'anthropic/claude-3-5-sonnet', 'groq/llama-3.3-70b-versatile'. + + For the full list of supported providers, see + [Cloudflare's documentation](https://developers.cloudflare.com/ai-gateway/usage/chat-completion/#supported-providers). + """ + provider_to_profile = { + 'anthropic': anthropic_model_profile, + 'openai': openai_model_profile, + 'groq': _groq_model_profile_cloudflare, + 'mistral': mistral_model_profile, + 'cohere': cohere_model_profile, + 'deepseek': deepseek_model_profile, + # NOTE: this would be the first support for perplexity in pydantic-ai + # to remove this, the equivalent test in tests/providers/test_cloudflare.py::test_cloudflare_provider_model_profile would need to be removed + 'perplexity': openai_model_profile, # Perplexity uses OpenAI-compatible API + 'workers-ai': openai_model_profile, # Cloudflare Workers AI uses OpenAI-compatible API + 'workersai': openai_model_profile, # Alternative naming + 'google-ai-studio': google_model_profile, + 'grok': grok_model_profile, + 'xai': grok_model_profile, # xai is an alias for grok + 'cerebras': _cerebras_model_profile_cloudflare, + } + + profile = None + + try: + provider, model_name = model_name.split('/', 1) + except ValueError: + raise UserError(f"Model name must be in 'provider/model' format, got: {model_name!r}") + + if provider in provider_to_profile: + profile = provider_to_profile[provider](model_name) + # If provider is not recognized, profile remains None and we fall back to OpenAI-compatible behavior. + # This matches VercelProvider's behavior of silently supporting unknown providers through the unified API. + + # As CloudflareProvider is always used with OpenAIChatModel, which used to unconditionally use OpenAIJsonSchemaTransformer, + # we need to maintain that behavior unless json_schema_transformer is set explicitly + return OpenAIModelProfile( + json_schema_transformer=OpenAIJsonSchemaTransformer, + ).update(profile) + + @staticmethod + def _create_stored_keys_client(base_client: httpx.AsyncClient) -> httpx.AsyncClient: + """Create an HTTP client that strips the Authorization header for stored keys mode. + + When using Cloudflare's stored keys feature (API keys stored in the dashboard), + the Authorization header must NOT be sent. If sent, it takes precedence over + the stored keys, breaking the feature. + + This wraps the base client with an event hook that removes the Authorization header. + """ + + async def strip_auth_header(request: httpx.Request) -> None: + """Remove Authorization header so Cloudflare uses stored keys from dashboard.""" + if 'authorization' in request.headers: + del request.headers['authorization'] + + # Merge event hooks - preserve any existing hooks and add our strip_auth_header hook + existing_hooks = base_client.event_hooks + new_request_hooks = list(existing_hooks.get('request', [])) + new_request_hooks.append(strip_auth_header) + + merged_hooks = dict(existing_hooks) + merged_hooks['request'] = new_request_hooks + + # Create new client based on the base client's configuration + base_client._event_hooks = merged_hooks # type: ignore[attr-defined] + return base_client + + # Scenario 1: BYOK with unauthenticated gateway (api_key required) + @overload + def __init__(self, *, account_id: str, gateway_id: str, api_key: str) -> None: ... + + @overload + def __init__(self, *, account_id: str, gateway_id: str, api_key: str, http_client: httpx.AsyncClient) -> None: ... + + # Scenario 2: BYOK with authenticated gateway (api_key + cf_aig_authorization) + @overload + def __init__(self, *, account_id: str, gateway_id: str, api_key: str, cf_aig_authorization: str) -> None: ... + + @overload + def __init__( + self, + *, + account_id: str, + gateway_id: str, + api_key: str, + cf_aig_authorization: str, + http_client: httpx.AsyncClient, + ) -> None: ... + + # Scenario 3: Stored keys with authenticated gateway (use_gateway_keys=True, cf_aig_authorization required) + @overload + def __init__( + self, *, account_id: str, gateway_id: str, cf_aig_authorization: str, use_gateway_keys: bool = True + ) -> None: ... + + @overload + def __init__( + self, + *, + account_id: str, + gateway_id: str, + cf_aig_authorization: str, + http_client: httpx.AsyncClient, + use_gateway_keys: bool = True, + ) -> None: ... + + # Advanced: Pre-configured OpenAI client + @overload + def __init__(self, *, account_id: str, gateway_id: str, openai_client: AsyncOpenAI) -> None: ... + + def __init__( + self, + *, + account_id: str | None = None, + gateway_id: str | None = None, + api_key: str | None = None, + cf_aig_authorization: str | None = None, + openai_client: AsyncOpenAI | None = None, + http_client: httpx.AsyncClient | None = None, + use_gateway_keys: bool = False, + ) -> None: + """Initialize the Cloudflare AI Gateway provider. + + Args: + account_id: Your Cloudflare account ID. Can also be set via CLOUDFLARE_ACCOUNT_ID environment variable. + gateway_id: Your Cloudflare AI Gateway ID. Can also be set via CLOUDFLARE_GATEWAY_ID environment variable. + api_key: The API key for the upstream provider (OpenAI, Anthropic, etc.). + - Required for BYOK (bring your own key) mode (default) + - Do NOT provide when use_gateway_keys=True (conflicts with stored keys mode) + - Optional when using the openai_client parameter (pre-configured client) + cf_aig_authorization: Authorization token for authenticated gateways. + - Required when use_gateway_keys=True (stored keys mode) + - Optional for BYOK mode (provides additional gateway authentication) + - Can also be set via CLOUDFLARE_AI_GATEWAY_AUTH environment variable + openai_client: Optional pre-configured AsyncOpenAI client for advanced use cases. + http_client: Optional HTTP client to use for requests. + use_gateway_keys: Whether to use API keys stored in your Cloudflare dashboard (default: False). + - Set to True to use stored keys mode (requires cf_aig_authorization) + - When True, do not provide api_key (they are mutually exclusive) + - When False (default), you must provide api_key for BYOK mode + + Raises: + UserError: If use_gateway_keys=True and api_key is also provided (conflicting configuration). + """ + account_id = account_id or os.getenv('CLOUDFLARE_ACCOUNT_ID') + gateway_id = gateway_id or os.getenv('CLOUDFLARE_GATEWAY_ID') + + if not account_id: + raise UserError( + 'Set the `CLOUDFLARE_ACCOUNT_ID` environment variable ' + 'or pass it via `CloudflareProvider(account_id=...)` to use the Cloudflare provider.' + ) + + if not gateway_id: + raise UserError( + 'Set the `CLOUDFLARE_GATEWAY_ID` environment variable ' + 'or pass it via `CloudflareProvider(gateway_id=...)` to use the Cloudflare provider.' + ) + + cf_aig_authorization = cf_aig_authorization or os.getenv('CLOUDFLARE_AI_GATEWAY_AUTH') + + if use_gateway_keys: + # Stored keys mode requires authenticated gateway + if cf_aig_authorization is None: + raise UserError( + 'When use_gateway_keys=True, you must provide cf_aig_authorization.\n' + 'Stored keys (API keys stored in Cloudflare dashboard) require an authenticated gateway.' + ) + # Can't use both stored keys and provide your own api_key + if api_key is not None: + raise UserError( + 'When use_gateway_keys=True, do not provide an api_key.\n' + 'use_gateway_keys=True means using API keys stored in your Cloudflare dashboard,\n' + 'which is incompatible with providing your own api_key (BYOK mode).' + ) + # Use placeholder for AsyncOpenAI (required by AsyncOpenAI client) but we'll strip the Authorization header + api_key = 'stored-keys-placeholder' + elif api_key is None and openai_client is None: + # Not using stored keys, so api_key is required (unless using pre-configured openai_client) + raise UserError( + 'When use_gateway_keys=False (the default), you must provide an api_key for BYOK mode.\n' + 'To use API keys stored in your Cloudflare dashboard, set use_gateway_keys=True and provide cf_aig_authorization.' + ) + + self._base_url = f'https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/compat' + + default_headers = { + 'http-referer': 'https://ai.pydantic.dev/', + 'x-title': 'pydantic-ai', + } + + if cf_aig_authorization: + default_headers['cf-aig-authorization'] = cf_aig_authorization + + if openai_client is not None: + self._client = openai_client + elif http_client is not None: + # If user provided http_client and we're in stored keys mode, we need to wrap it + if use_gateway_keys: + http_client = self._create_stored_keys_client(http_client) + self._client = AsyncOpenAI( + base_url=self._base_url, api_key=api_key, http_client=http_client, default_headers=default_headers + ) + else: + http_client = cached_async_http_client(provider='cloudflare') + # In stored keys mode, wrap the client to strip Authorization header + if use_gateway_keys: + http_client = self._create_stored_keys_client(http_client) + self._client = AsyncOpenAI( + base_url=self._base_url, api_key=api_key, http_client=http_client, default_headers=default_headers + ) diff --git a/tests/providers/test_cloudflare.py b/tests/providers/test_cloudflare.py new file mode 100644 index 0000000000..58f007e924 --- /dev/null +++ b/tests/providers/test_cloudflare.py @@ -0,0 +1,363 @@ +import re + +import httpx +import pytest +from pytest_mock import MockerFixture + +from pydantic_ai import Agent +from pydantic_ai.exceptions import UserError +from pydantic_ai.profiles.anthropic import anthropic_model_profile +from pydantic_ai.profiles.cohere import cohere_model_profile +from pydantic_ai.profiles.deepseek import deepseek_model_profile +from pydantic_ai.profiles.google import GoogleJsonSchemaTransformer, google_model_profile +from pydantic_ai.profiles.grok import grok_model_profile +from pydantic_ai.profiles.meta import meta_model_profile +from pydantic_ai.profiles.mistral import mistral_model_profile +from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, openai_model_profile + +from ..conftest import TestEnv, try_import + +with try_import() as imports_successful: + import openai + + from pydantic_ai.providers.cloudflare import CloudflareProvider + + +pytestmark = [ + pytest.mark.skipif(not imports_successful(), reason='openai not installed'), + pytest.mark.vcr, + pytest.mark.anyio, +] + + +def test_cloudflare_provider(): + provider = CloudflareProvider(account_id='test-account-id', gateway_id='test-gateway-id', api_key='api-key') + assert provider.name == 'cloudflare' + assert provider.base_url == 'https://gateway.ai.cloudflare.com/v1/test-account-id/test-gateway-id/compat' + assert isinstance(provider.client, openai.AsyncOpenAI) + assert provider.client.api_key == 'api-key' + + +def test_cloudflare_provider_need_account_id(env: TestEnv) -> None: + env.remove('CLOUDFLARE_ACCOUNT_ID') + with pytest.raises( + UserError, + match=re.escape( + 'Set the `CLOUDFLARE_ACCOUNT_ID` environment variable ' + 'or pass it via `CloudflareProvider(account_id=...)` to use the Cloudflare provider.' + ), + ): + CloudflareProvider(gateway_id='test-gateway-id', api_key='api-key') # type: ignore[call-overload] + + +def test_cloudflare_provider_need_gateway_id(env: TestEnv) -> None: + env.remove('CLOUDFLARE_GATEWAY_ID') + with pytest.raises( + UserError, + match=re.escape( + 'Set the `CLOUDFLARE_GATEWAY_ID` environment variable ' + 'or pass it via `CloudflareProvider(gateway_id=...)` to use the Cloudflare provider.' + ), + ): + CloudflareProvider(account_id='test-account-id', api_key='api-key') # type: ignore[call-overload] + + +def test_cloudflare_provider_from_env(env: TestEnv) -> None: + env.set('CLOUDFLARE_ACCOUNT_ID', 'env-account-id') + env.set('CLOUDFLARE_GATEWAY_ID', 'env-gateway-id') + + # Test with explicit api_key (account_id and gateway_id from env) + provider = CloudflareProvider(api_key='env-api-key') # type: ignore[call-overload] + assert provider.base_url == 'https://gateway.ai.cloudflare.com/v1/env-account-id/env-gateway-id/compat' + assert provider.client.api_key == 'env-api-key' + + +def test_cloudflare_provider_with_cf_aig_authorization(): + provider = CloudflareProvider( + account_id='test-account-id', + gateway_id='test-gateway-id', + api_key='api-key', + cf_aig_authorization='gateway-token', + ) + assert provider.client.default_headers['cf-aig-authorization'] == 'gateway-token' + + +def test_cloudflare_provider_cf_aig_authorization_from_env(env: TestEnv) -> None: + env.set('CLOUDFLARE_ACCOUNT_ID', 'test-account-id') + env.set('CLOUDFLARE_GATEWAY_ID', 'test-gateway-id') + env.set('CLOUDFLARE_AI_GATEWAY_AUTH', 'env-gateway-token') + + provider = CloudflareProvider(api_key='api-key') # type: ignore[call-overload] + assert provider.client.default_headers['cf-aig-authorization'] == 'env-gateway-token' + + +def test_cloudflare_pass_openai_client() -> None: + openai_client = openai.AsyncOpenAI(api_key='api-key') + provider = CloudflareProvider( + account_id='test-account-id', gateway_id='test-gateway-id', openai_client=openai_client + ) + assert provider.client == openai_client + + +def test_cloudflare_provider_model_profile(mocker: MockerFixture): + provider = CloudflareProvider(account_id='test-account-id', gateway_id='test-gateway-id', api_key='api-key') + + ns = 'pydantic_ai.providers.cloudflare' + + # Mock all profile functions + anthropic_mock = mocker.patch(f'{ns}.anthropic_model_profile', wraps=anthropic_model_profile) + cohere_mock = mocker.patch(f'{ns}.cohere_model_profile', wraps=cohere_model_profile) + deepseek_mock = mocker.patch(f'{ns}.deepseek_model_profile', wraps=deepseek_model_profile) + google_mock = mocker.patch(f'{ns}.google_model_profile', wraps=google_model_profile) + grok_mock = mocker.patch(f'{ns}.grok_model_profile', wraps=grok_model_profile) + mistral_mock = mocker.patch(f'{ns}.mistral_model_profile', wraps=mistral_model_profile) + openai_mock = mocker.patch(f'{ns}.openai_model_profile', wraps=openai_model_profile) + + # Mock GroqProvider and CerebrasProvider since they handle complex model profiling internally + groq_provider_instance = mocker.MagicMock() + cerebras_provider_instance = mocker.MagicMock() + + # Configure the mocks to return appropriate profiles + def groq_profile_func(name: str): + if name.lower().startswith('llama'): + return meta_model_profile(name) + elif name.lower().startswith('gemma'): + return google_model_profile(name) + else: + return openai_model_profile(name) + + def cerebras_profile_func(name: str): + if name.lower().startswith('llama'): + return meta_model_profile(name) + else: + return openai_model_profile(name) + + groq_provider_instance.model_profile.side_effect = groq_profile_func + cerebras_provider_instance.model_profile.side_effect = cerebras_profile_func + + mocker.patch(f'{ns}.GroqProvider', return_value=groq_provider_instance) + mocker.patch(f'{ns}.CerebrasProvider', return_value=cerebras_provider_instance) + + # Test openai provider + profile = provider.model_profile('openai/gpt-4o') + openai_mock.assert_called_with('gpt-4o') + assert profile is not None + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + # Test anthropic provider + profile = provider.model_profile('anthropic/claude-3-sonnet') + anthropic_mock.assert_called_with('claude-3-sonnet') + assert profile is not None + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + # Test cohere provider + profile = provider.model_profile('cohere/command-r-plus') + cohere_mock.assert_called_with('command-r-plus') + assert profile is not None + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + # Test deepseek provider + profile = provider.model_profile('deepseek/deepseek-chat') + deepseek_mock.assert_called_with('deepseek-chat') + assert profile is not None + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + # Test mistral provider + profile = provider.model_profile('mistral/mistral-large') + mistral_mock.assert_called_with('mistral-large') + assert profile is not None + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + # Test google-ai-studio provider + profile = provider.model_profile('google-ai-studio/gemini-1.5-pro') + google_mock.assert_called_with('gemini-1.5-pro') + assert profile is not None + assert profile.json_schema_transformer == GoogleJsonSchemaTransformer + + # Test grok provider + profile = provider.model_profile('grok/grok-beta') + grok_mock.assert_called_with('grok-beta') + assert profile is not None + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + # Test xai provider (alias for grok) + profile = provider.model_profile('xai/grok-2') + grok_mock.assert_called_with('grok-2') + assert profile is not None + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + # Test groq provider with llama model (should delegate to GroqProvider) + profile = provider.model_profile('groq/llama-3.3-70b-versatile') + groq_provider_instance.model_profile.assert_called_with('llama-3.3-70b-versatile') + assert profile is not None + + # Test groq provider with gemma model (should delegate to GroqProvider) + profile = provider.model_profile('groq/gemma-7b-it') + groq_provider_instance.model_profile.assert_called_with('gemma-7b-it') + assert profile is not None + + # Test perplexity provider (uses OpenAI-compatible API) + profile = provider.model_profile('perplexity/llama-3.1-sonar-small-128k-online') + openai_mock.assert_called_with('llama-3.1-sonar-small-128k-online') + assert profile is not None + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + # Test workers-ai provider (Cloudflare's own AI service) + profile = provider.model_profile('workers-ai/@cf/meta/llama-3.1-8b-instruct') + openai_mock.assert_called_with('@cf/meta/llama-3.1-8b-instruct') + assert profile is not None + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + # Test cerebras provider with llama model (should delegate to CerebrasProvider) + profile = provider.model_profile('cerebras/llama3.1-8b') + cerebras_provider_instance.model_profile.assert_called_with('llama3.1-8b') + assert profile is not None + + +def test_cloudflare_with_http_client(): + http_client = httpx.AsyncClient() + provider = CloudflareProvider( + account_id='test-account-id', gateway_id='test-gateway-id', api_key='test-key', http_client=http_client + ) + assert provider.client.api_key == 'test-key' + assert ( + str(provider.client.base_url) == 'https://gateway.ai.cloudflare.com/v1/test-account-id/test-gateway-id/compat/' + ) + + +def test_cloudflare_provider_invalid_model_name(): + provider = CloudflareProvider(account_id='test-account-id', gateway_id='test-gateway-id', api_key='api-key') + + with pytest.raises(UserError, match="Model name must be in 'provider/model' format"): + provider.model_profile('invalid-model-name') + + +def test_cloudflare_provider_unknown_provider(): + provider = CloudflareProvider(account_id='test-account-id', gateway_id='test-gateway-id', api_key='api-key') + + profile = provider.model_profile('unknown/gpt-4') + assert profile is not None + assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer + + +def test_cloudflare_default_headers(): + provider = CloudflareProvider(account_id='test-account-id', gateway_id='test-gateway-id', api_key='api-key') + + # Check that default headers are set + assert provider.client.default_headers['http-referer'] == 'https://ai.pydantic.dev/' + assert provider.client.default_headers['x-title'] == 'pydantic-ai' + + +def test_cloudflare_provider_stored_keys(): + """Test stored keys mode - API keys stored in Cloudflare dashboard (requires authenticated gateway).""" + provider = CloudflareProvider( + account_id='test-account-id', + gateway_id='test-gateway-id', + cf_aig_authorization='gateway-token', + use_gateway_keys=True, + ) + # api_key is set to placeholder for AsyncOpenAI, but Authorization header will be stripped + assert provider.client.api_key == 'stored-keys-placeholder' + assert provider.base_url == 'https://gateway.ai.cloudflare.com/v1/test-account-id/test-gateway-id/compat' + assert provider.client.default_headers['cf-aig-authorization'] == 'gateway-token' + + +def test_cloudflare_provider_missing_credentials(): + """Test that error is raised when api_key is missing and use_gateway_keys=False.""" + with pytest.raises( + UserError, + match=re.escape('When use_gateway_keys=False (the default), you must provide an api_key for BYOK mode.'), + ): + CloudflareProvider(account_id='test-account-id', gateway_id='test-gateway-id') # type: ignore[call-overload] + + +def test_cloudflare_provider_use_gateway_keys_with_api_key_conflict(): + """Test that error is raised when both use_gateway_keys=True and api_key are provided.""" + with pytest.raises( + UserError, + match=re.escape('When use_gateway_keys=True, do not provide an api_key.'), + ): + CloudflareProvider( # type: ignore[call-overload] + account_id='test-account-id', + gateway_id='test-gateway-id', + cf_aig_authorization='gateway-token', + api_key='sk-test', + use_gateway_keys=True, + ) + + +def test_cloudflare_provider_use_gateway_keys_without_auth(): + """Test that error is raised when use_gateway_keys=True but cf_aig_authorization is missing.""" + with pytest.raises( + UserError, + match=re.escape('When use_gateway_keys=True, you must provide cf_aig_authorization.'), + ): + CloudflareProvider( + account_id='test-account-id', + gateway_id='test-gateway-id', + use_gateway_keys=True, # type: ignore[call-overload] + ) + + +async def test_cloudflare_stored_keys_strips_auth_header(): + """Test that Authorization header is stripped in stored keys mode so Cloudflare uses stored keys.""" + import httpx + + # Create a custom http_client to pass to the provider + custom_client = httpx.AsyncClient() + + provider = CloudflareProvider( + account_id='test-account-id', + gateway_id='test-gateway-id', + cf_aig_authorization='gateway-token', + http_client=custom_client, + use_gateway_keys=True, + ) + + # Get the http_client from the AsyncOpenAI client (accessing private attribute for testing) + http_client: httpx.AsyncClient = provider.client._client # type: ignore[attr-defined] + + # Create a test request with Authorization header + request = http_client.build_request('POST', 'https://example.com', headers={'Authorization': 'Bearer test'}) + + # Trigger the request event hooks + for hook in http_client.event_hooks.get('request', []): + await hook(request) + + # Verify Authorization header was removed by the event hook + assert 'authorization' not in request.headers + + +def test_cloudflare_documented_patterns(): + """Test the exact usage patterns from the documentation work correctly. + + This test validates the examples shown in docs/models/openai.md work as documented. + """ + from pydantic_ai.models.openai import OpenAIChatModel + + # Example 1: Basic BYOK mode (from docs) + model = OpenAIChatModel( + 'openai/gpt-4o', + provider=CloudflareProvider( + account_id='your-account-id', + gateway_id='your-gateway-id', + api_key='your-openai-api-key', + ), + ) + agent = Agent(model) + assert isinstance(agent.model, OpenAIChatModel) + assert agent.model.model_name == 'openai/gpt-4o' + + # Example 2: Stored keys mode (from docs) + model = OpenAIChatModel( + 'anthropic/claude-3-5-sonnet', + provider=CloudflareProvider( + account_id='your-account-id', + gateway_id='your-gateway-id', + cf_aig_authorization='your-gateway-token', + use_gateway_keys=True, + ), + ) + agent = Agent(model) + assert isinstance(agent.model, OpenAIChatModel) + assert agent.model.model_name == 'anthropic/claude-3-5-sonnet' diff --git a/tests/providers/test_provider_names.py b/tests/providers/test_provider_names.py index d44ab68276..654272efd1 100644 --- a/tests/providers/test_provider_names.py +++ b/tests/providers/test_provider_names.py @@ -16,6 +16,7 @@ from pydantic_ai.providers.anthropic import AnthropicProvider from pydantic_ai.providers.azure import AzureProvider + from pydantic_ai.providers.cloudflare import CloudflareProvider from pydantic_ai.providers.cohere import CohereProvider from pydantic_ai.providers.deepseek import DeepSeekProvider from pydantic_ai.providers.fireworks import FireworksProvider @@ -37,6 +38,7 @@ test_infer_provider_params = [ ('anthropic', AnthropicProvider, 'ANTHROPIC_API_KEY'), + ('cloudflare', CloudflareProvider, 'CLOUDFLARE_ACCOUNT_ID'), ('cohere', CohereProvider, 'CO_API_KEY'), ('deepseek', DeepSeekProvider, 'DEEPSEEK_API_KEY'), ('openrouter', OpenRouterProvider, 'OPENROUTER_API_KEY'), From f816c091d6e5ca7ba1d4ede3a7f2735dec0a558c Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Sun, 19 Oct 2025 18:21:01 -0500 Subject: [PATCH 2/6] Fix missing imports in CloudflareProvider docstring examples Add missing import statements to examples 2 and 3 in the class docstring to resolve Ruff F821 errors (undefined name CloudflareProvider). Generated with Claude Code https://claude.com/claude-code Co-Authored-By: Claude --- pydantic_ai_slim/pydantic_ai/providers/cloudflare.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pydantic_ai_slim/pydantic_ai/providers/cloudflare.py b/pydantic_ai_slim/pydantic_ai/providers/cloudflare.py index 475a10fa51..999f5828d5 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/cloudflare.py +++ b/pydantic_ai_slim/pydantic_ai/providers/cloudflare.py @@ -80,6 +80,9 @@ class CloudflareProvider(Provider[AsyncOpenAI]): 2. BYOK with authenticated gateway (API key + gateway authentication): ```python + from pydantic_ai import Agent + from pydantic_ai.providers.cloudflare import CloudflareProvider + provider = CloudflareProvider( account_id='your-account-id', gateway_id='your-gateway-id', @@ -91,6 +94,9 @@ class CloudflareProvider(Provider[AsyncOpenAI]): 3. Stored keys mode (use API keys stored in Cloudflare dashboard): ```python + from pydantic_ai import Agent + from pydantic_ai.providers.cloudflare import CloudflareProvider + # Requires authenticated gateway - API keys are stored in your Cloudflare dashboard # Set use_gateway_keys=True and provide cf_aig_authorization provider = CloudflareProvider( From 070c5620e14b430c50e6c27a6e5de775dae3c280 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Sun, 19 Oct 2025 18:29:56 -0500 Subject: [PATCH 3/6] Fix CloudflareProvider docstring examples to use OpenAIChatModel Replace broken Agent('model', provider=provider) pattern with the correct OpenAIChatModel pattern in all three usage examples. This matches the fix already applied to docs/models/openai.md. Generated with Claude Code https://claude.com/claude-code Co-Authored-By: Claude --- .../pydantic_ai/providers/cloudflare.py | 48 +++++++++++-------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/providers/cloudflare.py b/pydantic_ai_slim/pydantic_ai/providers/cloudflare.py index 999f5828d5..6c347f019a 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/cloudflare.py +++ b/pydantic_ai_slim/pydantic_ai/providers/cloudflare.py @@ -68,44 +68,54 @@ class CloudflareProvider(Provider[AsyncOpenAI]): 1. BYOK with unauthenticated gateway (bring your own API key): ```python from pydantic_ai import Agent + from pydantic_ai.models.openai import OpenAIChatModel from pydantic_ai.providers.cloudflare import CloudflareProvider - provider = CloudflareProvider( - account_id='your-account-id', - gateway_id='your-gateway-id', - api_key='your-openai-api-key' # Your own provider API key + model = OpenAIChatModel( + 'openai/gpt-4o', + provider=CloudflareProvider( + account_id='your-account-id', + gateway_id='your-gateway-id', + api_key='your-openai-api-key', + ), ) - agent = Agent('openai/gpt-4o', provider=provider) + agent = Agent(model) ``` 2. BYOK with authenticated gateway (API key + gateway authentication): ```python from pydantic_ai import Agent + from pydantic_ai.models.openai import OpenAIChatModel from pydantic_ai.providers.cloudflare import CloudflareProvider - provider = CloudflareProvider( - account_id='your-account-id', - gateway_id='your-gateway-id', - api_key='your-openai-api-key', - cf_aig_authorization='your-gateway-token' + model = OpenAIChatModel( + 'anthropic/claude-3-5-sonnet', + provider=CloudflareProvider( + account_id='your-account-id', + gateway_id='your-gateway-id', + api_key='your-openai-api-key', + cf_aig_authorization='your-gateway-token', + ), ) - agent = Agent('anthropic/claude-3-5-sonnet', provider=provider) + agent = Agent(model) ``` 3. Stored keys mode (use API keys stored in Cloudflare dashboard): ```python from pydantic_ai import Agent + from pydantic_ai.models.openai import OpenAIChatModel from pydantic_ai.providers.cloudflare import CloudflareProvider - # Requires authenticated gateway - API keys are stored in your Cloudflare dashboard - # Set use_gateway_keys=True and provide cf_aig_authorization - provider = CloudflareProvider( - account_id='your-account-id', - gateway_id='your-gateway-id', - cf_aig_authorization='your-gateway-token', - use_gateway_keys=True + model = OpenAIChatModel( + 'openai/gpt-4o', + provider=CloudflareProvider( + account_id='your-account-id', + gateway_id='your-gateway-id', + cf_aig_authorization='your-gateway-token', + use_gateway_keys=True, + ), ) - agent = Agent('openai/gpt-4o', provider=provider) + agent = Agent(model) ``` """ From 3ba51cd540e9f3eaa9204c1e2e4254c7feaa0333 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Sun, 19 Oct 2025 19:01:52 -0500 Subject: [PATCH 4/6] fix test coverage for the cloudflare provider --- tests/providers/test_cloudflare.py | 70 ++++++++++++++++-------------- 1 file changed, 37 insertions(+), 33 deletions(-) diff --git a/tests/providers/test_cloudflare.py b/tests/providers/test_cloudflare.py index 58f007e924..ec02d86c0d 100644 --- a/tests/providers/test_cloudflare.py +++ b/tests/providers/test_cloudflare.py @@ -5,13 +5,13 @@ from pytest_mock import MockerFixture from pydantic_ai import Agent +from pydantic_ai._json_schema import InlineDefsJsonSchemaTransformer from pydantic_ai.exceptions import UserError from pydantic_ai.profiles.anthropic import anthropic_model_profile from pydantic_ai.profiles.cohere import cohere_model_profile from pydantic_ai.profiles.deepseek import deepseek_model_profile from pydantic_ai.profiles.google import GoogleJsonSchemaTransformer, google_model_profile from pydantic_ai.profiles.grok import grok_model_profile -from pydantic_ai.profiles.meta import meta_model_profile from pydantic_ai.profiles.mistral import mistral_model_profile from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, openai_model_profile @@ -99,7 +99,11 @@ def test_cloudflare_pass_openai_client() -> None: assert provider.client == openai_client -def test_cloudflare_provider_model_profile(mocker: MockerFixture): +def test_cloudflare_provider_model_profile(mocker: MockerFixture, env: TestEnv): + # Set dummy API keys so we can use real GroqProvider and CerebrasProvider + env.set('GROQ_API_KEY', 'test-groq-key') + env.set('CEREBRAS_API_KEY', 'test-cerebras-key') + provider = CloudflareProvider(account_id='test-account-id', gateway_id='test-gateway-id', api_key='api-key') ns = 'pydantic_ai.providers.cloudflare' @@ -113,30 +117,8 @@ def test_cloudflare_provider_model_profile(mocker: MockerFixture): mistral_mock = mocker.patch(f'{ns}.mistral_model_profile', wraps=mistral_model_profile) openai_mock = mocker.patch(f'{ns}.openai_model_profile', wraps=openai_model_profile) - # Mock GroqProvider and CerebrasProvider since they handle complex model profiling internally - groq_provider_instance = mocker.MagicMock() - cerebras_provider_instance = mocker.MagicMock() - - # Configure the mocks to return appropriate profiles - def groq_profile_func(name: str): - if name.lower().startswith('llama'): - return meta_model_profile(name) - elif name.lower().startswith('gemma'): - return google_model_profile(name) - else: - return openai_model_profile(name) - - def cerebras_profile_func(name: str): - if name.lower().startswith('llama'): - return meta_model_profile(name) - else: - return openai_model_profile(name) - - groq_provider_instance.model_profile.side_effect = groq_profile_func - cerebras_provider_instance.model_profile.side_effect = cerebras_profile_func - - mocker.patch(f'{ns}.GroqProvider', return_value=groq_provider_instance) - mocker.patch(f'{ns}.CerebrasProvider', return_value=cerebras_provider_instance) + # Use real GroqProvider and CerebrasProvider - they don't make API calls for model_profile() + # We just need dummy API keys which are set via env vars above # Test openai provider profile = provider.model_profile('openai/gpt-4o') @@ -186,15 +168,17 @@ def cerebras_profile_func(name: str): assert profile is not None assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer - # Test groq provider with llama model (should delegate to GroqProvider) + # Test groq provider with llama model (delegates to GroqProvider which returns meta profile) + # meta_model_profile uses InlineDefsJsonSchemaTransformer profile = provider.model_profile('groq/llama-3.3-70b-versatile') - groq_provider_instance.model_profile.assert_called_with('llama-3.3-70b-versatile') assert profile is not None + assert profile.json_schema_transformer == InlineDefsJsonSchemaTransformer - # Test groq provider with gemma model (should delegate to GroqProvider) + # Test groq provider with gemma model (delegates to GroqProvider which returns google profile) + # google_model_profile uses GoogleJsonSchemaTransformer profile = provider.model_profile('groq/gemma-7b-it') - groq_provider_instance.model_profile.assert_called_with('gemma-7b-it') assert profile is not None + assert profile.json_schema_transformer == GoogleJsonSchemaTransformer # Test perplexity provider (uses OpenAI-compatible API) profile = provider.model_profile('perplexity/llama-3.1-sonar-small-128k-online') @@ -208,10 +192,17 @@ def cerebras_profile_func(name: str): assert profile is not None assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer - # Test cerebras provider with llama model (should delegate to CerebrasProvider) + # Test cerebras provider with llama model (delegates to CerebrasProvider which returns meta profile) + # meta_model_profile uses InlineDefsJsonSchemaTransformer, wrapped by CerebrasProvider's OpenAIModelProfile profile = provider.model_profile('cerebras/llama3.1-8b') - cerebras_provider_instance.model_profile.assert_called_with('llama3.1-8b') assert profile is not None + assert profile.json_schema_transformer == InlineDefsJsonSchemaTransformer + + # Test cerebras provider with qwen model (delegates to CerebrasProvider which returns qwen profile) + # qwen_model_profile uses InlineDefsJsonSchemaTransformer, wrapped by CerebrasProvider's OpenAIModelProfile + profile = provider.model_profile('cerebras/qwen3.5-8b') + assert profile is not None + assert profile.json_schema_transformer == InlineDefsJsonSchemaTransformer def test_cloudflare_with_http_client(): @@ -317,7 +308,7 @@ async def test_cloudflare_stored_keys_strips_auth_header(): # Get the http_client from the AsyncOpenAI client (accessing private attribute for testing) http_client: httpx.AsyncClient = provider.client._client # type: ignore[attr-defined] - # Create a test request with Authorization header + # Test 1: Request WITH Authorization header - should be stripped request = http_client.build_request('POST', 'https://example.com', headers={'Authorization': 'Bearer test'}) # Trigger the request event hooks @@ -327,6 +318,19 @@ async def test_cloudflare_stored_keys_strips_auth_header(): # Verify Authorization header was removed by the event hook assert 'authorization' not in request.headers + # Test 2: Request WITHOUT Authorization header - should pass through unchanged (covers line 192 else branch) + request_no_auth = http_client.build_request( + 'POST', 'https://example.com', headers={'Content-Type': 'application/json'} + ) + + # Trigger the request event hooks + for hook in http_client.event_hooks.get('request', []): + await hook(request_no_auth) + + # Verify other headers are preserved and no error occurred + assert 'content-type' in request_no_auth.headers + assert 'authorization' not in request_no_auth.headers + def test_cloudflare_documented_patterns(): """Test the exact usage patterns from the documentation work correctly. From d472bee46953da3596a323ab81faa8f1514810a2 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Sun, 19 Oct 2025 19:15:56 -0500 Subject: [PATCH 5/6] chore: retry CI From 2bce1ac244602d7102177061fcd03edde6c007a4 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Tue, 21 Oct 2025 20:25:21 -0500 Subject: [PATCH 6/6] Address all PR review comments for Cloudflare AI Gateway provider MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit implements all 9 review comments from @DouweM: 1. Add cloudflare: model name shorthand support - Added 'cloudflare' to OpenAI-compatible providers in models/__init__.py - Added 10 representative Cloudflare model names to KnownModelName - Added shorthand usage example to docs 2-3. Extract Groq/Cerebras model profiling to top-level functions - Created groq_provider_model_profile() in providers/groq.py - Created cerebras_provider_model_profile() in providers/cerebras.py - Updated CloudflareProvider to use direct imports instead of instantiation 4. Add perplexity_model_profile function - Created profiles/perplexity.py with perplexity_model_profile() - Updated CloudflareProvider to use it 5. Remove AI-generated comment from cloudflare.py 6. Fix HTTP client monkeypatching - Use empty string for api_key in CF-managed keys mode - Removed _create_stored_keys_client() method (~26 lines) - Leverages OpenAI SDK's built-in behavior 7. Rename cf_aig_authorization → gateway_auth_token - Updated parameter name throughout codebase - HTTP header stays 'cf-aig-authorization' (Cloudflare API requirement) 8. Remove use_gateway_keys parameter - Made CF-managed keys mode detection implicit - Updated overload signatures - Simplified initialization logic - Deleted 2 obsolete tests Also updated terminology from "BYOK/stored keys" to "user-managed/CF-managed" for clarity and to avoid confusion with Cloudflare's BYOK feature. All tests passing (16 Cloudflare tests, 100% coverage maintained). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- docs/models/openai.md | 19 ++- .../pydantic_ai/models/__init__.py | 11 ++ pydantic_ai_slim/pydantic_ai/models/openai.py | 2 + .../pydantic_ai/profiles/perplexity.py | 8 + .../pydantic_ai/providers/cerebras.py | 51 +++--- .../pydantic_ai/providers/cloudflare.py | 152 +++++------------- .../pydantic_ai/providers/groq.py | 49 +++--- tests/models/test_model_names.py | 13 ++ tests/providers/test_cloudflare.py | 97 +++-------- tests/test_cli.py | 1 + tests/test_examples.py | 2 + 11 files changed, 172 insertions(+), 233 deletions(-) create mode 100644 pydantic_ai_slim/pydantic_ai/profiles/perplexity.py diff --git a/docs/models/openai.md b/docs/models/openai.md index 6235803b85..a626167f15 100644 --- a/docs/models/openai.md +++ b/docs/models/openai.md @@ -448,7 +448,21 @@ To use [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/), f !!! note This provider uses Cloudflare's [unified API endpoint](https://developers.cloudflare.com/ai-gateway/usage/chat-completion/) for routing requests to multiple AI providers. For the full list of supported providers, see [Cloudflare's documentation](https://developers.cloudflare.com/ai-gateway/usage/chat-completion/#supported-providers). -You can set the `CLOUDFLARE_ACCOUNT_ID`, `CLOUDFLARE_GATEWAY_ID`, and optionally `CLOUDFLARE_AI_GATEWAY_AUTH` environment variables and use [`CloudflareProvider`][pydantic_ai.providers.cloudflare.CloudflareProvider]: +You can set the `CLOUDFLARE_ACCOUNT_ID`, `CLOUDFLARE_GATEWAY_ID`, and optionally `CLOUDFLARE_AI_GATEWAY_AUTH` environment variables and use the `cloudflare:` model name prefix: + +```python test="skip - requires actual API keys" +from pydantic_ai import Agent + +# Set via environment or in code: +# CLOUDFLARE_ACCOUNT_ID='your-account-id' +# CLOUDFLARE_GATEWAY_ID='your-gateway-id' +# OPENAI_API_KEY='your-openai-api-key' + +agent = Agent('cloudflare:openai/gpt-4o') +... +``` + +Or use [`CloudflareProvider`][pydantic_ai.providers.cloudflare.CloudflareProvider] directly: ```python from pydantic_ai import Agent @@ -479,8 +493,7 @@ model = OpenAIChatModel( provider=CloudflareProvider( account_id='your-account-id', gateway_id='your-gateway-id', - cf_aig_authorization='your-gateway-token', - use_gateway_keys=True, + gateway_auth_token='your-gateway-token', ), ) agent = Agent(model) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 0520486be7..d6ad517107 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -127,6 +127,16 @@ 'cerebras:qwen-3-32b', 'cerebras:qwen-3-coder-480b', 'cerebras:qwen-3-235b-a22b-thinking-2507', + 'cloudflare:anthropic/claude-3-5-sonnet', + 'cloudflare:cohere/command-r-plus', + 'cloudflare:deepseek/deepseek-chat', + 'cloudflare:google/gemini-2.0-flash', + 'cloudflare:groq/llama-3.3-70b-versatile', + 'cloudflare:mistral/mistral-large-latest', + 'cloudflare:openai/gpt-4o', + 'cloudflare:perplexity/llama-3.1-sonar-small-128k-online', + 'cloudflare:workers-ai/@cf/meta/llama-3.1-8b-instruct', + 'cloudflare:xai/grok-2-1212', 'cohere:c4ai-aya-expanse-32b', 'cohere:c4ai-aya-expanse-8b', 'cohere:command', @@ -682,6 +692,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901 'azure', 'deepseek', 'cerebras', + 'cloudflare', 'fireworks', 'github', 'grok', diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index c0d7644670..7c63da5a3c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -272,6 +272,7 @@ def __init__( 'azure', 'deepseek', 'cerebras', + 'cloudflare', 'fireworks', 'github', 'grok', @@ -329,6 +330,7 @@ def __init__( 'azure', 'deepseek', 'cerebras', + 'cloudflare', 'fireworks', 'github', 'grok', diff --git a/pydantic_ai_slim/pydantic_ai/profiles/perplexity.py b/pydantic_ai_slim/pydantic_ai/profiles/perplexity.py new file mode 100644 index 0000000000..3a30e9ffbe --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/profiles/perplexity.py @@ -0,0 +1,8 @@ +from __future__ import annotations as _annotations + +from . import ModelProfile + + +def perplexity_model_profile(model_name: str) -> ModelProfile | None: + """Get the model profile for a Perplexity model.""" + return None diff --git a/pydantic_ai_slim/pydantic_ai/providers/cerebras.py b/pydantic_ai_slim/pydantic_ai/providers/cerebras.py index 267cf41b8c..4891eeddd3 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/cerebras.py +++ b/pydantic_ai_slim/pydantic_ai/providers/cerebras.py @@ -23,6 +23,35 @@ ) from _import_error +def cerebras_provider_model_profile(model_name: str) -> ModelProfile | None: + """Get the model profile for a model routed through Cerebras provider. + + This function handles model profiling for models that use Cerebras's API, + and applies Cerebras-specific settings like unsupported model parameters. + """ + prefix_to_profile = {'llama': meta_model_profile, 'qwen': qwen_model_profile, 'gpt-oss': harmony_model_profile} + + profile = None + for prefix, profile_func in prefix_to_profile.items(): + model_name = model_name.lower() + if model_name.startswith(prefix): + profile = profile_func(model_name) + + # According to https://inference-docs.cerebras.ai/resources/openai#currently-unsupported-openai-features, + # Cerebras doesn't support some model settings. + unsupported_model_settings = ( + 'frequency_penalty', + 'logit_bias', + 'presence_penalty', + 'parallel_tool_calls', + 'service_tier', + ) + return OpenAIModelProfile( + json_schema_transformer=OpenAIJsonSchemaTransformer, + openai_unsupported_model_settings=unsupported_model_settings, + ).update(profile) + + class CerebrasProvider(Provider[AsyncOpenAI]): """Provider for Cerebras API.""" @@ -39,27 +68,7 @@ def client(self) -> AsyncOpenAI: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: - prefix_to_profile = {'llama': meta_model_profile, 'qwen': qwen_model_profile, 'gpt-oss': harmony_model_profile} - - profile = None - for prefix, profile_func in prefix_to_profile.items(): - model_name = model_name.lower() - if model_name.startswith(prefix): - profile = profile_func(model_name) - - # According to https://inference-docs.cerebras.ai/resources/openai#currently-unsupported-openai-features, - # Cerebras doesn't support some model settings. - unsupported_model_settings = ( - 'frequency_penalty', - 'logit_bias', - 'presence_penalty', - 'parallel_tool_calls', - 'service_tier', - ) - return OpenAIModelProfile( - json_schema_transformer=OpenAIJsonSchemaTransformer, - openai_unsupported_model_settings=unsupported_model_settings, - ).update(profile) + return cerebras_provider_model_profile(model_name) @overload def __init__(self) -> None: ... diff --git a/pydantic_ai_slim/pydantic_ai/providers/cloudflare.py b/pydantic_ai_slim/pydantic_ai/providers/cloudflare.py index 6c347f019a..1bd2bc2c6d 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/cloudflare.py +++ b/pydantic_ai_slim/pydantic_ai/providers/cloudflare.py @@ -15,10 +15,11 @@ from pydantic_ai.profiles.grok import grok_model_profile from pydantic_ai.profiles.mistral import mistral_model_profile from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile +from pydantic_ai.profiles.perplexity import perplexity_model_profile from pydantic_ai.providers import Provider -from .cerebras import CerebrasProvider -from .groq import GroqProvider +from .cerebras import cerebras_provider_model_profile +from .groq import groq_provider_model_profile try: from openai import AsyncOpenAI @@ -29,23 +30,6 @@ ) from _import_error -def _groq_model_profile_cloudflare(model_name: str) -> ModelProfile | None: - """Get the model profile for Groq models routed through Cloudflare's unified API. - - Cloudflare routes to Groq's OpenAI-compatible endpoint, so we use prefix matching - similar to the native GroqProvider to determine the appropriate profile. - """ - return GroqProvider().model_profile(model_name) - - -def _cerebras_model_profile_cloudflare(model_name: str) -> ModelProfile | None: - """Get the model profile for Cerebras models routed through Cloudflare's unified API. - - Similar to the native CerebrasProvider, use prefix matching to determine profiles. - """ - return CerebrasProvider().model_profile(model_name) - - class CloudflareProvider(Provider[AsyncOpenAI]): """Provider for Cloudflare AI Gateway API. @@ -61,11 +45,11 @@ class CloudflareProvider(Provider[AsyncOpenAI]): This provider looks for these environment variables if they are not provided as parameters: - account_id: `CLOUDFLARE_ACCOUNT_ID` - gateway_id: `CLOUDFLARE_GATEWAY_ID` - - cf_aig_authorization: `CLOUDFLARE_AI_GATEWAY_AUTH` (optional) + - gateway_auth_token: `CLOUDFLARE_AI_GATEWAY_AUTH` (optional) There are three usage modes: - 1. BYOK with unauthenticated gateway (bring your own API key): + 1. User-managed keys with unauthenticated gateway: ```python from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIChatModel @@ -82,7 +66,7 @@ class CloudflareProvider(Provider[AsyncOpenAI]): agent = Agent(model) ``` - 2. BYOK with authenticated gateway (API key + gateway authentication): + 2. User-managed keys with authenticated gateway (API key + gateway authentication): ```python from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIChatModel @@ -94,13 +78,13 @@ class CloudflareProvider(Provider[AsyncOpenAI]): account_id='your-account-id', gateway_id='your-gateway-id', api_key='your-openai-api-key', - cf_aig_authorization='your-gateway-token', + gateway_auth_token='your-gateway-token', ), ) agent = Agent(model) ``` - 3. Stored keys mode (use API keys stored in Cloudflare dashboard): + 3. CF-managed keys mode (use API keys stored in Cloudflare dashboard): ```python from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIChatModel @@ -111,8 +95,7 @@ class CloudflareProvider(Provider[AsyncOpenAI]): provider=CloudflareProvider( account_id='your-account-id', gateway_id='your-gateway-id', - cf_aig_authorization='your-gateway-token', - use_gateway_keys=True, + gateway_auth_token='your-gateway-token', ), ) agent = Agent(model) @@ -143,19 +126,17 @@ def model_profile(self, model_name: str) -> ModelProfile | None: provider_to_profile = { 'anthropic': anthropic_model_profile, 'openai': openai_model_profile, - 'groq': _groq_model_profile_cloudflare, + 'groq': groq_provider_model_profile, 'mistral': mistral_model_profile, 'cohere': cohere_model_profile, 'deepseek': deepseek_model_profile, - # NOTE: this would be the first support for perplexity in pydantic-ai - # to remove this, the equivalent test in tests/providers/test_cloudflare.py::test_cloudflare_provider_model_profile would need to be removed - 'perplexity': openai_model_profile, # Perplexity uses OpenAI-compatible API + 'perplexity': perplexity_model_profile, 'workers-ai': openai_model_profile, # Cloudflare Workers AI uses OpenAI-compatible API 'workersai': openai_model_profile, # Alternative naming 'google-ai-studio': google_model_profile, 'grok': grok_model_profile, 'xai': grok_model_profile, # xai is an alias for grok - 'cerebras': _cerebras_model_profile_cloudflare, + 'cerebras': cerebras_provider_model_profile, } profile = None @@ -167,8 +148,6 @@ def model_profile(self, model_name: str) -> ModelProfile | None: if provider in provider_to_profile: profile = provider_to_profile[provider](model_name) - # If provider is not recognized, profile remains None and we fall back to OpenAI-compatible behavior. - # This matches VercelProvider's behavior of silently supporting unknown providers through the unified API. # As CloudflareProvider is always used with OpenAIChatModel, which used to unconditionally use OpenAIJsonSchemaTransformer, # we need to maintain that behavior unless json_schema_transformer is set explicitly @@ -176,44 +155,16 @@ def model_profile(self, model_name: str) -> ModelProfile | None: json_schema_transformer=OpenAIJsonSchemaTransformer, ).update(profile) - @staticmethod - def _create_stored_keys_client(base_client: httpx.AsyncClient) -> httpx.AsyncClient: - """Create an HTTP client that strips the Authorization header for stored keys mode. - - When using Cloudflare's stored keys feature (API keys stored in the dashboard), - the Authorization header must NOT be sent. If sent, it takes precedence over - the stored keys, breaking the feature. - - This wraps the base client with an event hook that removes the Authorization header. - """ - - async def strip_auth_header(request: httpx.Request) -> None: - """Remove Authorization header so Cloudflare uses stored keys from dashboard.""" - if 'authorization' in request.headers: - del request.headers['authorization'] - - # Merge event hooks - preserve any existing hooks and add our strip_auth_header hook - existing_hooks = base_client.event_hooks - new_request_hooks = list(existing_hooks.get('request', [])) - new_request_hooks.append(strip_auth_header) - - merged_hooks = dict(existing_hooks) - merged_hooks['request'] = new_request_hooks - - # Create new client based on the base client's configuration - base_client._event_hooks = merged_hooks # type: ignore[attr-defined] - return base_client - - # Scenario 1: BYOK with unauthenticated gateway (api_key required) + # Scenario 1: User-managed keys with unauthenticated gateway (api_key required) @overload def __init__(self, *, account_id: str, gateway_id: str, api_key: str) -> None: ... @overload def __init__(self, *, account_id: str, gateway_id: str, api_key: str, http_client: httpx.AsyncClient) -> None: ... - # Scenario 2: BYOK with authenticated gateway (api_key + cf_aig_authorization) + # Scenario 2: User-managed keys with authenticated gateway (api_key + gateway_auth_token) @overload - def __init__(self, *, account_id: str, gateway_id: str, api_key: str, cf_aig_authorization: str) -> None: ... + def __init__(self, *, account_id: str, gateway_id: str, api_key: str, gateway_auth_token: str) -> None: ... @overload def __init__( @@ -222,15 +173,13 @@ def __init__( account_id: str, gateway_id: str, api_key: str, - cf_aig_authorization: str, + gateway_auth_token: str, http_client: httpx.AsyncClient, ) -> None: ... - # Scenario 3: Stored keys with authenticated gateway (use_gateway_keys=True, cf_aig_authorization required) + # Scenario 3: CF-managed keys with authenticated gateway (no api_key, gateway_auth_token required) @overload - def __init__( - self, *, account_id: str, gateway_id: str, cf_aig_authorization: str, use_gateway_keys: bool = True - ) -> None: ... + def __init__(self, *, account_id: str, gateway_id: str, gateway_auth_token: str) -> None: ... @overload def __init__( @@ -238,9 +187,8 @@ def __init__( *, account_id: str, gateway_id: str, - cf_aig_authorization: str, + gateway_auth_token: str, http_client: httpx.AsyncClient, - use_gateway_keys: bool = True, ) -> None: ... # Advanced: Pre-configured OpenAI client @@ -253,10 +201,9 @@ def __init__( account_id: str | None = None, gateway_id: str | None = None, api_key: str | None = None, - cf_aig_authorization: str | None = None, + gateway_auth_token: str | None = None, openai_client: AsyncOpenAI | None = None, http_client: httpx.AsyncClient | None = None, - use_gateway_keys: bool = False, ) -> None: """Initialize the Cloudflare AI Gateway provider. @@ -264,22 +211,18 @@ def __init__( account_id: Your Cloudflare account ID. Can also be set via CLOUDFLARE_ACCOUNT_ID environment variable. gateway_id: Your Cloudflare AI Gateway ID. Can also be set via CLOUDFLARE_GATEWAY_ID environment variable. api_key: The API key for the upstream provider (OpenAI, Anthropic, etc.). - - Required for BYOK (bring your own key) mode (default) - - Do NOT provide when use_gateway_keys=True (conflicts with stored keys mode) + - Required for user-managed mode + - Omit this (along with providing gateway_auth_token) to use CF-managed keys mode - Optional when using the openai_client parameter (pre-configured client) - cf_aig_authorization: Authorization token for authenticated gateways. - - Required when use_gateway_keys=True (stored keys mode) - - Optional for BYOK mode (provides additional gateway authentication) + gateway_auth_token: Authorization token for authenticated gateways. + - Required for CF-managed keys mode (when api_key is omitted) + - Optional for user-managed mode (provides additional gateway authentication) - Can also be set via CLOUDFLARE_AI_GATEWAY_AUTH environment variable openai_client: Optional pre-configured AsyncOpenAI client for advanced use cases. http_client: Optional HTTP client to use for requests. - use_gateway_keys: Whether to use API keys stored in your Cloudflare dashboard (default: False). - - Set to True to use stored keys mode (requires cf_aig_authorization) - - When True, do not provide api_key (they are mutually exclusive) - - When False (default), you must provide api_key for BYOK mode Raises: - UserError: If use_gateway_keys=True and api_key is also provided (conflicting configuration). + UserError: If configuration is invalid (e.g., neither api_key nor CF-managed keys mode is configured). """ account_id = account_id or os.getenv('CLOUDFLARE_ACCOUNT_ID') gateway_id = gateway_id or os.getenv('CLOUDFLARE_GATEWAY_ID') @@ -296,29 +239,20 @@ def __init__( 'or pass it via `CloudflareProvider(gateway_id=...)` to use the Cloudflare provider.' ) - cf_aig_authorization = cf_aig_authorization or os.getenv('CLOUDFLARE_AI_GATEWAY_AUTH') - - if use_gateway_keys: - # Stored keys mode requires authenticated gateway - if cf_aig_authorization is None: - raise UserError( - 'When use_gateway_keys=True, you must provide cf_aig_authorization.\n' - 'Stored keys (API keys stored in Cloudflare dashboard) require an authenticated gateway.' - ) - # Can't use both stored keys and provide your own api_key - if api_key is not None: - raise UserError( - 'When use_gateway_keys=True, do not provide an api_key.\n' - 'use_gateway_keys=True means using API keys stored in your Cloudflare dashboard,\n' - 'which is incompatible with providing your own api_key (BYOK mode).' - ) - # Use placeholder for AsyncOpenAI (required by AsyncOpenAI client) but we'll strip the Authorization header - api_key = 'stored-keys-placeholder' + gateway_auth_token = gateway_auth_token or os.getenv('CLOUDFLARE_AI_GATEWAY_AUTH') + + # Detect CF-managed keys mode: no api_key provided + gateway_auth_token present + no pre-configured client + use_cf_managed_keys = api_key is None and gateway_auth_token is not None and openai_client is None + + if use_cf_managed_keys: + # CF-managed keys mode: use API keys stored in Cloudflare dashboard + # Use empty string for AsyncOpenAI - this prevents the Authorization header from being sent + api_key = '' elif api_key is None and openai_client is None: - # Not using stored keys, so api_key is required (unless using pre-configured openai_client) + # Not using CF-managed keys, so api_key is required (unless using pre-configured openai_client) raise UserError( - 'When use_gateway_keys=False (the default), you must provide an api_key for BYOK mode.\n' - 'To use API keys stored in your Cloudflare dashboard, set use_gateway_keys=True and provide cf_aig_authorization.' + 'You must provide an api_key for user-managed mode.\n' + 'To use API keys stored in your Cloudflare dashboard (CF-managed), omit api_key and provide gateway_auth_token instead.' ) self._base_url = f'https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/compat' @@ -328,23 +262,17 @@ def __init__( 'x-title': 'pydantic-ai', } - if cf_aig_authorization: - default_headers['cf-aig-authorization'] = cf_aig_authorization + if gateway_auth_token: + default_headers['cf-aig-authorization'] = gateway_auth_token if openai_client is not None: self._client = openai_client elif http_client is not None: - # If user provided http_client and we're in stored keys mode, we need to wrap it - if use_gateway_keys: - http_client = self._create_stored_keys_client(http_client) self._client = AsyncOpenAI( base_url=self._base_url, api_key=api_key, http_client=http_client, default_headers=default_headers ) else: http_client = cached_async_http_client(provider='cloudflare') - # In stored keys mode, wrap the client to strip Authorization header - if use_gateway_keys: - http_client = self._create_stored_keys_client(http_client) self._client = AsyncOpenAI( base_url=self._base_url, api_key=api_key, http_client=http_client, default_headers=default_headers ) diff --git a/pydantic_ai_slim/pydantic_ai/providers/groq.py b/pydantic_ai_slim/pydantic_ai/providers/groq.py index f0e5c5b53b..56604ef8ca 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/groq.py +++ b/pydantic_ai_slim/pydantic_ai/providers/groq.py @@ -44,6 +44,34 @@ def meta_groq_model_profile(model_name: str) -> ModelProfile | None: return meta_model_profile(model_name) +def groq_provider_model_profile(model_name: str) -> ModelProfile | None: + """Get the model profile for a model routed through Groq provider. + + This function handles model profiling for models that use Groq's API, + including various model families like Llama, Gemma, Qwen, etc. + """ + prefix_to_profile = { + 'llama': meta_model_profile, + 'meta-llama/': meta_groq_model_profile, + 'gemma': google_model_profile, + 'qwen': qwen_model_profile, + 'deepseek': deepseek_model_profile, + 'mistral': mistral_model_profile, + 'moonshotai/': groq_moonshotai_model_profile, + 'compound-': groq_model_profile, + 'openai/': openai_model_profile, + } + + for prefix, profile_func in prefix_to_profile.items(): + model_name = model_name.lower() + if model_name.startswith(prefix): + if prefix.endswith('/'): + model_name = model_name[len(prefix) :] + return profile_func(model_name) + + return None + + class GroqProvider(Provider[AsyncGroq]): """Provider for Groq API.""" @@ -60,26 +88,7 @@ def client(self) -> AsyncGroq: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: - prefix_to_profile = { - 'llama': meta_model_profile, - 'meta-llama/': meta_groq_model_profile, - 'gemma': google_model_profile, - 'qwen': qwen_model_profile, - 'deepseek': deepseek_model_profile, - 'mistral': mistral_model_profile, - 'moonshotai/': groq_moonshotai_model_profile, - 'compound-': groq_model_profile, - 'openai/': openai_model_profile, - } - - for prefix, profile_func in prefix_to_profile.items(): - model_name = model_name.lower() - if model_name.startswith(prefix): - if prefix.endswith('/'): - model_name = model_name[len(prefix) :] - return profile_func(model_name) - - return None + return groq_provider_model_profile(model_name) @overload def __init__(self, *, groq_client: AsyncGroq | None = None) -> None: ... diff --git a/tests/models/test_model_names.py b/tests/models/test_model_names.py index b27aa2d8c2..28c58b56ab 100644 --- a/tests/models/test_model_names.py +++ b/tests/models/test_model_names.py @@ -70,6 +70,18 @@ def get_model_names(model_name_type: Any) -> Iterator[str]: openai_names = [f'openai:{n}' for n in get_model_names(OpenAIModelName)] bedrock_names = [f'bedrock:{n}' for n in get_model_names(BedrockModelName)] deepseek_names = ['deepseek:deepseek-chat', 'deepseek:deepseek-reasoner'] + cloudflare_names = [ + 'cloudflare:anthropic/claude-3-5-sonnet', + 'cloudflare:cohere/command-r-plus', + 'cloudflare:deepseek/deepseek-chat', + 'cloudflare:google/gemini-2.0-flash', + 'cloudflare:groq/llama-3.3-70b-versatile', + 'cloudflare:mistral/mistral-large-latest', + 'cloudflare:openai/gpt-4o', + 'cloudflare:perplexity/llama-3.1-sonar-small-128k-online', + 'cloudflare:workers-ai/@cf/meta/llama-3.1-8b-instruct', + 'cloudflare:xai/grok-2-1212', + ] huggingface_names = [f'huggingface:{n}' for n in get_model_names(HuggingFaceModelName)] heroku_names = get_heroku_model_names() cerebras_names = get_cerebras_model_names() @@ -86,6 +98,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]: + openai_names + bedrock_names + deepseek_names + + cloudflare_names + huggingface_names + heroku_names + cerebras_names diff --git a/tests/providers/test_cloudflare.py b/tests/providers/test_cloudflare.py index ec02d86c0d..42564cae0f 100644 --- a/tests/providers/test_cloudflare.py +++ b/tests/providers/test_cloudflare.py @@ -14,6 +14,7 @@ from pydantic_ai.profiles.grok import grok_model_profile from pydantic_ai.profiles.mistral import mistral_model_profile from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, openai_model_profile +from pydantic_ai.profiles.perplexity import perplexity_model_profile from ..conftest import TestEnv, try_import @@ -72,17 +73,17 @@ def test_cloudflare_provider_from_env(env: TestEnv) -> None: assert provider.client.api_key == 'env-api-key' -def test_cloudflare_provider_with_cf_aig_authorization(): +def test_cloudflare_provider_with_gateway_auth_token(): provider = CloudflareProvider( account_id='test-account-id', gateway_id='test-gateway-id', api_key='api-key', - cf_aig_authorization='gateway-token', + gateway_auth_token='gateway-token', ) assert provider.client.default_headers['cf-aig-authorization'] == 'gateway-token' -def test_cloudflare_provider_cf_aig_authorization_from_env(env: TestEnv) -> None: +def test_cloudflare_provider_gateway_auth_token_from_env(env: TestEnv) -> None: env.set('CLOUDFLARE_ACCOUNT_ID', 'test-account-id') env.set('CLOUDFLARE_GATEWAY_ID', 'test-gateway-id') env.set('CLOUDFLARE_AI_GATEWAY_AUTH', 'env-gateway-token') @@ -116,6 +117,7 @@ def test_cloudflare_provider_model_profile(mocker: MockerFixture, env: TestEnv): grok_mock = mocker.patch(f'{ns}.grok_model_profile', wraps=grok_model_profile) mistral_mock = mocker.patch(f'{ns}.mistral_model_profile', wraps=mistral_model_profile) openai_mock = mocker.patch(f'{ns}.openai_model_profile', wraps=openai_model_profile) + perplexity_mock = mocker.patch(f'{ns}.perplexity_model_profile', wraps=perplexity_model_profile) # Use real GroqProvider and CerebrasProvider - they don't make API calls for model_profile() # We just need dummy API keys which are set via env vars above @@ -180,9 +182,9 @@ def test_cloudflare_provider_model_profile(mocker: MockerFixture, env: TestEnv): assert profile is not None assert profile.json_schema_transformer == GoogleJsonSchemaTransformer - # Test perplexity provider (uses OpenAI-compatible API) + # Test perplexity provider (currently returns None, falls back to OpenAI-compatible) profile = provider.model_profile('perplexity/llama-3.1-sonar-small-128k-online') - openai_mock.assert_called_with('llama-3.1-sonar-small-128k-online') + perplexity_mock.assert_called_with('llama-3.1-sonar-small-128k-online') assert profile is not None assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer @@ -240,96 +242,38 @@ def test_cloudflare_default_headers(): def test_cloudflare_provider_stored_keys(): - """Test stored keys mode - API keys stored in Cloudflare dashboard (requires authenticated gateway).""" + """Test CF-managed keys mode - API keys stored in Cloudflare dashboard (requires authenticated gateway).""" provider = CloudflareProvider( account_id='test-account-id', gateway_id='test-gateway-id', - cf_aig_authorization='gateway-token', - use_gateway_keys=True, + gateway_auth_token='gateway-token', ) - # api_key is set to placeholder for AsyncOpenAI, but Authorization header will be stripped - assert provider.client.api_key == 'stored-keys-placeholder' + # api_key is set to empty string for AsyncOpenAI to prevent Authorization header + assert provider.client.api_key == '' assert provider.base_url == 'https://gateway.ai.cloudflare.com/v1/test-account-id/test-gateway-id/compat' assert provider.client.default_headers['cf-aig-authorization'] == 'gateway-token' def test_cloudflare_provider_missing_credentials(): - """Test that error is raised when api_key is missing and use_gateway_keys=False.""" + """Test that error is raised when api_key is missing and not in CF-managed keys mode.""" with pytest.raises( UserError, - match=re.escape('When use_gateway_keys=False (the default), you must provide an api_key for BYOK mode.'), + match=re.escape('You must provide an api_key for user-managed mode.'), ): CloudflareProvider(account_id='test-account-id', gateway_id='test-gateway-id') # type: ignore[call-overload] -def test_cloudflare_provider_use_gateway_keys_with_api_key_conflict(): - """Test that error is raised when both use_gateway_keys=True and api_key are provided.""" - with pytest.raises( - UserError, - match=re.escape('When use_gateway_keys=True, do not provide an api_key.'), - ): - CloudflareProvider( # type: ignore[call-overload] - account_id='test-account-id', - gateway_id='test-gateway-id', - cf_aig_authorization='gateway-token', - api_key='sk-test', - use_gateway_keys=True, - ) - - -def test_cloudflare_provider_use_gateway_keys_without_auth(): - """Test that error is raised when use_gateway_keys=True but cf_aig_authorization is missing.""" - with pytest.raises( - UserError, - match=re.escape('When use_gateway_keys=True, you must provide cf_aig_authorization.'), - ): - CloudflareProvider( - account_id='test-account-id', - gateway_id='test-gateway-id', - use_gateway_keys=True, # type: ignore[call-overload] - ) - - -async def test_cloudflare_stored_keys_strips_auth_header(): - """Test that Authorization header is stripped in stored keys mode so Cloudflare uses stored keys.""" - import httpx - - # Create a custom http_client to pass to the provider - custom_client = httpx.AsyncClient() - +def test_cloudflare_stored_keys_no_auth_header(): + """Test that Authorization header is not sent in CF-managed keys mode (empty api_key).""" provider = CloudflareProvider( account_id='test-account-id', gateway_id='test-gateway-id', - cf_aig_authorization='gateway-token', - http_client=custom_client, - use_gateway_keys=True, - ) - - # Get the http_client from the AsyncOpenAI client (accessing private attribute for testing) - http_client: httpx.AsyncClient = provider.client._client # type: ignore[attr-defined] - - # Test 1: Request WITH Authorization header - should be stripped - request = http_client.build_request('POST', 'https://example.com', headers={'Authorization': 'Bearer test'}) - - # Trigger the request event hooks - for hook in http_client.event_hooks.get('request', []): - await hook(request) - - # Verify Authorization header was removed by the event hook - assert 'authorization' not in request.headers - - # Test 2: Request WITHOUT Authorization header - should pass through unchanged (covers line 192 else branch) - request_no_auth = http_client.build_request( - 'POST', 'https://example.com', headers={'Content-Type': 'application/json'} + gateway_auth_token='gateway-token', ) - # Trigger the request event hooks - for hook in http_client.event_hooks.get('request', []): - await hook(request_no_auth) - - # Verify other headers are preserved and no error occurred - assert 'content-type' in request_no_auth.headers - assert 'authorization' not in request_no_auth.headers + # In CF-managed keys mode, api_key is empty string which prevents OpenAI SDK from adding Authorization header + assert provider.client.api_key == '' + assert provider.client.default_headers['cf-aig-authorization'] == 'gateway-token' def test_cloudflare_documented_patterns(): @@ -358,8 +302,7 @@ def test_cloudflare_documented_patterns(): provider=CloudflareProvider( account_id='your-account-id', gateway_id='your-gateway-id', - cf_aig_authorization='your-gateway-token', - use_gateway_keys=True, + gateway_auth_token='your-gateway-token', ), ) agent = Agent(model) diff --git a/tests/test_cli.py b/tests/test_cli.py index e95ff09141..125423ca24 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -138,6 +138,7 @@ def test_list_models(capfd: CaptureFixture[str]): 'anthropic', 'bedrock', 'cerebras', + 'cloudflare', 'google-vertex', 'google-gla', 'groq', diff --git a/tests/test_examples.py b/tests/test_examples.py index 87649e44b3..6e32feb532 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -170,6 +170,8 @@ def print(self, *args: Any, **kwargs: Any) -> None: env.set('AWS_DEFAULT_REGION', 'us-east-1') env.set('VERCEL_AI_GATEWAY_API_KEY', 'testing') env.set('CEREBRAS_API_KEY', 'testing') + env.set('CLOUDFLARE_ACCOUNT_ID', 'testing') + env.set('CLOUDFLARE_GATEWAY_ID', 'testing') env.set('NEBIUS_API_KEY', 'testing') env.set('HEROKU_INFERENCE_KEY', 'testing') env.set('FIREWORKS_API_KEY', 'testing')