From 7457d9521b4363113bdf0bcc39ea0b0504d3e649 Mon Sep 17 00:00:00 2001 From: noeliecherrier Date: Wed, 15 Oct 2025 12:38:29 +0200 Subject: [PATCH 1/2] feat(mistralai): remove tenacity retries for embeddings --- .../langchain_mistralai/embeddings.py | 22 ------------------- .../integration_tests/test_embeddings.py | 5 ++--- 2 files changed, 2 insertions(+), 25 deletions(-) diff --git a/libs/partners/mistralai/langchain_mistralai/embeddings.py b/libs/partners/mistralai/langchain_mistralai/embeddings.py index 26093d15c20c7..f333bfaf73337 100644 --- a/libs/partners/mistralai/langchain_mistralai/embeddings.py +++ b/libs/partners/mistralai/langchain_mistralai/embeddings.py @@ -16,7 +16,6 @@ SecretStr, model_validator, ) -from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed from tokenizers import Tokenizer # type: ignore[import] from typing_extensions import Self @@ -57,13 +56,8 @@ class MistralAIEmbeddings(BaseModel, Embeddings): api_key: SecretStr | None The API key for the MistralAI API. If not provided, it will be read from the environment variable `MISTRAL_API_KEY`. - max_retries: int - The number of times to retry a request if it fails. timeout: int The number of seconds to wait for a response before timing out. - wait_time: int - The number of seconds to wait before retrying a request in case of 429 - error. max_concurrent_requests: int The maximum number of concurrent requests to make to the Mistral API. @@ -133,9 +127,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings): default_factory=secret_from_env("MISTRAL_API_KEY", default=""), ) endpoint: str = "https://api.mistral.ai/v1/" - max_retries: int = 5 timeout: int = 120 - wait_time: int = 30 max_concurrent_requests: int = 64 tokenizer: Tokenizer = Field(default=None) @@ -225,13 +217,6 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: try: batch_responses = [] - @retry( - retry=retry_if_exception_type( - (httpx.TimeoutException, httpx.HTTPStatusError) - ), - wait=wait_fixed(self.wait_time), - stop=stop_after_attempt(self.max_retries), - ) def _embed_batch(batch: list[str]) -> Response: response = self.client.post( url="/embeddings", @@ -266,13 +251,6 @@ async def aembed_documents(self, texts: list[str]) -> list[list[float]]: """ try: - @retry( - retry=retry_if_exception_type( - (httpx.TimeoutException, httpx.HTTPStatusError) - ), - wait=wait_fixed(self.wait_time), - stop=stop_after_attempt(self.max_retries), - ) async def _aembed_batch(batch: list[str]) -> Response: response = await self.async_client.post( url="/embeddings", diff --git a/libs/partners/mistralai/tests/integration_tests/test_embeddings.py b/libs/partners/mistralai/tests/integration_tests/test_embeddings.py index 3ef91728e2fd7..383af9dc0a6a4 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_embeddings.py +++ b/libs/partners/mistralai/tests/integration_tests/test_embeddings.py @@ -4,7 +4,6 @@ import httpx import pytest -import tenacity from langchain_mistralai import MistralAIEmbeddings @@ -38,14 +37,14 @@ async def test_mistralai_embedding_documents_async() -> None: async def test_mistralai_embedding_documents_http_error_async() -> None: """Test MistralAI embeddings for documents.""" documents = ["foo bar", "test document"] - embedding = MistralAIEmbeddings(max_retries=0) + embedding = MistralAIEmbeddings() mock_response = httpx.Response( status_code=400, request=httpx.Request("POST", url=embedding.async_client.base_url), ) with ( patch.object(embedding.async_client, "post", return_value=mock_response), - pytest.raises(tenacity.RetryError), + pytest.raises(httpx.HTTPStatusError), ): await embedding.aembed_documents(documents) From aa48b389a5e04b9fc9d38e351793c1f71fb41743 Mon Sep 17 00:00:00 2001 From: noeliecherrier Date: Mon, 20 Oct 2025 17:04:15 +0200 Subject: [PATCH 2/2] feat(mistralai): enable tenacity retries opt-out --- .../langchain_mistralai/embeddings.py | 19 ++++++++++++++++++- .../integration_tests/test_embeddings.py | 18 +++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/libs/partners/mistralai/langchain_mistralai/embeddings.py b/libs/partners/mistralai/langchain_mistralai/embeddings.py index 9675090489644..f8fd08dc72d74 100644 --- a/libs/partners/mistralai/langchain_mistralai/embeddings.py +++ b/libs/partners/mistralai/langchain_mistralai/embeddings.py @@ -1,7 +1,7 @@ import asyncio import logging import warnings -from collections.abc import Iterable +from collections.abc import Callable, Iterable import httpx from httpx import Response @@ -16,6 +16,7 @@ SecretStr, model_validator, ) +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed from tokenizers import Tokenizer # type: ignore[import] from typing_extensions import Self @@ -133,7 +134,9 @@ class MistralAIEmbeddings(BaseModel, Embeddings): default_factory=secret_from_env("MISTRAL_API_KEY", default=""), ) endpoint: str = "https://api.mistral.ai/v1/" + max_retries: int | None = 5 timeout: int = 120 + wait_time: int | None = 30 max_concurrent_requests: int = 64 tokenizer: Tokenizer = Field(default=None) @@ -210,6 +213,18 @@ def _get_batches(self, texts: list[str]) -> Iterable[list[str]]: if batch: yield batch + def _retry(self, func: Callable) -> Callable: + if self.max_retries is None or self.wait_time is None: + return func + + return retry( + retry=retry_if_exception_type( + (httpx.TimeoutException, httpx.HTTPStatusError) + ), + wait=wait_fixed(self.wait_time), + stop=stop_after_attempt(self.max_retries), + )(func) + def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed a list of document texts. @@ -223,6 +238,7 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: try: batch_responses = [] + @self._retry def _embed_batch(batch: list[str]) -> Response: response = self.client.post( url="/embeddings", @@ -257,6 +273,7 @@ async def aembed_documents(self, texts: list[str]) -> list[list[float]]: """ try: + @self._retry async def _aembed_batch(batch: list[str]) -> Response: response = await self.async_client.post( url="/embeddings", diff --git a/libs/partners/mistralai/tests/integration_tests/test_embeddings.py b/libs/partners/mistralai/tests/integration_tests/test_embeddings.py index 383af9dc0a6a4..b5e0f2787cc9f 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_embeddings.py +++ b/libs/partners/mistralai/tests/integration_tests/test_embeddings.py @@ -4,6 +4,7 @@ import httpx import pytest +import tenacity from langchain_mistralai import MistralAIEmbeddings @@ -34,10 +35,25 @@ async def test_mistralai_embedding_documents_async() -> None: assert len(output[0]) == 1024 +async def test_mistralai_embedding_documents_tenacity_error_async() -> None: + """Test MistralAI embeddings for documents.""" + documents = ["foo bar", "test document"] + embedding = MistralAIEmbeddings(max_retries=0) + mock_response = httpx.Response( + status_code=400, + request=httpx.Request("POST", url=embedding.async_client.base_url), + ) + with ( + patch.object(embedding.async_client, "post", return_value=mock_response), + pytest.raises(tenacity.RetryError), + ): + await embedding.aembed_documents(documents) + + async def test_mistralai_embedding_documents_http_error_async() -> None: """Test MistralAI embeddings for documents.""" documents = ["foo bar", "test document"] - embedding = MistralAIEmbeddings() + embedding = MistralAIEmbeddings(max_retries=None) mock_response = httpx.Response( status_code=400, request=httpx.Request("POST", url=embedding.async_client.base_url),