Skip to content

Commit f556585

Browse files
authored
feat: mistralai integration with tool call block (#20103)
1 parent d2c766b commit f556585

File tree

4 files changed

+210
-63
lines changed

4 files changed

+210
-63
lines changed

llama-index-integrations/llms/llama-index-llms-mistralai/llama_index/llms/mistralai/base.py

Lines changed: 181 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
TextBlock,
2626
ImageBlock,
2727
ThinkingBlock,
28+
ToolCallBlock,
2829
)
2930
from llama_index.core.bridge.pydantic import Field, PrivateAttr
3031
from llama_index.core.callbacks import CallbackManager
@@ -53,7 +54,7 @@
5354
)
5455

5556
from mistralai import Mistral
56-
from mistralai.models import ToolCall
57+
from mistralai.models import ToolCall, FunctionCall
5758
from mistralai.models import (
5859
Messages,
5960
AssistantMessage,
@@ -102,6 +103,8 @@ def to_mistral_chunks(content_blocks: Sequence[ContentBlock]) -> Sequence[Conten
102103
image_url=f"data:{image_mimetype};base64,{base_64_str}"
103104
)
104105
)
106+
elif isinstance(content_block, ToolCallBlock):
107+
pass
105108
else:
106109
raise ValueError(f"Unsupported content block type {type(content_block)}")
107110
return content_chunks
@@ -112,7 +115,33 @@ def to_mistral_chatmessage(
112115
) -> List[Messages]:
113116
new_messages = []
114117
for m in messages:
115-
tool_calls = m.additional_kwargs.get("tool_calls")
118+
unique_tool_calls = []
119+
tool_calls_li = [
120+
block for block in m.blocks if isinstance(block, ToolCallBlock)
121+
]
122+
tool_calls = []
123+
for tool_call_li in tool_calls_li:
124+
tool_calls.append(
125+
ToolCall(
126+
id=tool_call_li.tool_call_id,
127+
function=FunctionCall(
128+
name=tool_call_li.tool_name,
129+
arguments=tool_call_li.tool_kwargs,
130+
),
131+
)
132+
)
133+
unique_tool_calls.append(
134+
(tool_call_li.tool_call_id, tool_call_li.tool_name)
135+
)
136+
# try with legacy tool calls for compatibility with older chat histories
137+
if len(m.additional_kwargs.get("tool_calls", [])) > 0:
138+
tcs = m.additional_kwargs.get("tool_calls", [])
139+
for tc in tcs:
140+
if (
141+
isinstance(tc, ToolCall)
142+
and (tc.id, tc.function.name) not in unique_tool_calls
143+
):
144+
tool_calls.append(tc)
116145
chunks = to_mistral_chunks(m.blocks)
117146
if m.role == MessageRole.USER:
118147
new_messages.append(UserMessage(content=chunks))
@@ -135,9 +164,15 @@ def to_mistral_chatmessage(
135164

136165

137166
def force_single_tool_call(response: ChatResponse) -> None:
138-
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
167+
tool_calls = [
168+
block for block in response.message.blocks if isinstance(block, ToolCallBlock)
169+
]
139170
if len(tool_calls) > 1:
140-
response.message.additional_kwargs["tool_calls"] = [tool_calls[0]]
171+
response.message.blocks = [
172+
block
173+
for block in response.message.blocks
174+
if not isinstance(block, ToolCallBlock)
175+
] + [tool_calls[0]]
141176

142177

143178
class MistralAI(FunctionCallingLLM):
@@ -296,17 +331,29 @@ def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
296331
**kwargs,
297332
}
298333

299-
def _separate_thinking(self, response: str) -> Tuple[str, str]:
334+
def _separate_thinking(
335+
self, response: Union[str, List[ContentChunk]]
336+
) -> Tuple[str, str]:
300337
"""Separate the thinking from the response."""
301-
match = THINKING_REGEX.search(response)
338+
content = ""
339+
if isinstance(response, str):
340+
content = response
341+
else:
342+
for chunk in response:
343+
if isinstance(chunk, ThinkChunk):
344+
for c in chunk.thinking:
345+
if isinstance(c, TextChunk):
346+
content += c.text + "\n"
347+
348+
match = THINKING_REGEX.search(content)
302349
if match:
303-
return match.group(1), response.replace(match.group(0), "")
350+
return match.group(1), content.replace(match.group(0), "")
304351

305-
match = THINKING_START_REGEX.search(response)
352+
match = THINKING_START_REGEX.search(content)
306353
if match:
307354
return match.group(0), ""
308355

309-
return "", response
356+
return "", content
310357

311358
@llm_chat_callback()
312359
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
@@ -315,34 +362,51 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
315362
messages = to_mistral_chatmessage(messages)
316363
all_kwargs = self._get_all_kwargs(**kwargs)
317364
response = self._client.chat.complete(messages=messages, **all_kwargs)
318-
blocks: List[TextBlock | ThinkingBlock] = []
365+
blocks: List[TextBlock | ThinkingBlock | ToolCallBlock] = []
319366

320-
additional_kwargs = {}
321367
if self.model in MISTRAL_AI_REASONING_MODELS:
322368
thinking_txt, response_txt = self._separate_thinking(
323-
response.choices[0].message.content
369+
response.choices[0].message.content or []
324370
)
325371
if thinking_txt:
326372
blocks.append(ThinkingBlock(content=thinking_txt))
327373

374+
response_txt_think_show = ""
375+
if response.choices[0].message.content:
376+
if isinstance(response.choices[0].message.content, str):
377+
response_txt_think_show = response.choices[0].message.content
378+
else:
379+
for chunk in response.choices[0].message.content:
380+
if isinstance(chunk, TextBlock):
381+
response_txt_think_show += chunk.text + "\n"
382+
if isinstance(chunk, ThinkChunk):
383+
for c in chunk.thinking:
384+
if isinstance(c, TextChunk):
385+
response_txt_think_show += c.text + "\n"
386+
328387
response_txt = (
329-
response_txt
330-
if not self.show_thinking
331-
else response.choices[0].message.content
388+
response_txt if not self.show_thinking else response_txt_think_show
332389
)
333390
else:
334391
response_txt = response.choices[0].message.content
335392

336393
blocks.append(TextBlock(text=response_txt))
337394
tool_calls = response.choices[0].message.tool_calls
338395
if tool_calls is not None:
339-
additional_kwargs["tool_calls"] = tool_calls
396+
for tool_call in tool_calls:
397+
if isinstance(tool_call, ToolCall):
398+
blocks.append(
399+
ToolCallBlock(
400+
tool_call_id=tool_call.id,
401+
tool_kwargs=tool_call.function.arguments,
402+
tool_name=tool_call.function.name,
403+
)
404+
)
340405

341406
return ChatResponse(
342407
message=ChatMessage(
343408
role=MessageRole.ASSISTANT,
344409
blocks=blocks,
345-
additional_kwargs=additional_kwargs,
346410
),
347411
raw=dict(response),
348412
)
@@ -367,18 +431,39 @@ def stream_chat(
367431

368432
def gen() -> ChatResponseGen:
369433
content = ""
370-
blocks: List[TextBlock | ThinkingBlock] = []
434+
blocks: List[TextBlock | ThinkingBlock | ToolCallBlock] = []
371435
for chunk in response:
372436
delta = chunk.data.choices[0].delta
373437
role = delta.role or MessageRole.ASSISTANT
374438

375439
# NOTE: Unlike openAI, we are directly injecting the tool calls
376-
additional_kwargs = {}
377440
if delta.tool_calls:
378-
additional_kwargs["tool_calls"] = delta.tool_calls
441+
for tool_call in delta.tool_calls:
442+
if isinstance(tool_call, ToolCall):
443+
blocks.append(
444+
ToolCallBlock(
445+
tool_call_id=tool_call.id,
446+
tool_name=tool_call.function.name,
447+
tool_kwargs=tool_call.function.arguments,
448+
)
449+
)
379450

380451
content_delta = delta.content or ""
381-
content += content_delta
452+
content_delta_str = ""
453+
if isinstance(content_delta, str):
454+
content_delta_str = content_delta
455+
else:
456+
for chunk in content_delta:
457+
if isinstance(chunk, TextChunk):
458+
content_delta_str += chunk.text + "\n"
459+
elif isinstance(chunk, ThinkChunk):
460+
for c in chunk.thinking:
461+
if isinstance(c, TextChunk):
462+
content_delta_str += c.text + "\n"
463+
else:
464+
continue
465+
466+
content += content_delta_str
382467

383468
# decide whether to include thinking in deltas/responses
384469
if self.model in MISTRAL_AI_REASONING_MODELS:
@@ -392,15 +477,14 @@ def gen() -> ChatResponseGen:
392477
# If thinking hasn't ended, don't include it in the delta
393478
if thinking_txt is None and not self.show_thinking:
394479
content_delta = ""
395-
blocks.append(TextBlock(text=content))
480+
blocks.append(TextBlock(text=content))
396481

397482
yield ChatResponse(
398483
message=ChatMessage(
399484
role=role,
400485
blocks=blocks,
401-
additional_kwargs=additional_kwargs,
402486
),
403-
delta=content_delta,
487+
delta=content_delta_str,
404488
raw=chunk,
405489
)
406490

@@ -425,19 +509,30 @@ async def achat(
425509
messages=messages, **all_kwargs
426510
)
427511

428-
blocks: List[TextBlock | ThinkingBlock] = []
512+
blocks: List[TextBlock | ThinkingBlock | ToolCallBlock] = []
429513
additional_kwargs = {}
430514
if self.model in MISTRAL_AI_REASONING_MODELS:
431515
thinking_txt, response_txt = self._separate_thinking(
432-
response.choices[0].message.content
516+
response.choices[0].message.content or []
433517
)
434518
if thinking_txt:
435519
blocks.append(ThinkingBlock(content=thinking_txt))
436520

521+
response_txt_think_show = ""
522+
if response.choices[0].message.content:
523+
if isinstance(response.choices[0].message.content, str):
524+
response_txt_think_show = response.choices[0].message.content
525+
else:
526+
for chunk in response.choices[0].message.content:
527+
if isinstance(chunk, TextBlock):
528+
response_txt_think_show += chunk.text + "\n"
529+
if isinstance(chunk, ThinkChunk):
530+
for c in chunk.thinking:
531+
if isinstance(c, TextChunk):
532+
response_txt_think_show += c.text + "\n"
533+
437534
response_txt = (
438-
response_txt
439-
if not self.show_thinking
440-
else response.choices[0].message.content
535+
response_txt if not self.show_thinking else response_txt_think_show
441536
)
442537
else:
443538
response_txt = response.choices[0].message.content
@@ -446,7 +541,25 @@ async def achat(
446541

447542
tool_calls = response.choices[0].message.tool_calls
448543
if tool_calls is not None:
449-
additional_kwargs["tool_calls"] = tool_calls
544+
for tool_call in tool_calls:
545+
if isinstance(tool_call, ToolCall):
546+
blocks.append(
547+
ToolCallBlock(
548+
tool_call_id=tool_call.id,
549+
tool_kwargs=tool_call.function.arguments,
550+
tool_name=tool_call.function.name,
551+
)
552+
)
553+
else:
554+
if isinstance(tool_call[1], (str, dict)):
555+
blocks.append(
556+
ToolCallBlock(
557+
tool_kwargs=tool_call[1], tool_name=tool_call[0]
558+
)
559+
)
560+
additional_kwargs["tool_calls"] = (
561+
tool_calls # keep this to avoid tool calls loss if tool call does not fall within the validation scenarios above
562+
)
450563

451564
return ChatResponse(
452565
message=ChatMessage(
@@ -477,17 +590,38 @@ async def astream_chat(
477590

478591
async def gen() -> ChatResponseAsyncGen:
479592
content = ""
480-
blocks: List[ThinkingBlock | TextBlock] = []
593+
blocks: List[ThinkingBlock | TextBlock | ToolCallBlock] = []
481594
async for chunk in response:
482595
delta = chunk.data.choices[0].delta
483596
role = delta.role or MessageRole.ASSISTANT
484597
# NOTE: Unlike openAI, we are directly injecting the tool calls
485-
additional_kwargs = {}
486598
if delta.tool_calls:
487-
additional_kwargs["tool_calls"] = delta.tool_calls
599+
for tool_call in delta.tool_calls:
600+
if isinstance(tool_call, ToolCall):
601+
blocks.append(
602+
ToolCallBlock(
603+
tool_call_id=tool_call.id,
604+
tool_name=tool_call.function.name,
605+
tool_kwargs=tool_call.function.arguments,
606+
)
607+
)
488608

489609
content_delta = delta.content or ""
490-
content += content_delta
610+
content_delta_str = ""
611+
if isinstance(content_delta, str):
612+
content_delta_str = content_delta
613+
else:
614+
for chunk in content_delta:
615+
if isinstance(chunk, TextChunk):
616+
content_delta_str += chunk.text + "\n"
617+
elif isinstance(chunk, ThinkChunk):
618+
for c in chunk.thinking:
619+
if isinstance(c, TextChunk):
620+
content_delta_str += c.text + "\n"
621+
else:
622+
continue
623+
624+
content += content_delta_str
491625

492626
# decide whether to include thinking in deltas/responses
493627
if self.model in MISTRAL_AI_REASONING_MODELS:
@@ -501,15 +635,14 @@ async def gen() -> ChatResponseAsyncGen:
501635
if thinking_txt is None and not self.show_thinking:
502636
content_delta = ""
503637

504-
blocks.append(TextBlock(text=content))
638+
blocks.append(TextBlock(text=content))
505639

506640
yield ChatResponse(
507641
message=ChatMessage(
508642
role=role,
509643
blocks=blocks,
510-
additional_kwargs=additional_kwargs,
511644
),
512-
delta=content_delta,
645+
delta=content_delta_str,
513646
raw=chunk,
514647
)
515648

@@ -570,7 +703,11 @@ def get_tool_calls_from_response(
570703
error_on_no_tool_call: bool = True,
571704
) -> List[ToolSelection]:
572705
"""Predict and call the tool."""
573-
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
706+
tool_calls = [
707+
block
708+
for block in response.message.blocks
709+
if isinstance(block, ToolCallBlock)
710+
]
574711

575712
if len(tool_calls) < 1:
576713
if error_on_no_tool_call:
@@ -582,15 +719,15 @@ def get_tool_calls_from_response(
582719

583720
tool_selections = []
584721
for tool_call in tool_calls:
585-
if not isinstance(tool_call, ToolCall):
586-
raise ValueError("Invalid tool_call object")
587-
588-
argument_dict = json.loads(tool_call.function.arguments)
722+
if isinstance(tool_call.tool_kwargs, str):
723+
argument_dict = json.loads(tool_call.tool_kwargs)
724+
else:
725+
argument_dict = tool_call.tool_kwargs
589726

590727
tool_selections.append(
591728
ToolSelection(
592-
tool_id=tool_call.id,
593-
tool_name=tool_call.function.name,
729+
tool_id=tool_call.tool_call_id or "",
730+
tool_name=tool_call.tool_name,
594731
tool_kwargs=argument_dict,
595732
)
596733
)

0 commit comments

Comments
 (0)