Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 45 additions & 71 deletions pydantic_ai_slim/pydantic_ai/messages.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why aren't we using the mimetypes stdlib module? mimetypes.guess_type() already parses URLs and the current implementation doesn't take into account case insensitivity, etc.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Viicos Interestingly we already use that in DocumentUrl._infer_media_type, after checking a bunch of types ourselves :/

@fedexman Can you see if we can use mimetypes.guess_type() for all of these?

The method can be changed to just return str rather than XMediaType, as I don't think that type is used on any public fields.

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import base64
import hashlib
import mimetypes
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import KW_ONLY, dataclass, field, replace
Expand All @@ -10,6 +11,7 @@
from os import PathLike
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeAlias, cast, overload
from urllib.parse import urlparse

import pydantic
import pydantic_core
Expand All @@ -25,6 +27,26 @@
if TYPE_CHECKING:
from .models.instrumented import InstrumentationSettings

# Register manually MIME types that are not in the standard library
# Document types
mimetypes.add_type('text/markdown', '.mdx')
mimetypes.add_type('text/x-asciidoc', '.asciidoc')

# Video types
mimetypes.add_type('video/3gpp', '.three_gp')
mimetypes.add_type('video/x-flv', '.flv')
mimetypes.add_type('video/x-matroska', '.mkv')
mimetypes.add_type('video/x-ms-wmv', '.wmv')

# Audio types
mimetypes.add_type('audio/flac', '.flac')
mimetypes.add_type('audio/mpeg', '.mp3')
mimetypes.add_type('audio/ogg', '.oga')
# override stdlib mimetypes that use x- prefix with standard types
mimetypes.add_type('audio/aac', '.aac')
mimetypes.add_type('audio/aiff', '.aiff')
mimetypes.add_type('audio/wav', '.wav')
Comment on lines +30 to +48
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will affect the global mimetypes db for the current interpreter. Let's instantiate an explicit MimeTypes object instead, and attach the additional types directly to it.

You can then use your instance's guess_type() directly instead of the module-level guess_type().



AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg', 'audio/ogg', 'audio/flac', 'audio/aiff', 'audio/aac']
ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']
Expand Down Expand Up @@ -226,38 +248,27 @@ def __init__(
)
self.kind = kind

def _infer_media_type(self) -> VideoMediaType:
def _infer_media_type(self) -> str:
"""Return the media type of the video, based on the url."""
if self.url.endswith('.mkv'):
return 'video/x-matroska'
elif self.url.endswith('.mov'):
return 'video/quicktime'
elif self.url.endswith('.mp4'):
return 'video/mp4'
elif self.url.endswith('.webm'):
return 'video/webm'
elif self.url.endswith('.flv'):
return 'video/x-flv'
elif self.url.endswith(('.mpeg', '.mpg')):
return 'video/mpeg'
elif self.url.endswith('.wmv'):
return 'video/x-ms-wmv'
elif self.url.endswith('.three_gp'):
return 'video/3gpp'
# Assume that YouTube videos are mp4 because there would be no extension
# to infer from. This should not be a problem, as Gemini disregards media
# type for YouTube URLs.
elif self.is_youtube:
if self.is_youtube:
return 'video/mp4'
else:

mime_type, _ = guess_type(self.url)
if mime_type is None:
raise ValueError(
f'Could not infer media type from video URL: {self.url}. Explicitly provide a `media_type` instead.'
)
return mime_type

@property
def is_youtube(self) -> bool:
"""True if the URL has a YouTube domain."""
return self.url.startswith(('https://youtu.be/', 'https://youtube.com/', 'https://www.youtube.com/'))
parsed = urlparse(self.url)
hostname = parsed.hostname or ''
return hostname in ('youtu.be', 'youtube.com', 'www.youtube.com')

@property
def format(self) -> VideoFormat:
Expand Down Expand Up @@ -302,28 +313,18 @@ def __init__(
)
self.kind = kind

def _infer_media_type(self) -> AudioMediaType:
def _infer_media_type(self) -> str:
"""Return the media type of the audio file, based on the url.

References:
- Gemini: https://ai.google.dev/gemini-api/docs/audio#supported-formats
"""
if self.url.endswith('.mp3'):
return 'audio/mpeg'
if self.url.endswith('.wav'):
return 'audio/wav'
if self.url.endswith('.flac'):
return 'audio/flac'
if self.url.endswith('.oga'):
return 'audio/ogg'
if self.url.endswith('.aiff'):
return 'audio/aiff'
if self.url.endswith('.aac'):
return 'audio/aac'

raise ValueError(
f'Could not infer media type from audio URL: {self.url}. Explicitly provide a `media_type` instead.'
)
mime_type, _ = guess_type(self.url)
if mime_type is None:
raise ValueError(
f'Could not infer media type from audio URL: {self.url}. Explicitly provide a `media_type` instead.'
)
return mime_type

@property
def format(self) -> AudioFormat:
Expand Down Expand Up @@ -365,20 +366,14 @@ def __init__(
)
self.kind = kind

def _infer_media_type(self) -> ImageMediaType:
def _infer_media_type(self) -> str:
"""Return the media type of the image, based on the url."""
if self.url.endswith(('.jpg', '.jpeg')):
return 'image/jpeg'
elif self.url.endswith('.png'):
return 'image/png'
elif self.url.endswith('.gif'):
return 'image/gif'
elif self.url.endswith('.webp'):
return 'image/webp'
else:
mime_type, _ = guess_type(self.url)
if mime_type is None:
raise ValueError(
f'Could not infer media type from image URL: {self.url}. Explicitly provide a `media_type` instead.'
)
return mime_type

@property
def format(self) -> ImageFormat:
Expand Down Expand Up @@ -425,33 +420,12 @@ def __init__(

def _infer_media_type(self) -> str:
"""Return the media type of the document, based on the url."""
# Common document types are hardcoded here as mime-type support for these
# extensions varies across operating systems.
if self.url.endswith(('.md', '.mdx', '.markdown')):
return 'text/markdown'
elif self.url.endswith('.asciidoc'):
return 'text/x-asciidoc'
elif self.url.endswith('.txt'):
return 'text/plain'
elif self.url.endswith('.pdf'):
return 'application/pdf'
elif self.url.endswith('.rtf'):
return 'application/rtf'
elif self.url.endswith('.doc'):
return 'application/msword'
elif self.url.endswith('.docx'):
return 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
elif self.url.endswith('.xls'):
return 'application/vnd.ms-excel'
elif self.url.endswith('.xlsx'):
return 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'

type_, _ = guess_type(self.url)
if type_ is None:
mime_type, _ = guess_type(self.url)
if mime_type is None:
raise ValueError(
f'Could not infer media type from document URL: {self.url}. Explicitly provide a `media_type` instead.'
)
return type_
return mime_type

@property
def format(self) -> DocumentFormat:
Expand Down
72 changes: 72 additions & 0 deletions tests/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,78 @@ def test_video_url_invalid():
VideoUrl('foobar.potato').media_type


@pytest.mark.parametrize(
'url,media_type,format',
[
pytest.param(
'https://example.com/video.mp4?query=param',
'video/mp4',
'mp4',
id='mp4_with_query',
),
pytest.param(
'https://example.com/video.webm?X-Amz-Algorithm=AWS4-HMAC-SHA256',
'video/webm',
'webm',
id='webm_with_aws_params',
),
],
)
def test_video_url_with_query_parameters(url: str, media_type: str, format: str):
"""Test that VideoUrl correctly infers media type from URLs with query parameters (e.g., presigned URLs)."""
video_url = VideoUrl(url)
assert video_url.media_type == media_type
assert video_url.format == format


@pytest.mark.parametrize(
'url,media_type,format',
[
pytest.param(
'https://example.com/audio.mp3?query=param',
'audio/mpeg',
'mp3',
id='mp3_with_query',
),
pytest.param(
'https://example.com/audio.wav?X-Amz-Algorithm=AWS4-HMAC-SHA256',
'audio/wav',
'wav',
id='wav_with_aws_params',
),
],
)
def test_audio_url_with_query_parameters(url: str, media_type: str, format: str):
"""Test that AudioUrl correctly infers media type from URLs with query parameters (e.g., presigned URLs)."""
audio_url = AudioUrl(url)
assert audio_url.media_type == media_type
assert audio_url.format == format


@pytest.mark.parametrize(
'url,media_type,format',
[
pytest.param(
'https://example.com/image.png?query=param',
'image/png',
'png',
id='png_with_query',
),
pytest.param(
'https://example.com/image.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256',
'image/jpeg',
'jpeg',
id='jpg_with_aws_params',
),
],
)
def test_image_url_with_query_parameters(url: str, media_type: str, format: str):
"""Test that ImageUrl correctly infers media type from URLs with query parameters (e.g., presigned URLs)."""
image_url = ImageUrl(url)
assert image_url.media_type == media_type
assert image_url.format == format


def test_thinking_part_delta_apply_to_thinking_part_delta():
"""Test lines 768-775: Apply ThinkingPartDelta to another ThinkingPartDelta."""
original_delta = ThinkingPartDelta(
Expand Down
Loading