Skip to content

Fix: Enable max_retries Parameter in ChatMistralAI Class #30448

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions libs/partners/mistralai/langchain_mistralai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,9 @@ def completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
# retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)

# @retry_decorator
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
if "stream" not in kwargs:
kwargs["stream"] = False
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Test ChatMistral chat model."""

import json
import logging
import time
from typing import Any, Optional

import pytest
from httpx import ReadTimeout
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
Expand Down Expand Up @@ -296,3 +299,39 @@ class Person(BaseModel):
acc = chunk if acc is None else acc + chunk
assert acc.content != ""
assert "tool_calls" not in acc.additional_kwargs


def test_retry_parameters(caplog: pytest.LogCaptureFixture) -> None:
"""Test that retry parameters are honored in ChatMistralAI."""
# Create a model with intentionally short timeout and multiple retries
mistral = ChatMistralAI(
timeout=1, # Very short timeout to trigger timeouts
max_retries=3, # Should retry 3 times
)

# Simple test input that should take longer than 1 second to process
test_input = "Write a 2 sentence story about a cat"

# Measure start time
t0 = time.time()

try:
# Try to get a response
response = mistral.invoke(test_input)

# If successful, validate the response
elapsed_time = time.time() - t0
logging.info(f"Request succeeded in {elapsed_time:.2f} seconds")
# Check that we got a valid response
assert response.content
assert isinstance(response.content, str)
assert "cat" in response.content.lower()

except ReadTimeout:
elapsed_time = time.time() - t0
logging.info(f"Request timed out after {elapsed_time:.2f} seconds")
assert elapsed_time >= 3.0
pytest.skip("Test timed out as expected with short timeout")
except Exception as e:
logging.error(f"Unexpected exception: {e}")
raise
46 changes: 45 additions & 1 deletion libs/partners/mistralai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import os
from typing import Any, AsyncGenerator, Dict, Generator, List, cast
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import httpx
import pytest
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.messages import (
Expand Down Expand Up @@ -270,3 +271,46 @@ def test_extra_kwargs() -> None:
# Test that if provided twice it errors
with pytest.raises(ValueError):
ChatMistralAI(model="my-model", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]


def test_retry_with_failure_then_success() -> None:
"""Test that retry mechanism works correctly when
first request fails and second succeeds."""
# Create a real ChatMistralAI instance
chat = ChatMistralAI(max_retries=3)

# Set up the actual retry mechanism (not just mocking it)
# We'll track how many times the function is called
call_count = 0

def mock_post(*args: Any, **kwargs: Any) -> MagicMock:
nonlocal call_count
call_count += 1

if call_count == 1:
raise httpx.RequestError("Connection error", request=MagicMock())

mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [
{
"message": {
"role": "assistant",
"content": "Hello!",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
},
}
return mock_response

with patch.object(chat.client, "post", side_effect=mock_post):
result = chat.invoke("Hello")
assert result.content == "Hello!"
assert call_count == 2, f"Expected 2 calls, but got {call_count}"
Loading