diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 9019b81931..a5ab2aa741 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -2,14 +2,16 @@ import base64 import hashlib +import mimetypes from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import KW_ONLY, dataclass, field, replace from datetime import datetime -from mimetypes import guess_type +from mimetypes import MimeTypes from os import PathLike from pathlib import Path from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeAlias, cast, overload +from urllib.parse import urlparse import pydantic import pydantic_core @@ -25,6 +27,27 @@ if TYPE_CHECKING: from .models.instrumented import InstrumentationSettings +_mime_types = MimeTypes(tuple(mimetypes.knownfiles)) +# Register manually MIME types that are not in the standard library or override standard ones +# Document types +_mime_types.add_type('text/markdown', '.mdx') +_mime_types.add_type('text/x-asciidoc', '.asciidoc') + +# Video types +_mime_types.add_type('video/3gpp', '.three_gp') +_mime_types.add_type('video/x-flv', '.flv') +_mime_types.add_type('video/x-matroska', '.mkv') +_mime_types.add_type('video/x-ms-wmv', '.wmv') + +# Audio types +_mime_types.add_type('audio/flac', '.flac') +_mime_types.add_type('audio/mpeg', '.mp3') +_mime_types.add_type('audio/ogg', '.oga') +# override stdlib mimetypes that use x- prefix with standard types +_mime_types.add_type('audio/aac', '.aac') +_mime_types.add_type('audio/aiff', '.aiff') +_mime_types.add_type('audio/wav', '.wav') + AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg', 'audio/ogg', 'audio/flac', 'audio/aiff', 'audio/aac'] ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp'] @@ -226,38 +249,27 @@ def __init__( ) self.kind = kind - def _infer_media_type(self) -> VideoMediaType: + def _infer_media_type(self) -> str: """Return the media type of the video, based on the url.""" - if self.url.endswith('.mkv'): - return 'video/x-matroska' - elif self.url.endswith('.mov'): - return 'video/quicktime' - elif self.url.endswith('.mp4'): - return 'video/mp4' - elif self.url.endswith('.webm'): - return 'video/webm' - elif self.url.endswith('.flv'): - return 'video/x-flv' - elif self.url.endswith(('.mpeg', '.mpg')): - return 'video/mpeg' - elif self.url.endswith('.wmv'): - return 'video/x-ms-wmv' - elif self.url.endswith('.three_gp'): - return 'video/3gpp' # Assume that YouTube videos are mp4 because there would be no extension # to infer from. This should not be a problem, as Gemini disregards media # type for YouTube URLs. - elif self.is_youtube: + if self.is_youtube: return 'video/mp4' - else: + + mime_type, _ = _mime_types.guess_type(self.url) + if mime_type is None: raise ValueError( f'Could not infer media type from video URL: {self.url}. Explicitly provide a `media_type` instead.' ) + return mime_type @property def is_youtube(self) -> bool: """True if the URL has a YouTube domain.""" - return self.url.startswith(('https://youtu.be/', 'https://youtube.com/', 'https://www.youtube.com/')) + parsed = urlparse(self.url) + hostname = parsed.hostname or '' + return hostname in ('youtu.be', 'youtube.com', 'www.youtube.com') @property def format(self) -> VideoFormat: @@ -302,28 +314,18 @@ def __init__( ) self.kind = kind - def _infer_media_type(self) -> AudioMediaType: + def _infer_media_type(self) -> str: """Return the media type of the audio file, based on the url. References: - Gemini: https://ai.google.dev/gemini-api/docs/audio#supported-formats """ - if self.url.endswith('.mp3'): - return 'audio/mpeg' - if self.url.endswith('.wav'): - return 'audio/wav' - if self.url.endswith('.flac'): - return 'audio/flac' - if self.url.endswith('.oga'): - return 'audio/ogg' - if self.url.endswith('.aiff'): - return 'audio/aiff' - if self.url.endswith('.aac'): - return 'audio/aac' - - raise ValueError( - f'Could not infer media type from audio URL: {self.url}. Explicitly provide a `media_type` instead.' - ) + mime_type, _ = _mime_types.guess_type(self.url) + if mime_type is None: + raise ValueError( + f'Could not infer media type from audio URL: {self.url}. Explicitly provide a `media_type` instead.' + ) + return mime_type @property def format(self) -> AudioFormat: @@ -365,20 +367,14 @@ def __init__( ) self.kind = kind - def _infer_media_type(self) -> ImageMediaType: + def _infer_media_type(self) -> str: """Return the media type of the image, based on the url.""" - if self.url.endswith(('.jpg', '.jpeg')): - return 'image/jpeg' - elif self.url.endswith('.png'): - return 'image/png' - elif self.url.endswith('.gif'): - return 'image/gif' - elif self.url.endswith('.webp'): - return 'image/webp' - else: + mime_type, _ = _mime_types.guess_type(self.url) + if mime_type is None: raise ValueError( f'Could not infer media type from image URL: {self.url}. Explicitly provide a `media_type` instead.' ) + return mime_type @property def format(self) -> ImageFormat: @@ -425,33 +421,12 @@ def __init__( def _infer_media_type(self) -> str: """Return the media type of the document, based on the url.""" - # Common document types are hardcoded here as mime-type support for these - # extensions varies across operating systems. - if self.url.endswith(('.md', '.mdx', '.markdown')): - return 'text/markdown' - elif self.url.endswith('.asciidoc'): - return 'text/x-asciidoc' - elif self.url.endswith('.txt'): - return 'text/plain' - elif self.url.endswith('.pdf'): - return 'application/pdf' - elif self.url.endswith('.rtf'): - return 'application/rtf' - elif self.url.endswith('.doc'): - return 'application/msword' - elif self.url.endswith('.docx'): - return 'application/vnd.openxmlformats-officedocument.wordprocessingml.document' - elif self.url.endswith('.xls'): - return 'application/vnd.ms-excel' - elif self.url.endswith('.xlsx'): - return 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' - - type_, _ = guess_type(self.url) - if type_ is None: + mime_type, _ = _mime_types.guess_type(self.url) + if mime_type is None: raise ValueError( f'Could not infer media type from document URL: {self.url}. Explicitly provide a `media_type` instead.' ) - return type_ + return mime_type @property def format(self) -> DocumentFormat: @@ -545,7 +520,7 @@ def from_path(cls, path: PathLike[str]) -> BinaryContent: path = Path(path) if not path.exists(): raise FileNotFoundError(f'File not found: {path}') - media_type, _ = guess_type(path) + media_type, _ = _mime_types.guess_type(path) if media_type is None: media_type = 'application/octet-stream' diff --git a/tests/test_messages.py b/tests/test_messages.py index 943d68fe8c..488d847b62 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -333,6 +333,78 @@ def test_video_url_invalid(): VideoUrl('foobar.potato').media_type +@pytest.mark.parametrize( + 'url,media_type,format', + [ + pytest.param( + 'https://example.com/video.mp4?query=param', + 'video/mp4', + 'mp4', + id='mp4_with_query', + ), + pytest.param( + 'https://example.com/video.webm?X-Amz-Algorithm=AWS4-HMAC-SHA256', + 'video/webm', + 'webm', + id='webm_with_aws_params', + ), + ], +) +def test_video_url_with_query_parameters(url: str, media_type: str, format: str): + """Test that VideoUrl correctly infers media type from URLs with query parameters (e.g., presigned URLs).""" + video_url = VideoUrl(url) + assert video_url.media_type == media_type + assert video_url.format == format + + +@pytest.mark.parametrize( + 'url,media_type,format', + [ + pytest.param( + 'https://example.com/audio.mp3?query=param', + 'audio/mpeg', + 'mp3', + id='mp3_with_query', + ), + pytest.param( + 'https://example.com/audio.wav?X-Amz-Algorithm=AWS4-HMAC-SHA256', + 'audio/wav', + 'wav', + id='wav_with_aws_params', + ), + ], +) +def test_audio_url_with_query_parameters(url: str, media_type: str, format: str): + """Test that AudioUrl correctly infers media type from URLs with query parameters (e.g., presigned URLs).""" + audio_url = AudioUrl(url) + assert audio_url.media_type == media_type + assert audio_url.format == format + + +@pytest.mark.parametrize( + 'url,media_type,format', + [ + pytest.param( + 'https://example.com/image.png?query=param', + 'image/png', + 'png', + id='png_with_query', + ), + pytest.param( + 'https://example.com/image.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256', + 'image/jpeg', + 'jpeg', + id='jpg_with_aws_params', + ), + ], +) +def test_image_url_with_query_parameters(url: str, media_type: str, format: str): + """Test that ImageUrl correctly infers media type from URLs with query parameters (e.g., presigned URLs).""" + image_url = ImageUrl(url) + assert image_url.media_type == media_type + assert image_url.format == format + + def test_thinking_part_delta_apply_to_thinking_part_delta(): """Test lines 768-775: Apply ThinkingPartDelta to another ThinkingPartDelta.""" original_delta = ThinkingPartDelta(