diff --git a/src/huggingface_hub/_webhooks_payload.py b/src/huggingface_hub/_webhooks_payload.py index 288f4b08b9..1279297c8c 100644 --- a/src/huggingface_hub/_webhooks_payload.py +++ b/src/huggingface_hub/_webhooks_payload.py @@ -14,7 +14,7 @@ # limitations under the License. """Contains data structures to parse the webhooks payload.""" -from typing import List, Literal, Optional +from typing import Any, List, Literal, Optional, Union from .utils import is_pydantic_available @@ -32,6 +32,34 @@ def __init__(self, *args, **kwargs) -> None: " should be installed separately. Please run `pip install --upgrade pydantic` and retry." ) + @classmethod + def model_json_schema(cls, *args, **kwargs) -> dict[str, Any]: + raise ImportError( + "You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that" + " should be installed separately. Please run `pip install --upgrade pydantic` and retry." + ) + + @classmethod + def schema(cls, *args, **kwargs) -> dict[str, Any]: + raise ImportError( + "You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that" + " should be installed separately. Please run `pip install --upgrade pydantic` and retry." + ) + + @classmethod + def model_validate_json(cls, json_data: Union[str, bytes, bytearray], *args, **kwargs) -> "BaseModel": + raise ImportError( + "You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that" + " should be installed separately. Please run `pip install --upgrade pydantic` and retry." + ) + + @classmethod + def parse_raw(cls, json_data: Union[str, bytes, bytearray], *args, **kwargs) -> "BaseModel": + raise ImportError( + "You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that" + " should be installed separately. Please run `pip install --upgrade pydantic` and retry." + ) + # This is an adaptation of the ReportV3 interface implemented in moon-landing. V0, V1 and V2 have been ignored as they # are not in used anymore. To keep in sync when format is updated in diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index ed473e6d11..567adb1f61 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -37,11 +37,12 @@ import re import time import warnings -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Union, overload +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Type, Union, overload from requests import HTTPError from requests.structures import CaseInsensitiveDict +from huggingface_hub._webhooks_payload import BaseModel from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS from huggingface_hub.errors import BadRequestError, InferenceTimeoutError from huggingface_hub.inference._common import ( @@ -538,7 +539,7 @@ def chat_completion( max_tokens: Optional[int] = None, n: Optional[int] = None, presence_penalty: Optional[float] = None, - response_format: Optional[ChatCompletionInputGrammarType] = None, + response_format: Optional[Union[ChatCompletionInputGrammarType, Type[BaseModel]]] = None, seed: Optional[int] = None, stop: Optional[List[str]] = None, stream_options: Optional[ChatCompletionInputStreamOptions] = None, @@ -590,8 +591,8 @@ def chat_completion( presence_penalty (`float`, *optional*): Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. - response_format ([`ChatCompletionInputGrammarType`], *optional*): - Grammar constraints. Can be either a JSONSchema or a regex. + response_format ([`ChatCompletionInputGrammarType`] or `pydantic.BaseModel` class, *optional*): + Grammar constraints. Can be either a JSONSchema, a regex or a Pydantic schema. seed (Optional[`int`], *optional*): Seed for reproducible control flow. Defaults to None. stop (Optional[`str`], *optional*): @@ -820,7 +821,7 @@ def chat_completion( ) ``` - Example using response_format: + Example using response_format (dict): ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") @@ -850,7 +851,44 @@ def chat_completion( >>> response.choices[0].message.content '{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}' ``` + + Example using response_format (pydantic): + ```py + >>> from huggingface_hub import InferenceClient + >>> from pydantic import BaseModel, conint + >>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") + >>> messages = [ + ... { + ... "role": "user", + ... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?", + ... }, + ... ] + >>> class ActivitySummary(BaseModel): + ... location: str + ... activity: str + ... animals_seen: conint(ge=1, le=5) + ... animals: list[str] + >>> response = client.chat_completion( + ... messages=messages, + ... response_format=ActivitySummary, + ... max_tokens=500, + ) + >>> response.choices[0].message.parsed + ActivitySummary(location='park', activity='bike ride', animals_seen=3, animals=['puppy', 'cat', 'raccoon']) + ``` """ + if isinstance(response_format, type) and issubclass(response_format, BaseModel): + response_model = response_format + # pydantic v2 uses model_json_schema + response_format = ChatCompletionInputGrammarType( + type="json", + value=response_model.model_json_schema() + if hasattr(response_model, "model_json_schema") + else response_model.schema(), + ) + else: + response_model = None + model_url = self._resolve_chat_completion_url(model) # `model` is sent in the payload. Not used by the server but can be useful for debugging/routing. @@ -886,7 +924,20 @@ def chat_completion( if stream: return _stream_chat_completion_response(data) # type: ignore[arg-type] - return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type] + chat_completion_output = ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type] + if response_model: + for choice in chat_completion_output.choices: + if choice.message.content: + try: + # pydantic v2 uses model_validate_json + choice.message.parsed = ( + response_model.model_validate_json(choice.message.content) + if hasattr(response_model, "model_validate_json") + else response_model.parse_raw(choice.message.content) + ) + except ValueError: + choice.message.refusal = f"Failed to generate the response as a {response_model.__name__}" + return chat_completion_output def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str: # Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently. diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 74888bc0b8..c017ee8145 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -24,10 +24,23 @@ import re import time import warnings -from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Set, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterable, + Dict, + List, + Literal, + Optional, + Set, + Type, + Union, + overload, +) from requests.structures import CaseInsensitiveDict +from huggingface_hub._webhooks_payload import BaseModel from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS from huggingface_hub.errors import InferenceTimeoutError from huggingface_hub.inference._common import ( @@ -574,7 +587,7 @@ async def chat_completion( max_tokens: Optional[int] = None, n: Optional[int] = None, presence_penalty: Optional[float] = None, - response_format: Optional[ChatCompletionInputGrammarType] = None, + response_format: Optional[Union[ChatCompletionInputGrammarType, Type[BaseModel]]] = None, seed: Optional[int] = None, stop: Optional[List[str]] = None, stream_options: Optional[ChatCompletionInputStreamOptions] = None, @@ -626,8 +639,8 @@ async def chat_completion( presence_penalty (`float`, *optional*): Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. - response_format ([`ChatCompletionInputGrammarType`], *optional*): - Grammar constraints. Can be either a JSONSchema or a regex. + response_format ([`ChatCompletionInputGrammarType`] or `pydantic.BaseModel` class, *optional*): + Grammar constraints. Can be either a JSONSchema, a regex or a Pydantic schema. seed (Optional[`int`], *optional*): Seed for reproducible control flow. Defaults to None. stop (Optional[`str`], *optional*): @@ -861,7 +874,7 @@ async def chat_completion( ) ``` - Example using response_format: + Example using response_format (dict): ```py # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient @@ -892,7 +905,45 @@ async def chat_completion( >>> response.choices[0].message.content '{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}' ``` + + Example using response_format (pydantic): + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> from pydantic import BaseModel, conint + >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") + >>> messages = [ + ... { + ... "role": "user", + ... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?", + ... }, + ... ] + >>> class ActivitySummary(BaseModel): + ... location: str + ... activity: str + ... animals_seen: conint(ge=1, le=5) + ... animals: list[str] + >>> response = await client.chat_completion( + ... messages=messages, + ... response_format=ActivitySummary, + ... max_tokens=500, + ) + >>> response.choices[0].message.parsed + ActivitySummary(location='park', activity='bike ride', animals_seen=3, animals=['puppy', 'cat', 'raccoon']) + ``` """ + if isinstance(response_format, type) and issubclass(response_format, BaseModel): + response_model = response_format + # pydantic v2 uses model_json_schema + response_format = ChatCompletionInputGrammarType( + type="json", + value=response_model.model_json_schema() + if hasattr(response_model, "model_json_schema") + else response_model.schema(), + ) + else: + response_model = None + model_url = self._resolve_chat_completion_url(model) # `model` is sent in the payload. Not used by the server but can be useful for debugging/routing. @@ -928,7 +979,20 @@ async def chat_completion( if stream: return _async_stream_chat_completion_response(data) # type: ignore[arg-type] - return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type] + chat_completion_output = ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type] + if response_model: + for choice in chat_completion_output.choices: + if choice.message.content: + try: + # pydantic v2 uses model_validate_json + choice.message.parsed = ( + response_model.model_validate_json(choice.message.content) + if hasattr(response_model, "model_validate_json") + else response_model.parse_raw(choice.message.content) + ) + except ValueError: + choice.message.refusal = f"Failed to generate the response as a {response_model.__name__}" + return chat_completion_output def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str: # Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently. diff --git a/src/huggingface_hub/inference/_generated/types/chat_completion.py b/src/huggingface_hub/inference/_generated/types/chat_completion.py index 7a1f297e4f..a15b6c4887 100644 --- a/src/huggingface_hub/inference/_generated/types/chat_completion.py +++ b/src/huggingface_hub/inference/_generated/types/chat_completion.py @@ -6,6 +6,8 @@ from dataclasses import dataclass from typing import Any, List, Literal, Optional, Union +from huggingface_hub._webhooks_payload import BaseModel + from .base import BaseInferenceType @@ -196,6 +198,8 @@ class ChatCompletionOutputMessage(BaseInferenceType): role: str content: Optional[str] = None tool_calls: Optional[List[ChatCompletionOutputToolCall]] = None + parsed: Optional[BaseModel] = None + refusal: Optional[str] = None @dataclass