diff --git a/libs/partners/deepseek/langchain_deepseek/chat_models.py b/libs/partners/deepseek/langchain_deepseek/chat_models.py index d7bfa3d6b28d6..da1d1e5312307 100644 --- a/libs/partners/deepseek/langchain_deepseek/chat_models.py +++ b/libs/partners/deepseek/langchain_deepseek/chat_models.py @@ -173,9 +173,16 @@ class Joke(BaseModel): default_factory=from_env("DEEPSEEK_API_BASE", default=DEFAULT_API_BASE), ) """DeepSeek API base URL""" + strict: bool | None = Field( + default=None, + description=( + "Whether to enable strict mode for function calling. " + "When enabled, uses the Beta API endpoint and ensures " + "outputs strictly comply with the defined JSON schema." + ), + ) model_config = ConfigDict(populate_by_name=True) - @property def _llm_type(self) -> str: """Return type of chat model.""" @@ -198,16 +205,22 @@ def _get_ls_params( @model_validator(mode="after") def validate_environment(self) -> Self: """Validate necessary environment vars and client params.""" - if self.api_base == DEFAULT_API_BASE and not ( + # Use Beta API if strict mode is enabled + api_base = self.api_base + if self.strict and self.api_base == DEFAULT_API_BASE: + api_base = "https://api.deepseek.com/beta" + + if api_base == DEFAULT_API_BASE and not ( self.api_key and self.api_key.get_secret_value() ): msg = "If using default api base, DEEPSEEK_API_KEY must be set." raise ValueError(msg) + client_params: dict = { k: v for k, v in { "api_key": self.api_key.get_secret_value() if self.api_key else None, - "base_url": self.api_base, + "base_url": api_base, "timeout": self.request_timeout, "max_retries": self.max_retries, "default_headers": self.default_headers, @@ -229,6 +242,59 @@ def validate_environment(self) -> Self: self.async_client = self.root_async_client.chat.completions return self + def bind_tools( + self, + tools: list, + *, + tool_choice: str | dict | None = None, + strict: bool | None = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tools to the model with optional strict mode. + + Args: + tools: A list of tool definitions or Pydantic models. + tool_choice: Which tool the model should use. + strict: Whether to enable strict mode for these tools. + If not provided, uses the instance's strict setting. + **kwargs: Additional arguments to pass to the parent method. + + Returns: + A Runnable that will call the model with the bound tools. + """ + # Use instance strict setting if not explicitly provided + use_strict = strict if strict is not None else self.strict + + # If strict mode is enabled, add strict: true to each tool + if use_strict: + formatted_tools = [] + for tool in tools: + # Convert to OpenAI format + from langchain_core.utils.function_calling import convert_to_openai_tool + + if not isinstance(tool, dict): + tool_dict = convert_to_openai_tool(tool) + else: + tool_dict = tool.copy() + + # Add strict: true to the function definition + if "function" in tool_dict: + tool_dict["function"]["strict"] = True + + formatted_tools.append(tool_dict) + + tools = formatted_tools + + # Add strict to kwargs if it's being used + if use_strict is not None: + kwargs["strict"] = use_strict + + return super().bind_tools( + tools, + tool_choice=tool_choice, + **kwargs, + ) + def _get_request_payload( self, input_: LanguageModelInput, diff --git a/libs/partners/deepseek/tests/unit_tests/test_chat_models.py b/libs/partners/deepseek/tests/unit_tests/test_chat_models.py index 31be8ab98c160..dbf9495c4f491 100644 --- a/libs/partners/deepseek/tests/unit_tests/test_chat_models.py +++ b/libs/partners/deepseek/tests/unit_tests/test_chat_models.py @@ -311,3 +311,85 @@ def test_create_chat_result_with_model_provider_multiple_generations( assert ( generation.message.response_metadata.get("model_provider") == "deepseek" ) + + +class TestChatDeepSeekStrictMode: + """Test strict mode functionality.""" + + def test_strict_mode_uses_beta_api(self) -> None: + """Test that strict mode switches to Beta API endpoint.""" + model = ChatDeepSeek( + model=MODEL_NAME, + api_key=SecretStr("test-key"), + strict=True, + ) + + # Check that the client uses the beta endpoint + assert str(model.root_client.base_url) == "https://api.deepseek.com/beta/" + + def test_strict_mode_disabled_uses_default_api(self) -> None: + """Test that without strict mode, default API is used.""" + model = ChatDeepSeek( + model=MODEL_NAME, + api_key=SecretStr("test-key"), + strict=False, + ) + + # Check that the client uses the default endpoint + assert str(model.root_client.base_url) == "https://api.deepseek.com/v1/" + + def test_strict_mode_none_uses_default_api(self) -> None: + """Test that strict=None uses default API.""" + model = ChatDeepSeek( + model=MODEL_NAME, + api_key=SecretStr("test-key"), + ) + + # Check that the client uses the default endpoint + assert str(model.root_client.base_url) == "https://api.deepseek.com/v1/" + + def test_bind_tools_with_strict_mode(self) -> None: + """Test that bind_tools adds strict to tool definitions.""" + from pydantic import BaseModel, Field + + class GetWeather(BaseModel): + """Get the current weather in a given location.""" + location: str = Field(..., description="The city and state") # pyright: ignore[reportUndefinedVariable] + + model = ChatDeepSeek( + model=MODEL_NAME, + api_key=SecretStr("test-key"), + strict=True, + ) + + # Bind tools + model_with_tools = model.bind_tools([GetWeather]) + + # Check that tools were bound + assert 'tools' in model_with_tools.kwargs + + # Verify that tools have strict property set + tools = model_with_tools.kwargs['tools'] + assert len(tools) > 0 + assert tools[0]['function']['strict'] is True + def test_bind_tools_override_strict(self) -> None: + """Test that bind_tools can override instance strict setting.""" + from pydantic import BaseModel, Field + + class GetWeather(BaseModel): + """Get the current weather in a given location.""" + location: str = Field(..., description="The city and state") + + model = ChatDeepSeek( + model=MODEL_NAME, + api_key=SecretStr("test-key"), + strict=False, + ) + + # Override with strict=True in bind_tools + model_with_tools = model.bind_tools([GetWeather], strict=True) + + # Check that strict was passed to kwargs + assert 'tools' in model_with_tools.kwargs + tools = model_with_tools.kwargs['tools'] + assert tools[0]['function']['strict'] is True