2525 TextBlock ,
2626 ImageBlock ,
2727 ThinkingBlock ,
28+ ToolCallBlock ,
2829)
2930from llama_index .core .bridge .pydantic import Field , PrivateAttr
3031from llama_index .core .callbacks import CallbackManager
5354)
5455
5556from mistralai import Mistral
56- from mistralai .models import ToolCall
57+ from mistralai .models import ToolCall , FunctionCall
5758from mistralai .models import (
5859 Messages ,
5960 AssistantMessage ,
@@ -102,6 +103,8 @@ def to_mistral_chunks(content_blocks: Sequence[ContentBlock]) -> Sequence[Conten
102103 image_url = f"data:{ image_mimetype } ;base64,{ base_64_str } "
103104 )
104105 )
106+ elif isinstance (content_block , ToolCallBlock ):
107+ pass
105108 else :
106109 raise ValueError (f"Unsupported content block type { type (content_block )} " )
107110 return content_chunks
@@ -112,7 +115,33 @@ def to_mistral_chatmessage(
112115) -> List [Messages ]:
113116 new_messages = []
114117 for m in messages :
115- tool_calls = m .additional_kwargs .get ("tool_calls" )
118+ unique_tool_calls = []
119+ tool_calls_li = [
120+ block for block in m .blocks if isinstance (block , ToolCallBlock )
121+ ]
122+ tool_calls = []
123+ for tool_call_li in tool_calls_li :
124+ tool_calls .append (
125+ ToolCall (
126+ id = tool_call_li .tool_call_id ,
127+ function = FunctionCall (
128+ name = tool_call_li .tool_name ,
129+ arguments = tool_call_li .tool_kwargs ,
130+ ),
131+ )
132+ )
133+ unique_tool_calls .append (
134+ (tool_call_li .tool_call_id , tool_call_li .tool_name )
135+ )
136+ # try with legacy tool calls for compatibility with older chat histories
137+ if len (m .additional_kwargs .get ("tool_calls" , [])) > 0 :
138+ tcs = m .additional_kwargs .get ("tool_calls" , [])
139+ for tc in tcs :
140+ if (
141+ isinstance (tc , ToolCall )
142+ and (tc .id , tc .function .name ) not in unique_tool_calls
143+ ):
144+ tool_calls .append (tc )
116145 chunks = to_mistral_chunks (m .blocks )
117146 if m .role == MessageRole .USER :
118147 new_messages .append (UserMessage (content = chunks ))
@@ -135,9 +164,15 @@ def to_mistral_chatmessage(
135164
136165
137166def force_single_tool_call (response : ChatResponse ) -> None :
138- tool_calls = response .message .additional_kwargs .get ("tool_calls" , [])
167+ tool_calls = [
168+ block for block in response .message .blocks if isinstance (block , ToolCallBlock )
169+ ]
139170 if len (tool_calls ) > 1 :
140- response .message .additional_kwargs ["tool_calls" ] = [tool_calls [0 ]]
171+ response .message .blocks = [
172+ block
173+ for block in response .message .blocks
174+ if not isinstance (block , ToolCallBlock )
175+ ] + [tool_calls [0 ]]
141176
142177
143178class MistralAI (FunctionCallingLLM ):
@@ -296,17 +331,29 @@ def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
296331 ** kwargs ,
297332 }
298333
299- def _separate_thinking (self , response : str ) -> Tuple [str , str ]:
334+ def _separate_thinking (
335+ self , response : Union [str , List [ContentChunk ]]
336+ ) -> Tuple [str , str ]:
300337 """Separate the thinking from the response."""
301- match = THINKING_REGEX .search (response )
338+ content = ""
339+ if isinstance (response , str ):
340+ content = response
341+ else :
342+ for chunk in response :
343+ if isinstance (chunk , ThinkChunk ):
344+ for c in chunk .thinking :
345+ if isinstance (c , TextChunk ):
346+ content += c .text + "\n "
347+
348+ match = THINKING_REGEX .search (content )
302349 if match :
303- return match .group (1 ), response .replace (match .group (0 ), "" )
350+ return match .group (1 ), content .replace (match .group (0 ), "" )
304351
305- match = THINKING_START_REGEX .search (response )
352+ match = THINKING_START_REGEX .search (content )
306353 if match :
307354 return match .group (0 ), ""
308355
309- return "" , response
356+ return "" , content
310357
311358 @llm_chat_callback ()
312359 def chat (self , messages : Sequence [ChatMessage ], ** kwargs : Any ) -> ChatResponse :
@@ -315,34 +362,51 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
315362 messages = to_mistral_chatmessage (messages )
316363 all_kwargs = self ._get_all_kwargs (** kwargs )
317364 response = self ._client .chat .complete (messages = messages , ** all_kwargs )
318- blocks : List [TextBlock | ThinkingBlock ] = []
365+ blocks : List [TextBlock | ThinkingBlock | ToolCallBlock ] = []
319366
320- additional_kwargs = {}
321367 if self .model in MISTRAL_AI_REASONING_MODELS :
322368 thinking_txt , response_txt = self ._separate_thinking (
323- response .choices [0 ].message .content
369+ response .choices [0 ].message .content or []
324370 )
325371 if thinking_txt :
326372 blocks .append (ThinkingBlock (content = thinking_txt ))
327373
374+ response_txt_think_show = ""
375+ if response .choices [0 ].message .content :
376+ if isinstance (response .choices [0 ].message .content , str ):
377+ response_txt_think_show = response .choices [0 ].message .content
378+ else :
379+ for chunk in response .choices [0 ].message .content :
380+ if isinstance (chunk , TextBlock ):
381+ response_txt_think_show += chunk .text + "\n "
382+ if isinstance (chunk , ThinkChunk ):
383+ for c in chunk .thinking :
384+ if isinstance (c , TextChunk ):
385+ response_txt_think_show += c .text + "\n "
386+
328387 response_txt = (
329- response_txt
330- if not self .show_thinking
331- else response .choices [0 ].message .content
388+ response_txt if not self .show_thinking else response_txt_think_show
332389 )
333390 else :
334391 response_txt = response .choices [0 ].message .content
335392
336393 blocks .append (TextBlock (text = response_txt ))
337394 tool_calls = response .choices [0 ].message .tool_calls
338395 if tool_calls is not None :
339- additional_kwargs ["tool_calls" ] = tool_calls
396+ for tool_call in tool_calls :
397+ if isinstance (tool_call , ToolCall ):
398+ blocks .append (
399+ ToolCallBlock (
400+ tool_call_id = tool_call .id ,
401+ tool_kwargs = tool_call .function .arguments ,
402+ tool_name = tool_call .function .name ,
403+ )
404+ )
340405
341406 return ChatResponse (
342407 message = ChatMessage (
343408 role = MessageRole .ASSISTANT ,
344409 blocks = blocks ,
345- additional_kwargs = additional_kwargs ,
346410 ),
347411 raw = dict (response ),
348412 )
@@ -367,18 +431,39 @@ def stream_chat(
367431
368432 def gen () -> ChatResponseGen :
369433 content = ""
370- blocks : List [TextBlock | ThinkingBlock ] = []
434+ blocks : List [TextBlock | ThinkingBlock | ToolCallBlock ] = []
371435 for chunk in response :
372436 delta = chunk .data .choices [0 ].delta
373437 role = delta .role or MessageRole .ASSISTANT
374438
375439 # NOTE: Unlike openAI, we are directly injecting the tool calls
376- additional_kwargs = {}
377440 if delta .tool_calls :
378- additional_kwargs ["tool_calls" ] = delta .tool_calls
441+ for tool_call in delta .tool_calls :
442+ if isinstance (tool_call , ToolCall ):
443+ blocks .append (
444+ ToolCallBlock (
445+ tool_call_id = tool_call .id ,
446+ tool_name = tool_call .function .name ,
447+ tool_kwargs = tool_call .function .arguments ,
448+ )
449+ )
379450
380451 content_delta = delta .content or ""
381- content += content_delta
452+ content_delta_str = ""
453+ if isinstance (content_delta , str ):
454+ content_delta_str = content_delta
455+ else :
456+ for chunk in content_delta :
457+ if isinstance (chunk , TextChunk ):
458+ content_delta_str += chunk .text + "\n "
459+ elif isinstance (chunk , ThinkChunk ):
460+ for c in chunk .thinking :
461+ if isinstance (c , TextChunk ):
462+ content_delta_str += c .text + "\n "
463+ else :
464+ continue
465+
466+ content += content_delta_str
382467
383468 # decide whether to include thinking in deltas/responses
384469 if self .model in MISTRAL_AI_REASONING_MODELS :
@@ -392,15 +477,14 @@ def gen() -> ChatResponseGen:
392477 # If thinking hasn't ended, don't include it in the delta
393478 if thinking_txt is None and not self .show_thinking :
394479 content_delta = ""
395- blocks .append (TextBlock (text = content ))
480+ blocks .append (TextBlock (text = content ))
396481
397482 yield ChatResponse (
398483 message = ChatMessage (
399484 role = role ,
400485 blocks = blocks ,
401- additional_kwargs = additional_kwargs ,
402486 ),
403- delta = content_delta ,
487+ delta = content_delta_str ,
404488 raw = chunk ,
405489 )
406490
@@ -425,19 +509,30 @@ async def achat(
425509 messages = messages , ** all_kwargs
426510 )
427511
428- blocks : List [TextBlock | ThinkingBlock ] = []
512+ blocks : List [TextBlock | ThinkingBlock | ToolCallBlock ] = []
429513 additional_kwargs = {}
430514 if self .model in MISTRAL_AI_REASONING_MODELS :
431515 thinking_txt , response_txt = self ._separate_thinking (
432- response .choices [0 ].message .content
516+ response .choices [0 ].message .content or []
433517 )
434518 if thinking_txt :
435519 blocks .append (ThinkingBlock (content = thinking_txt ))
436520
521+ response_txt_think_show = ""
522+ if response .choices [0 ].message .content :
523+ if isinstance (response .choices [0 ].message .content , str ):
524+ response_txt_think_show = response .choices [0 ].message .content
525+ else :
526+ for chunk in response .choices [0 ].message .content :
527+ if isinstance (chunk , TextBlock ):
528+ response_txt_think_show += chunk .text + "\n "
529+ if isinstance (chunk , ThinkChunk ):
530+ for c in chunk .thinking :
531+ if isinstance (c , TextChunk ):
532+ response_txt_think_show += c .text + "\n "
533+
437534 response_txt = (
438- response_txt
439- if not self .show_thinking
440- else response .choices [0 ].message .content
535+ response_txt if not self .show_thinking else response_txt_think_show
441536 )
442537 else :
443538 response_txt = response .choices [0 ].message .content
@@ -446,7 +541,25 @@ async def achat(
446541
447542 tool_calls = response .choices [0 ].message .tool_calls
448543 if tool_calls is not None :
449- additional_kwargs ["tool_calls" ] = tool_calls
544+ for tool_call in tool_calls :
545+ if isinstance (tool_call , ToolCall ):
546+ blocks .append (
547+ ToolCallBlock (
548+ tool_call_id = tool_call .id ,
549+ tool_kwargs = tool_call .function .arguments ,
550+ tool_name = tool_call .function .name ,
551+ )
552+ )
553+ else :
554+ if isinstance (tool_call [1 ], (str , dict )):
555+ blocks .append (
556+ ToolCallBlock (
557+ tool_kwargs = tool_call [1 ], tool_name = tool_call [0 ]
558+ )
559+ )
560+ additional_kwargs ["tool_calls" ] = (
561+ tool_calls # keep this to avoid tool calls loss if tool call does not fall within the validation scenarios above
562+ )
450563
451564 return ChatResponse (
452565 message = ChatMessage (
@@ -477,17 +590,38 @@ async def astream_chat(
477590
478591 async def gen () -> ChatResponseAsyncGen :
479592 content = ""
480- blocks : List [ThinkingBlock | TextBlock ] = []
593+ blocks : List [ThinkingBlock | TextBlock | ToolCallBlock ] = []
481594 async for chunk in response :
482595 delta = chunk .data .choices [0 ].delta
483596 role = delta .role or MessageRole .ASSISTANT
484597 # NOTE: Unlike openAI, we are directly injecting the tool calls
485- additional_kwargs = {}
486598 if delta .tool_calls :
487- additional_kwargs ["tool_calls" ] = delta .tool_calls
599+ for tool_call in delta .tool_calls :
600+ if isinstance (tool_call , ToolCall ):
601+ blocks .append (
602+ ToolCallBlock (
603+ tool_call_id = tool_call .id ,
604+ tool_name = tool_call .function .name ,
605+ tool_kwargs = tool_call .function .arguments ,
606+ )
607+ )
488608
489609 content_delta = delta .content or ""
490- content += content_delta
610+ content_delta_str = ""
611+ if isinstance (content_delta , str ):
612+ content_delta_str = content_delta
613+ else :
614+ for chunk in content_delta :
615+ if isinstance (chunk , TextChunk ):
616+ content_delta_str += chunk .text + "\n "
617+ elif isinstance (chunk , ThinkChunk ):
618+ for c in chunk .thinking :
619+ if isinstance (c , TextChunk ):
620+ content_delta_str += c .text + "\n "
621+ else :
622+ continue
623+
624+ content += content_delta_str
491625
492626 # decide whether to include thinking in deltas/responses
493627 if self .model in MISTRAL_AI_REASONING_MODELS :
@@ -501,15 +635,14 @@ async def gen() -> ChatResponseAsyncGen:
501635 if thinking_txt is None and not self .show_thinking :
502636 content_delta = ""
503637
504- blocks .append (TextBlock (text = content ))
638+ blocks .append (TextBlock (text = content ))
505639
506640 yield ChatResponse (
507641 message = ChatMessage (
508642 role = role ,
509643 blocks = blocks ,
510- additional_kwargs = additional_kwargs ,
511644 ),
512- delta = content_delta ,
645+ delta = content_delta_str ,
513646 raw = chunk ,
514647 )
515648
@@ -570,7 +703,11 @@ def get_tool_calls_from_response(
570703 error_on_no_tool_call : bool = True ,
571704 ) -> List [ToolSelection ]:
572705 """Predict and call the tool."""
573- tool_calls = response .message .additional_kwargs .get ("tool_calls" , [])
706+ tool_calls = [
707+ block
708+ for block in response .message .blocks
709+ if isinstance (block , ToolCallBlock )
710+ ]
574711
575712 if len (tool_calls ) < 1 :
576713 if error_on_no_tool_call :
@@ -582,15 +719,15 @@ def get_tool_calls_from_response(
582719
583720 tool_selections = []
584721 for tool_call in tool_calls :
585- if not isinstance (tool_call , ToolCall ):
586- raise ValueError ( "Invalid tool_call object" )
587-
588- argument_dict = json . loads ( tool_call .function . arguments )
722+ if isinstance (tool_call . tool_kwargs , str ):
723+ argument_dict = json . loads ( tool_call . tool_kwargs )
724+ else :
725+ argument_dict = tool_call .tool_kwargs
589726
590727 tool_selections .append (
591728 ToolSelection (
592- tool_id = tool_call .id ,
593- tool_name = tool_call .function . name ,
729+ tool_id = tool_call .tool_call_id or "" ,
730+ tool_name = tool_call .tool_name ,
594731 tool_kwargs = argument_dict ,
595732 )
596733 )
0 commit comments