Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 69 additions & 3 deletions libs/partners/deepseek/langchain_deepseek/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,16 @@
default_factory=from_env("DEEPSEEK_API_BASE", default=DEFAULT_API_BASE),
)
"""DeepSeek API base URL"""
strict: bool | None = Field(
default=None,
description=(
"Whether to enable strict mode for function calling. "
"When enabled, uses the Beta API endpoint and ensures "
"outputs strictly comply with the defined JSON schema."
),
)

model_config = ConfigDict(populate_by_name=True)

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
Expand All @@ -198,16 +205,22 @@
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate necessary environment vars and client params."""
if self.api_base == DEFAULT_API_BASE and not (
# Use Beta API if strict mode is enabled
api_base = self.api_base
if self.strict and self.api_base == DEFAULT_API_BASE:
api_base = "https://api.deepseek.com/beta"

if api_base == DEFAULT_API_BASE and not (
self.api_key and self.api_key.get_secret_value()
):
msg = "If using default api base, DEEPSEEK_API_KEY must be set."
raise ValueError(msg)

client_params: dict = {
k: v
for k, v in {
"api_key": self.api_key.get_secret_value() if self.api_key else None,
"base_url": self.api_base,
"base_url": api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
Expand All @@ -229,6 +242,59 @@
self.async_client = self.root_async_client.chat.completions
return self

def bind_tools(
self,
tools: list,
*,
tool_choice: str | dict | None = None,
strict: bool | None = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tools to the model with optional strict mode.

Args:
tools: A list of tool definitions or Pydantic models.
tool_choice: Which tool the model should use.
strict: Whether to enable strict mode for these tools.
If not provided, uses the instance's strict setting.
**kwargs: Additional arguments to pass to the parent method.

Returns:
A Runnable that will call the model with the bound tools.
"""
# Use instance strict setting if not explicitly provided
use_strict = strict if strict is not None else self.strict

# If strict mode is enabled, add strict: true to each tool
if use_strict:
formatted_tools = []
for tool in tools:
# Convert to OpenAI format
from langchain_core.utils.function_calling import convert_to_openai_tool

Check failure on line 273 in libs/partners/deepseek/langchain_deepseek/chat_models.py

View workflow job for this annotation

GitHub Actions / lint (libs/partners/deepseek, 3.14) / Python 3.14

Ruff (PLC0415)

langchain_deepseek/chat_models.py:273:17: PLC0415 `import` should be at the top-level of a file

Check failure on line 273 in libs/partners/deepseek/langchain_deepseek/chat_models.py

View workflow job for this annotation

GitHub Actions / lint (libs/partners/deepseek, 3.10) / Python 3.10

Ruff (PLC0415)

langchain_deepseek/chat_models.py:273:17: PLC0415 `import` should be at the top-level of a file

if not isinstance(tool, dict):
tool_dict = convert_to_openai_tool(tool)
else:
tool_dict = tool.copy()

# Add strict: true to the function definition
if "function" in tool_dict:
tool_dict["function"]["strict"] = True

formatted_tools.append(tool_dict)

tools = formatted_tools

# Add strict to kwargs if it's being used
if use_strict is not None:
kwargs["strict"] = use_strict

return super().bind_tools(
tools,
tool_choice=tool_choice,
**kwargs,
)

def _get_request_payload(
self,
input_: LanguageModelInput,
Expand Down
82 changes: 82 additions & 0 deletions libs/partners/deepseek/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,85 @@ def test_create_chat_result_with_model_provider_multiple_generations(
assert (
generation.message.response_metadata.get("model_provider") == "deepseek"
)


class TestChatDeepSeekStrictMode:
"""Test strict mode functionality."""

def test_strict_mode_uses_beta_api(self) -> None:
"""Test that strict mode switches to Beta API endpoint."""
model = ChatDeepSeek(
model=MODEL_NAME,
api_key=SecretStr("test-key"),
strict=True,
)

# Check that the client uses the beta endpoint
assert str(model.root_client.base_url) == "https://api.deepseek.com/beta/"

def test_strict_mode_disabled_uses_default_api(self) -> None:
"""Test that without strict mode, default API is used."""
model = ChatDeepSeek(
model=MODEL_NAME,
api_key=SecretStr("test-key"),
strict=False,
)

# Check that the client uses the default endpoint
assert str(model.root_client.base_url) == "https://api.deepseek.com/v1/"

def test_strict_mode_none_uses_default_api(self) -> None:
"""Test that strict=None uses default API."""
model = ChatDeepSeek(
model=MODEL_NAME,
api_key=SecretStr("test-key"),
)

# Check that the client uses the default endpoint
assert str(model.root_client.base_url) == "https://api.deepseek.com/v1/"

def test_bind_tools_with_strict_mode(self) -> None:
"""Test that bind_tools adds strict to tool definitions."""
from pydantic import BaseModel, Field

class GetWeather(BaseModel):
"""Get the current weather in a given location."""
location: str = Field(..., description="The city and state") # pyright: ignore[reportUndefinedVariable]

model = ChatDeepSeek(
model=MODEL_NAME,
api_key=SecretStr("test-key"),
strict=True,
)

# Bind tools
model_with_tools = model.bind_tools([GetWeather])

# Check that tools were bound
assert 'tools' in model_with_tools.kwargs

# Verify that tools have strict property set
tools = model_with_tools.kwargs['tools']
assert len(tools) > 0
assert tools[0]['function']['strict'] is True
def test_bind_tools_override_strict(self) -> None:
"""Test that bind_tools can override instance strict setting."""
from pydantic import BaseModel, Field

class GetWeather(BaseModel):
"""Get the current weather in a given location."""
location: str = Field(..., description="The city and state")

model = ChatDeepSeek(
model=MODEL_NAME,
api_key=SecretStr("test-key"),
strict=False,
)

# Override with strict=True in bind_tools
model_with_tools = model.bind_tools([GetWeather], strict=True)

# Check that strict was passed to kwargs
assert 'tools' in model_with_tools.kwargs
tools = model_with_tools.kwargs['tools']
assert tools[0]['function']['strict'] is True
Loading