diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py index 7cf54d7682..709afff32b 100644 --- a/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py +++ b/llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py @@ -25,6 +25,7 @@ TextBlock, ImageBlock, ThinkingBlock, + ToolCallBlock, ) from llama_index.core.bridge.pydantic import Field, PrivateAttr from llama_index.core.callbacks import CallbackManager @@ -53,7 +54,7 @@ ) from mistralai import Mistral -from mistralai.models import ToolCall +from mistralai.models import ToolCall, FunctionCall from mistralai.models import ( Messages, AssistantMessage, @@ -102,6 +103,8 @@ def to_mistral_chunks(content_blocks: Sequence[ContentBlock]) -> Sequence[Conten image_url=f"data:{image_mimetype};base64,{base_64_str}" ) ) + elif isinstance(content_block, ToolCallBlock): + pass else: raise ValueError(f"Unsupported content block type {type(content_block)}") return content_chunks @@ -112,7 +115,33 @@ def to_mistral_chatmessage( ) -> List[Messages]: new_messages = [] for m in messages: - tool_calls = m.additional_kwargs.get("tool_calls") + unique_tool_calls = [] + tool_calls_li = [ + block for block in m.blocks if isinstance(block, ToolCallBlock) + ] + tool_calls = [] + for tool_call_li in tool_calls_li: + tool_calls.append( + ToolCall( + id=tool_call_li.tool_call_id, + function=FunctionCall( + name=tool_call_li.tool_name, + arguments=tool_call_li.tool_kwargs, + ), + ) + ) + unique_tool_calls.append( + (tool_call_li.tool_call_id, tool_call_li.tool_name) + ) + # try with legacy tool calls for compatibility with older chat histories + if len(m.additional_kwargs.get("tool_calls", [])) > 0: + tcs = m.additional_kwargs.get("tool_calls", []) + for tc in tcs: + if ( + isinstance(tc, ToolCall) + and (tc.id, tc.function.name) not in unique_tool_calls + ): + tool_calls.append(tc) chunks = to_mistral_chunks(m.blocks) if m.role == MessageRole.USER: new_messages.append(UserMessage(content=chunks)) @@ -135,9 +164,15 @@ def to_mistral_chatmessage( def force_single_tool_call(response: ChatResponse) -> None: - tool_calls = response.message.additional_kwargs.get("tool_calls", []) + tool_calls = [ + block for block in response.message.blocks if isinstance(block, ToolCallBlock) + ] if len(tool_calls) > 1: - response.message.additional_kwargs["tool_calls"] = [tool_calls[0]] + response.message.blocks = [ + block + for block in response.message.blocks + if not isinstance(block, ToolCallBlock) + ] + [tool_calls[0]] class MistralAI(FunctionCallingLLM): @@ -296,17 +331,29 @@ def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]: **kwargs, } - def _separate_thinking(self, response: str) -> Tuple[str, str]: + def _separate_thinking( + self, response: Union[str, List[ContentChunk]] + ) -> Tuple[str, str]: """Separate the thinking from the response.""" - match = THINKING_REGEX.search(response) + content = "" + if isinstance(response, str): + content = response + else: + for chunk in response: + if isinstance(chunk, ThinkChunk): + for c in chunk.thinking: + if isinstance(c, TextChunk): + content += c.text + "\n" + + match = THINKING_REGEX.search(content) if match: - return match.group(1), response.replace(match.group(0), "") + return match.group(1), content.replace(match.group(0), "") - match = THINKING_START_REGEX.search(response) + match = THINKING_START_REGEX.search(content) if match: return match.group(0), "" - return "", response + return "", content @llm_chat_callback() def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: @@ -315,20 +362,30 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: messages = to_mistral_chatmessage(messages) all_kwargs = self._get_all_kwargs(**kwargs) response = self._client.chat.complete(messages=messages, **all_kwargs) - blocks: List[TextBlock | ThinkingBlock] = [] + blocks: List[TextBlock | ThinkingBlock | ToolCallBlock] = [] - additional_kwargs = {} if self.model in MISTRAL_AI_REASONING_MODELS: thinking_txt, response_txt = self._separate_thinking( - response.choices[0].message.content + response.choices[0].message.content or [] ) if thinking_txt: blocks.append(ThinkingBlock(content=thinking_txt)) + response_txt_think_show = "" + if response.choices[0].message.content: + if isinstance(response.choices[0].message.content, str): + response_txt_think_show = response.choices[0].message.content + else: + for chunk in response.choices[0].message.content: + if isinstance(chunk, TextBlock): + response_txt_think_show += chunk.text + "\n" + if isinstance(chunk, ThinkChunk): + for c in chunk.thinking: + if isinstance(c, TextChunk): + response_txt_think_show += c.text + "\n" + response_txt = ( - response_txt - if not self.show_thinking - else response.choices[0].message.content + response_txt if not self.show_thinking else response_txt_think_show ) else: response_txt = response.choices[0].message.content @@ -336,13 +393,20 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: blocks.append(TextBlock(text=response_txt)) tool_calls = response.choices[0].message.tool_calls if tool_calls is not None: - additional_kwargs["tool_calls"] = tool_calls + for tool_call in tool_calls: + if isinstance(tool_call, ToolCall): + blocks.append( + ToolCallBlock( + tool_call_id=tool_call.id, + tool_kwargs=tool_call.function.arguments, + tool_name=tool_call.function.name, + ) + ) return ChatResponse( message=ChatMessage( role=MessageRole.ASSISTANT, blocks=blocks, - additional_kwargs=additional_kwargs, ), raw=dict(response), ) @@ -367,18 +431,39 @@ def stream_chat( def gen() -> ChatResponseGen: content = "" - blocks: List[TextBlock | ThinkingBlock] = [] + blocks: List[TextBlock | ThinkingBlock | ToolCallBlock] = [] for chunk in response: delta = chunk.data.choices[0].delta role = delta.role or MessageRole.ASSISTANT # NOTE: Unlike openAI, we are directly injecting the tool calls - additional_kwargs = {} if delta.tool_calls: - additional_kwargs["tool_calls"] = delta.tool_calls + for tool_call in delta.tool_calls: + if isinstance(tool_call, ToolCall): + blocks.append( + ToolCallBlock( + tool_call_id=tool_call.id, + tool_name=tool_call.function.name, + tool_kwargs=tool_call.function.arguments, + ) + ) content_delta = delta.content or "" - content += content_delta + content_delta_str = "" + if isinstance(content_delta, str): + content_delta_str = content_delta + else: + for chunk in content_delta: + if isinstance(chunk, TextChunk): + content_delta_str += chunk.text + "\n" + elif isinstance(chunk, ThinkChunk): + for c in chunk.thinking: + if isinstance(c, TextChunk): + content_delta_str += c.text + "\n" + else: + continue + + content += content_delta_str # decide whether to include thinking in deltas/responses if self.model in MISTRAL_AI_REASONING_MODELS: @@ -392,15 +477,14 @@ def gen() -> ChatResponseGen: # If thinking hasn't ended, don't include it in the delta if thinking_txt is None and not self.show_thinking: content_delta = "" - blocks.append(TextBlock(text=content)) + blocks.append(TextBlock(text=content)) yield ChatResponse( message=ChatMessage( role=role, blocks=blocks, - additional_kwargs=additional_kwargs, ), - delta=content_delta, + delta=content_delta_str, raw=chunk, ) @@ -425,19 +509,30 @@ async def achat( messages=messages, **all_kwargs ) - blocks: List[TextBlock | ThinkingBlock] = [] + blocks: List[TextBlock | ThinkingBlock | ToolCallBlock] = [] additional_kwargs = {} if self.model in MISTRAL_AI_REASONING_MODELS: thinking_txt, response_txt = self._separate_thinking( - response.choices[0].message.content + response.choices[0].message.content or [] ) if thinking_txt: blocks.append(ThinkingBlock(content=thinking_txt)) + response_txt_think_show = "" + if response.choices[0].message.content: + if isinstance(response.choices[0].message.content, str): + response_txt_think_show = response.choices[0].message.content + else: + for chunk in response.choices[0].message.content: + if isinstance(chunk, TextBlock): + response_txt_think_show += chunk.text + "\n" + if isinstance(chunk, ThinkChunk): + for c in chunk.thinking: + if isinstance(c, TextChunk): + response_txt_think_show += c.text + "\n" + response_txt = ( - response_txt - if not self.show_thinking - else response.choices[0].message.content + response_txt if not self.show_thinking else response_txt_think_show ) else: response_txt = response.choices[0].message.content @@ -446,7 +541,25 @@ async def achat( tool_calls = response.choices[0].message.tool_calls if tool_calls is not None: - additional_kwargs["tool_calls"] = tool_calls + for tool_call in tool_calls: + if isinstance(tool_call, ToolCall): + blocks.append( + ToolCallBlock( + tool_call_id=tool_call.id, + tool_kwargs=tool_call.function.arguments, + tool_name=tool_call.function.name, + ) + ) + else: + if isinstance(tool_call[1], (str, dict)): + blocks.append( + ToolCallBlock( + tool_kwargs=tool_call[1], tool_name=tool_call[0] + ) + ) + additional_kwargs["tool_calls"] = ( + tool_calls # keep this to avoid tool calls loss if tool call does not fall within the validation scenarios above + ) return ChatResponse( message=ChatMessage( @@ -477,17 +590,38 @@ async def astream_chat( async def gen() -> ChatResponseAsyncGen: content = "" - blocks: List[ThinkingBlock | TextBlock] = [] + blocks: List[ThinkingBlock | TextBlock | ToolCallBlock] = [] async for chunk in response: delta = chunk.data.choices[0].delta role = delta.role or MessageRole.ASSISTANT # NOTE: Unlike openAI, we are directly injecting the tool calls - additional_kwargs = {} if delta.tool_calls: - additional_kwargs["tool_calls"] = delta.tool_calls + for tool_call in delta.tool_calls: + if isinstance(tool_call, ToolCall): + blocks.append( + ToolCallBlock( + tool_call_id=tool_call.id, + tool_name=tool_call.function.name, + tool_kwargs=tool_call.function.arguments, + ) + ) content_delta = delta.content or "" - content += content_delta + content_delta_str = "" + if isinstance(content_delta, str): + content_delta_str = content_delta + else: + for chunk in content_delta: + if isinstance(chunk, TextChunk): + content_delta_str += chunk.text + "\n" + elif isinstance(chunk, ThinkChunk): + for c in chunk.thinking: + if isinstance(c, TextChunk): + content_delta_str += c.text + "\n" + else: + continue + + content += content_delta_str # decide whether to include thinking in deltas/responses if self.model in MISTRAL_AI_REASONING_MODELS: @@ -501,15 +635,14 @@ async def gen() -> ChatResponseAsyncGen: if thinking_txt is None and not self.show_thinking: content_delta = "" - blocks.append(TextBlock(text=content)) + blocks.append(TextBlock(text=content)) yield ChatResponse( message=ChatMessage( role=role, blocks=blocks, - additional_kwargs=additional_kwargs, ), - delta=content_delta, + delta=content_delta_str, raw=chunk, ) @@ -570,7 +703,11 @@ def get_tool_calls_from_response( error_on_no_tool_call: bool = True, ) -> List[ToolSelection]: """Predict and call the tool.""" - tool_calls = response.message.additional_kwargs.get("tool_calls", []) + tool_calls = [ + block + for block in response.message.blocks + if isinstance(block, ToolCallBlock) + ] if len(tool_calls) < 1: if error_on_no_tool_call: @@ -582,15 +719,15 @@ def get_tool_calls_from_response( tool_selections = [] for tool_call in tool_calls: - if not isinstance(tool_call, ToolCall): - raise ValueError("Invalid tool_call object") - - argument_dict = json.loads(tool_call.function.arguments) + if isinstance(tool_call.tool_kwargs, str): + argument_dict = json.loads(tool_call.tool_kwargs) + else: + argument_dict = tool_call.tool_kwargs tool_selections.append( ToolSelection( - tool_id=tool_call.id, - tool_name=tool_call.function.name, + tool_id=tool_call.tool_call_id or "", + tool_name=tool_call.tool_name, tool_kwargs=argument_dict, ) ) diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-mistralai/pyproject.toml index 920504b4d5..94c8d0d868 100644 --- a/llama-index-integrations/llms/llama-index-llms-mistralai/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-mistralai/pyproject.toml @@ -26,15 +26,15 @@ dev = [ [project] name = "llama-index-llms-mistralai" -version = "0.8.2" +version = "0.9.0" description = "llama-index llms mistral ai integration" authors = [{name = "Your Name", email = "you@example.com"}] requires-python = ">=3.9,<4.0" readme = "README.md" license = "MIT" dependencies = [ - "mistralai>=1.8.2", - "llama-index-core>=0.14.3,<0.15", + "mistralai>=1.9.11", + "llama-index-core>=0.14.5,<0.15", ] [tool.codespell] diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/tests/test_llms_mistral.py b/llama-index-integrations/llms/llama-index-llms-mistralai/tests/test_llms_mistral.py index ce4ed450aa..afaf5d3315 100644 --- a/llama-index-integrations/llms/llama-index-llms-mistralai/tests/test_llms_mistral.py +++ b/llama-index-integrations/llms/llama-index-llms-mistralai/tests/test_llms_mistral.py @@ -4,11 +4,16 @@ from pathlib import Path from unittest.mock import patch -from mistralai import ToolCall, ImageURLChunk, TextChunk, ThinkChunk +from mistralai import ImageURLChunk, TextChunk, ThinkChunk import pytest from llama_index.core.base.llms.base import BaseLLM -from llama_index.core.llms import ChatMessage, ImageBlock, TextBlock +from llama_index.core.base.llms.types import ( + ChatMessage, + ImageBlock, + TextBlock, + ToolCallBlock, +) from llama_index.core.base.llms.types import ThinkingBlock from llama_index.core.tools import FunctionTool from llama_index.llms.mistralai import MistralAI @@ -40,14 +45,13 @@ def test_tool_required(): user_msg="What is the capital of France?", tool_required=True, ) - additional_kwargs = result.message.additional_kwargs - assert "tool_calls" in additional_kwargs - tool_calls = additional_kwargs["tool_calls"] + tool_calls = [ + block for block in result.message.blocks if isinstance(block, ToolCallBlock) + ] assert len(tool_calls) == 1 tool_call = tool_calls[0] - assert isinstance(tool_call, ToolCall) - assert tool_call.function.name == "search_tool" - assert "query" in tool_call.function.arguments + assert tool_call.tool_name == "search_tool" + assert "query" in tool_call.tool_kwargs @patch("mistralai.Mistral") @@ -184,3 +188,8 @@ def test_to_mistral_chunks(tmp_path: Path, image_url: str) -> None: ) assert isinstance(thinking_chunks[1], TextChunk) assert thinking_chunks[1].text == "This is some text" + tool_blocks = [ + ToolCallBlock(tool_call_id="1", tool_name="hello_world", tool_kwargs={}) + ] + tool_chunks = to_mistral_chunks(tool_blocks) + assert len(tool_chunks) == 0 diff --git a/llama-index-integrations/llms/llama-index-llms-mistralai/uv.lock b/llama-index-integrations/llms/llama-index-llms-mistralai/uv.lock index 82fce00b29..49e59c5033 100644 --- a/llama-index-integrations/llms/llama-index-llms-mistralai/uv.lock +++ b/llama-index-integrations/llms/llama-index-llms-mistralai/uv.lock @@ -1584,7 +1584,7 @@ wheels = [ [[package]] name = "llama-index-core" -version = "0.13.0" +version = "0.14.5" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -1618,9 +1618,9 @@ dependencies = [ { name = "typing-inspect" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1a/0d/29865c82de51c2c1f263a61c267c336d35b14839fded66061a999dba1d40/llama_index_core-0.13.0.tar.gz", hash = "sha256:01fec50d3d807e3c3bc17a62ed1f5b93dad2205cda52f7d0c2d34cc6a6ab2b92", size = 7230599, upload-time = "2025-07-30T17:24:00.398Z" } +sdist = { url = "https://files.pythonhosted.org/packages/cf/42/e1de7d6a390dcd67b0754fd24e0d0acb56c1d0838a68e30671dd79fd5521/llama_index_core-0.14.5.tar.gz", hash = "sha256:913ebc3ad895d381eaab0f10dc405101c5bec5a70c09909ef2493ddc115f8552", size = 11578206, upload-time = "2025-10-15T19:10:09.746Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f3/27/36a3ad19c8e2e7967ca1c9e469e5205901d2b6d299420e9aba4eafd90d8d/llama_index_core-0.13.0-py3-none-any.whl", hash = "sha256:46c14fc2a26b8f7618c2dd2daf6e430e3f94b1908474baee539f705c9c638348", size = 7573714, upload-time = "2025-07-30T17:23:52.355Z" }, + { url = "https://files.pythonhosted.org/packages/0f/64/c02576991efcefd30a65971e87ece7494d6bbf3739b7bffeeb56c86b5a76/llama_index_core-0.14.5-py3-none-any.whl", hash = "sha256:5445aa322b83a9d48baa608c3b920df4f434ed5d461a61e6bccb36d99348bddf", size = 11919461, upload-time = "2025-10-15T19:10:06.92Z" }, ] [[package]] @@ -1638,7 +1638,7 @@ wheels = [ [[package]] name = "llama-index-llms-mistralai" -version = "0.7.1" +version = "0.9.0" source = { editable = "." } dependencies = [ { name = "llama-index-core" }, @@ -1670,7 +1670,7 @@ dev = [ [package.metadata] requires-dist = [ - { name = "llama-index-core", specifier = ">=0.13.0,<0.15" }, + { name = "llama-index-core", specifier = ">=0.14.5,<0.15" }, { name = "mistralai", specifier = ">=1.8.2" }, ] @@ -1698,16 +1698,17 @@ dev = [ [[package]] name = "llama-index-workflows" -version = "1.2.0" +version = "2.8.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "eval-type-backport", marker = "python_full_version < '3.10'" }, { name = "llama-index-instrumentation" }, { name = "pydantic" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/26/9d/9dc7adc10d9976582bf50b074883986cb36b46f2fe45cf60550767300a29/llama_index_workflows-1.2.0.tar.gz", hash = "sha256:f6b19f01a340a1afb1d2fd2285c9dce346e304a3aae519e6103059f5afb2609f", size = 1019113, upload-time = "2025-07-23T18:32:47.86Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/ab/105f9a5d596ad29b8734c8cb72ade3a6ce200b6e633dd5f79b130f729ab7/llama_index_workflows-2.8.1.tar.gz", hash = "sha256:914a6c927e2fc87a66426c05ffdcf8d1e06a513714d47f4c263d60bcb918180b", size = 4989071, upload-time = "2025-10-15T23:08:56.183Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/36/c1/5190f102a042d36a6a495de27510c2d6e3aca98f892895bfacdcf9109c1d/llama_index_workflows-1.2.0-py3-none-any.whl", hash = "sha256:5722a7ce137e00361025768789e7e77720cd66f855791050183a3c540b6e5b8c", size = 37463, upload-time = "2025-07-23T18:32:46.294Z" }, + { url = "https://files.pythonhosted.org/packages/6e/de/12bdf7a625fa932f4bab38233b6f146ef29aa5df36010088b2d3f0d479e9/llama_index_workflows-2.8.1-py3-none-any.whl", hash = "sha256:1309911c2252cba4705f4e5ee485f734936a7c7dbbabb037e13072a3d2f01551", size = 61019, upload-time = "2025-10-15T23:08:55.17Z" }, ] [[package]]