diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index a7deb8b547126..224830fc16fad 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -457,9 +457,9 @@ def completion_with_retry( self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any ) -> Any: """Use tenacity to retry the completion call.""" - # retry_decorator = _create_retry_decorator(self, run_manager=run_manager) + retry_decorator = _create_retry_decorator(self, run_manager=run_manager) - # @retry_decorator + @retry_decorator def _completion_with_retry(**kwargs: Any) -> Any: if "stream" not in kwargs: kwargs["stream"] = False diff --git a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py index d3592ef32fdce..e0c58d25b67d6 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py @@ -1,9 +1,12 @@ """Test ChatMistral chat model.""" import json +import logging +import time from typing import Any, Optional import pytest +from httpx import ReadTimeout from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -296,3 +299,39 @@ class Person(BaseModel): acc = chunk if acc is None else acc + chunk assert acc.content != "" assert "tool_calls" not in acc.additional_kwargs + + +def test_retry_parameters(caplog: pytest.LogCaptureFixture) -> None: + """Test that retry parameters are honored in ChatMistralAI.""" + # Create a model with intentionally short timeout and multiple retries + mistral = ChatMistralAI( + timeout=1, # Very short timeout to trigger timeouts + max_retries=3, # Should retry 3 times + ) + + # Simple test input that should take longer than 1 second to process + test_input = "Write a 2 sentence story about a cat" + + # Measure start time + t0 = time.time() + + try: + # Try to get a response + response = mistral.invoke(test_input) + + # If successful, validate the response + elapsed_time = time.time() - t0 + logging.info(f"Request succeeded in {elapsed_time:.2f} seconds") + # Check that we got a valid response + assert response.content + assert isinstance(response.content, str) + assert "cat" in response.content.lower() + + except ReadTimeout: + elapsed_time = time.time() - t0 + logging.info(f"Request timed out after {elapsed_time:.2f} seconds") + assert elapsed_time >= 3.0 + pytest.skip("Test timed out as expected with short timeout") + except Exception as e: + logging.error(f"Unexpected exception: {e}") + raise diff --git a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py index 4dc251832e7a8..6a94f431cfddb 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py @@ -2,8 +2,9 @@ import os from typing import Any, AsyncGenerator, Dict, Generator, List, cast -from unittest.mock import patch +from unittest.mock import MagicMock, patch +import httpx import pytest from langchain_core.callbacks.base import BaseCallbackHandler from langchain_core.messages import ( @@ -270,3 +271,46 @@ def test_extra_kwargs() -> None: # Test that if provided twice it errors with pytest.raises(ValueError): ChatMistralAI(model="my-model", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg] + + +def test_retry_with_failure_then_success() -> None: + """Test that retry mechanism works correctly when + first request fails and second succeeds.""" + # Create a real ChatMistralAI instance + chat = ChatMistralAI(max_retries=3) + + # Set up the actual retry mechanism (not just mocking it) + # We'll track how many times the function is called + call_count = 0 + + def mock_post(*args: Any, **kwargs: Any) -> MagicMock: + nonlocal call_count + call_count += 1 + + if call_count == 1: + raise httpx.RequestError("Connection error", request=MagicMock()) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello!", + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 1, + "total_tokens": 2, + }, + } + return mock_response + + with patch.object(chat.client, "post", side_effect=mock_post): + result = chat.invoke("Hello") + assert result.content == "Hello!" + assert call_count == 2, f"Expected 2 calls, but got {call_count}"