Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
TextBlock,
ImageBlock,
ThinkingBlock,
ToolCallBlock,
)
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks import CallbackManager
Expand Down Expand Up @@ -53,7 +54,7 @@
)

from mistralai import Mistral
from mistralai.models import ToolCall
from mistralai.models import ToolCall, FunctionCall
from mistralai.models import (
Messages,
AssistantMessage,
Expand Down Expand Up @@ -102,6 +103,8 @@ def to_mistral_chunks(content_blocks: Sequence[ContentBlock]) -> Sequence[Conten
image_url=f"data:{image_mimetype};base64,{base_64_str}"
)
)
elif isinstance(content_block, ToolCallBlock):
pass
else:
raise ValueError(f"Unsupported content block type {type(content_block)}")
return content_chunks
Expand All @@ -112,7 +115,33 @@ def to_mistral_chatmessage(
) -> List[Messages]:
new_messages = []
for m in messages:
tool_calls = m.additional_kwargs.get("tool_calls")
unique_tool_calls = []
tool_calls_li = [
block for block in m.blocks if isinstance(block, ToolCallBlock)
]
tool_calls = []
for tool_call_li in tool_calls_li:
tool_calls.append(
ToolCall(
id=tool_call_li.tool_call_id,
function=FunctionCall(
name=tool_call_li.tool_name,
arguments=tool_call_li.tool_kwargs,
),
)
)
unique_tool_calls.append(
(tool_call_li.tool_call_id, tool_call_li.tool_name)
)
# try with legacy tool calls for compatibility with older chat histories
if len(m.additional_kwargs.get("tool_calls", [])) > 0:
tcs = m.additional_kwargs.get("tool_calls", [])
for tc in tcs:
if (
isinstance(tc, ToolCall)
and (tc.id, tc.function.name) not in unique_tool_calls
):
tool_calls.append(tc)
Comment on lines +137 to +144
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check for duplicates :))

chunks = to_mistral_chunks(m.blocks)
if m.role == MessageRole.USER:
new_messages.append(UserMessage(content=chunks))
Expand All @@ -135,9 +164,15 @@ def to_mistral_chatmessage(


def force_single_tool_call(response: ChatResponse) -> None:
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
tool_calls = [
block for block in response.message.blocks if isinstance(block, ToolCallBlock)
]
if len(tool_calls) > 1:
response.message.additional_kwargs["tool_calls"] = [tool_calls[0]]
response.message.blocks = [
block
for block in response.message.blocks
if not isinstance(block, ToolCallBlock)
] + [tool_calls[0]]


class MistralAI(FunctionCallingLLM):
Expand Down Expand Up @@ -296,17 +331,29 @@ def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
**kwargs,
}

def _separate_thinking(self, response: str) -> Tuple[str, str]:
def _separate_thinking(
self, response: Union[str, List[ContentChunk]]
) -> Tuple[str, str]:
"""Separate the thinking from the response."""
match = THINKING_REGEX.search(response)
content = ""
if isinstance(response, str):
content = response
else:
for chunk in response:
if isinstance(chunk, ThinkChunk):
for c in chunk.thinking:
if isinstance(c, TextChunk):
content += c.text + "\n"

match = THINKING_REGEX.search(content)
if match:
return match.group(1), response.replace(match.group(0), "")
return match.group(1), content.replace(match.group(0), "")

match = THINKING_START_REGEX.search(response)
match = THINKING_START_REGEX.search(content)
if match:
return match.group(0), ""

return "", response
return "", content

@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
Expand All @@ -315,34 +362,51 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
messages = to_mistral_chatmessage(messages)
all_kwargs = self._get_all_kwargs(**kwargs)
response = self._client.chat.complete(messages=messages, **all_kwargs)
blocks: List[TextBlock | ThinkingBlock] = []
blocks: List[TextBlock | ThinkingBlock | ToolCallBlock] = []

additional_kwargs = {}
if self.model in MISTRAL_AI_REASONING_MODELS:
thinking_txt, response_txt = self._separate_thinking(
response.choices[0].message.content
response.choices[0].message.content or []
)
if thinking_txt:
blocks.append(ThinkingBlock(content=thinking_txt))

response_txt_think_show = ""
if response.choices[0].message.content:
if isinstance(response.choices[0].message.content, str):
response_txt_think_show = response.choices[0].message.content
else:
for chunk in response.choices[0].message.content:
if isinstance(chunk, TextBlock):
response_txt_think_show += chunk.text + "\n"
if isinstance(chunk, ThinkChunk):
for c in chunk.thinking:
if isinstance(c, TextChunk):
response_txt_think_show += c.text + "\n"

response_txt = (
response_txt
if not self.show_thinking
else response.choices[0].message.content
response_txt if not self.show_thinking else response_txt_think_show
)
else:
response_txt = response.choices[0].message.content

blocks.append(TextBlock(text=response_txt))
tool_calls = response.choices[0].message.tool_calls
if tool_calls is not None:
additional_kwargs["tool_calls"] = tool_calls
for tool_call in tool_calls:
if isinstance(tool_call, ToolCall):
blocks.append(
ToolCallBlock(
tool_call_id=tool_call.id,
tool_kwargs=tool_call.function.arguments,
tool_name=tool_call.function.name,
)
)

return ChatResponse(
message=ChatMessage(
role=MessageRole.ASSISTANT,
blocks=blocks,
additional_kwargs=additional_kwargs,
),
raw=dict(response),
)
Expand All @@ -367,18 +431,39 @@ def stream_chat(

def gen() -> ChatResponseGen:
content = ""
blocks: List[TextBlock | ThinkingBlock] = []
blocks: List[TextBlock | ThinkingBlock | ToolCallBlock] = []
for chunk in response:
delta = chunk.data.choices[0].delta
role = delta.role or MessageRole.ASSISTANT

# NOTE: Unlike openAI, we are directly injecting the tool calls
additional_kwargs = {}
if delta.tool_calls:
additional_kwargs["tool_calls"] = delta.tool_calls
for tool_call in delta.tool_calls:
if isinstance(tool_call, ToolCall):
blocks.append(
ToolCallBlock(
tool_call_id=tool_call.id,
tool_name=tool_call.function.name,
tool_kwargs=tool_call.function.arguments,
)
)

content_delta = delta.content or ""
content += content_delta
content_delta_str = ""
if isinstance(content_delta, str):
content_delta_str = content_delta
else:
for chunk in content_delta:
if isinstance(chunk, TextChunk):
content_delta_str += chunk.text + "\n"
elif isinstance(chunk, ThinkChunk):
for c in chunk.thinking:
if isinstance(c, TextChunk):
content_delta_str += c.text + "\n"
else:
continue

content += content_delta_str

# decide whether to include thinking in deltas/responses
if self.model in MISTRAL_AI_REASONING_MODELS:
Expand All @@ -392,15 +477,14 @@ def gen() -> ChatResponseGen:
# If thinking hasn't ended, don't include it in the delta
if thinking_txt is None and not self.show_thinking:
content_delta = ""
blocks.append(TextBlock(text=content))
blocks.append(TextBlock(text=content))

yield ChatResponse(
message=ChatMessage(
role=role,
blocks=blocks,
additional_kwargs=additional_kwargs,
),
delta=content_delta,
delta=content_delta_str,
raw=chunk,
)

Expand All @@ -425,19 +509,30 @@ async def achat(
messages=messages, **all_kwargs
)

blocks: List[TextBlock | ThinkingBlock] = []
blocks: List[TextBlock | ThinkingBlock | ToolCallBlock] = []
additional_kwargs = {}
if self.model in MISTRAL_AI_REASONING_MODELS:
thinking_txt, response_txt = self._separate_thinking(
response.choices[0].message.content
response.choices[0].message.content or []
)
if thinking_txt:
blocks.append(ThinkingBlock(content=thinking_txt))

response_txt_think_show = ""
if response.choices[0].message.content:
if isinstance(response.choices[0].message.content, str):
response_txt_think_show = response.choices[0].message.content
else:
for chunk in response.choices[0].message.content:
if isinstance(chunk, TextBlock):
response_txt_think_show += chunk.text + "\n"
if isinstance(chunk, ThinkChunk):
for c in chunk.thinking:
if isinstance(c, TextChunk):
response_txt_think_show += c.text + "\n"

response_txt = (
response_txt
if not self.show_thinking
else response.choices[0].message.content
response_txt if not self.show_thinking else response_txt_think_show
)
else:
response_txt = response.choices[0].message.content
Expand All @@ -446,7 +541,25 @@ async def achat(

tool_calls = response.choices[0].message.tool_calls
if tool_calls is not None:
additional_kwargs["tool_calls"] = tool_calls
for tool_call in tool_calls:
if isinstance(tool_call, ToolCall):
blocks.append(
ToolCallBlock(
tool_call_id=tool_call.id,
tool_kwargs=tool_call.function.arguments,
tool_name=tool_call.function.name,
)
)
else:
if isinstance(tool_call[1], (str, dict)):
blocks.append(
ToolCallBlock(
tool_kwargs=tool_call[1], tool_name=tool_call[0]
)
)
additional_kwargs["tool_calls"] = (
tool_calls # keep this to avoid tool calls loss if tool call does not fall within the validation scenarios above
)

return ChatResponse(
message=ChatMessage(
Expand Down Expand Up @@ -477,17 +590,38 @@ async def astream_chat(

async def gen() -> ChatResponseAsyncGen:
content = ""
blocks: List[ThinkingBlock | TextBlock] = []
blocks: List[ThinkingBlock | TextBlock | ToolCallBlock] = []
async for chunk in response:
delta = chunk.data.choices[0].delta
role = delta.role or MessageRole.ASSISTANT
# NOTE: Unlike openAI, we are directly injecting the tool calls
additional_kwargs = {}
if delta.tool_calls:
additional_kwargs["tool_calls"] = delta.tool_calls
for tool_call in delta.tool_calls:
if isinstance(tool_call, ToolCall):
blocks.append(
ToolCallBlock(
tool_call_id=tool_call.id,
tool_name=tool_call.function.name,
tool_kwargs=tool_call.function.arguments,
)
)

content_delta = delta.content or ""
content += content_delta
content_delta_str = ""
if isinstance(content_delta, str):
content_delta_str = content_delta
else:
for chunk in content_delta:
if isinstance(chunk, TextChunk):
content_delta_str += chunk.text + "\n"
elif isinstance(chunk, ThinkChunk):
for c in chunk.thinking:
if isinstance(c, TextChunk):
content_delta_str += c.text + "\n"
else:
continue

content += content_delta_str

# decide whether to include thinking in deltas/responses
if self.model in MISTRAL_AI_REASONING_MODELS:
Expand All @@ -501,15 +635,14 @@ async def gen() -> ChatResponseAsyncGen:
if thinking_txt is None and not self.show_thinking:
content_delta = ""

blocks.append(TextBlock(text=content))
blocks.append(TextBlock(text=content))

yield ChatResponse(
message=ChatMessage(
role=role,
blocks=blocks,
additional_kwargs=additional_kwargs,
),
delta=content_delta,
delta=content_delta_str,
raw=chunk,
)

Expand Down Expand Up @@ -570,7 +703,11 @@ def get_tool_calls_from_response(
error_on_no_tool_call: bool = True,
) -> List[ToolSelection]:
"""Predict and call the tool."""
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
tool_calls = [
block
for block in response.message.blocks
if isinstance(block, ToolCallBlock)
]

if len(tool_calls) < 1:
if error_on_no_tool_call:
Expand All @@ -582,15 +719,15 @@ def get_tool_calls_from_response(

tool_selections = []
for tool_call in tool_calls:
if not isinstance(tool_call, ToolCall):
raise ValueError("Invalid tool_call object")

argument_dict = json.loads(tool_call.function.arguments)
if isinstance(tool_call.tool_kwargs, str):
argument_dict = json.loads(tool_call.tool_kwargs)
else:
argument_dict = tool_call.tool_kwargs

tool_selections.append(
ToolSelection(
tool_id=tool_call.id,
tool_name=tool_call.function.name,
tool_id=tool_call.tool_call_id or "",
tool_name=tool_call.tool_name,
tool_kwargs=argument_dict,
)
)
Expand Down
Loading
Loading