Skip to content

Commit d2c766b

Browse files
authored
feat: integrate anthropic with tool call block (#20100)
1 parent b22fc1f commit d2c766b

File tree

6 files changed

+217
-72
lines changed

6 files changed

+217
-72
lines changed

llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/base.py

Lines changed: 67 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
Set,
1414
Tuple,
1515
Union,
16+
cast,
1617
)
17-
18+
from llama_index.core.llms.utils import parse_partial_json
1819
from llama_index.core.base.llms.types import (
1920
ChatMessage,
2021
ChatResponse,
@@ -23,6 +24,7 @@
2324
LLMMetadata,
2425
MessageRole,
2526
ContentBlock,
27+
ToolCallBlock,
2628
)
2729
from llama_index.core.base.llms.types import TextBlock as LITextBlock
2830
from llama_index.core.base.llms.types import CitationBlock as LICitationBlock
@@ -35,7 +37,6 @@
3537
llm_completion_callback,
3638
)
3739
from llama_index.core.llms.function_calling import FunctionCallingLLM, ToolSelection
38-
from llama_index.core.llms.utils import parse_partial_json
3940
from llama_index.core.types import BaseOutputParser, PydanticProgramMode
4041
from llama_index.core.utils import Tokenizer
4142
from llama_index.llms.anthropic.utils import (
@@ -44,6 +45,7 @@
4445
is_anthropic_prompt_caching_supported_model,
4546
is_function_calling_model,
4647
messages_to_anthropic_messages,
48+
update_tool_calls,
4749
)
4850

4951
import anthropic
@@ -351,8 +353,7 @@ def _completion_response_from_chat_response(
351353

352354
def _get_blocks_and_tool_calls_and_thinking(
353355
self, response: Any
354-
) -> Tuple[List[ContentBlock], List[Dict[str, Any]], List[Dict[str, Any]]]:
355-
tool_calls = []
356+
) -> Tuple[List[ContentBlock], List[Dict[str, Any]]]:
356357
blocks: List[ContentBlock] = []
357358
citations: List[TextCitation] = []
358359
tracked_citations: Set[str] = set()
@@ -392,9 +393,15 @@ def _get_blocks_and_tool_calls_and_thinking(
392393
)
393394
)
394395
elif isinstance(content_block, ToolUseBlock):
395-
tool_calls.append(content_block.model_dump())
396+
blocks.append(
397+
ToolCallBlock(
398+
tool_call_id=content_block.id,
399+
tool_kwargs=cast(Dict[str, Any] | str, content_block.input),
400+
tool_name=content_block.name,
401+
)
402+
)
396403

397-
return blocks, tool_calls, [x.model_dump() for x in citations]
404+
return blocks, [x.model_dump() for x in citations]
398405

399406
@llm_chat_callback()
400407
def chat(
@@ -412,17 +419,12 @@ def chat(
412419
**all_kwargs,
413420
)
414421

415-
blocks, tool_calls, citations = self._get_blocks_and_tool_calls_and_thinking(
416-
response
417-
)
422+
blocks, citations = self._get_blocks_and_tool_calls_and_thinking(response)
418423

419424
return AnthropicChatResponse(
420425
message=ChatMessage(
421426
role=MessageRole.ASSISTANT,
422427
blocks=blocks,
423-
additional_kwargs={
424-
"tool_calls": tool_calls,
425-
},
426428
),
427429
citations=citations,
428430
raw=dict(response),
@@ -536,13 +538,18 @@ def gen() -> Generator[AnthropicChatResponse, None, None]:
536538
else:
537539
tool_calls_to_send = cur_tool_calls
538540

541+
for tool_call in tool_calls_to_send:
542+
tc = ToolCallBlock(
543+
tool_call_id=tool_call.id,
544+
tool_name=tool_call.name,
545+
tool_kwargs=cast(Dict[str, Any] | str, tool_call.input),
546+
)
547+
update_tool_calls(content, tc)
548+
539549
yield AnthropicChatResponse(
540550
message=ChatMessage(
541551
role=role,
542552
blocks=content,
543-
additional_kwargs={
544-
"tool_calls": [t.dict() for t in tool_calls_to_send]
545-
},
546553
),
547554
citations=cur_citations,
548555
delta=content_delta,
@@ -560,13 +567,23 @@ def gen() -> Generator[AnthropicChatResponse, None, None]:
560567
content.append(cur_block)
561568
cur_block = None
562569

570+
if cur_tool_call is not None:
571+
tool_calls_to_send = [*cur_tool_calls, cur_tool_call]
572+
else:
573+
tool_calls_to_send = cur_tool_calls
574+
575+
for tool_call in tool_calls_to_send:
576+
tc = ToolCallBlock(
577+
tool_call_id=tool_call.id,
578+
tool_name=tool_call.name,
579+
tool_kwargs=cast(Dict[str, Any] | str, tool_call.input),
580+
)
581+
update_tool_calls(content, tc)
582+
563583
yield AnthropicChatResponse(
564584
message=ChatMessage(
565585
role=role,
566586
blocks=content,
567-
additional_kwargs={
568-
"tool_calls": [t.dict() for t in tool_calls_to_send]
569-
},
570587
),
571588
citations=cur_citations,
572589
delta="",
@@ -604,17 +621,12 @@ async def achat(
604621
**all_kwargs,
605622
)
606623

607-
blocks, tool_calls, citations = self._get_blocks_and_tool_calls_and_thinking(
608-
response
609-
)
624+
blocks, citations = self._get_blocks_and_tool_calls_and_thinking(response)
610625

611626
return AnthropicChatResponse(
612627
message=ChatMessage(
613628
role=MessageRole.ASSISTANT,
614629
blocks=blocks,
615-
additional_kwargs={
616-
"tool_calls": tool_calls,
617-
},
618630
),
619631
citations=citations,
620632
raw=dict(response),
@@ -728,13 +740,18 @@ async def gen() -> ChatResponseAsyncGen:
728740
else:
729741
tool_calls_to_send = cur_tool_calls
730742

743+
for tool_call in tool_calls_to_send:
744+
tc = ToolCallBlock(
745+
tool_call_id=tool_call.id,
746+
tool_name=tool_call.name,
747+
tool_kwargs=cast(Dict[str, Any] | str, tool_call.input),
748+
)
749+
update_tool_calls(content, tc)
750+
731751
yield AnthropicChatResponse(
732752
message=ChatMessage(
733753
role=role,
734754
blocks=content,
735-
additional_kwargs={
736-
"tool_calls": [t.dict() for t in tool_calls_to_send]
737-
},
738755
),
739756
citations=cur_citations,
740757
delta=content_delta,
@@ -752,13 +769,23 @@ async def gen() -> ChatResponseAsyncGen:
752769
content.append(cur_block)
753770
cur_block = None
754771

772+
if cur_tool_call is not None:
773+
tool_calls_to_send = [*cur_tool_calls, cur_tool_call]
774+
else:
775+
tool_calls_to_send = cur_tool_calls
776+
777+
for tool_call in tool_calls_to_send:
778+
tc = ToolCallBlock(
779+
tool_call_id=tool_call.id,
780+
tool_name=tool_call.name,
781+
tool_kwargs=cast(Dict[str, Any] | str, tool_call.input),
782+
)
783+
update_tool_calls(content, tc)
784+
755785
yield AnthropicChatResponse(
756786
message=ChatMessage(
757787
role=role,
758788
blocks=content,
759-
additional_kwargs={
760-
"tool_calls": [t.dict() for t in tool_calls_to_send]
761-
},
762789
),
763790
citations=cur_citations,
764791
delta="",
@@ -867,7 +894,11 @@ def get_tool_calls_from_response(
867894
**kwargs: Any,
868895
) -> List[ToolSelection]:
869896
"""Predict and call the tool."""
870-
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
897+
tool_calls = [
898+
block
899+
for block in response.message.blocks
900+
if isinstance(block, ToolCallBlock)
901+
]
871902

872903
if len(tool_calls) < 1:
873904
if error_on_no_tool_call:
@@ -879,24 +910,16 @@ def get_tool_calls_from_response(
879910

880911
tool_selections = []
881912
for tool_call in tool_calls:
882-
if (
883-
"input" not in tool_call
884-
or "id" not in tool_call
885-
or "name" not in tool_call
886-
):
887-
raise ValueError("Invalid tool call.")
888-
if tool_call["type"] != "tool_use":
889-
raise ValueError("Invalid tool type. Unsupported by Anthropic")
890913
argument_dict = (
891-
json.loads(tool_call["input"])
892-
if isinstance(tool_call["input"], str)
893-
else tool_call["input"]
914+
json.loads(tool_call.tool_kwargs)
915+
if isinstance(tool_call.tool_kwargs, str)
916+
else tool_call.tool_kwargs
894917
)
895918

896919
tool_selections.append(
897920
ToolSelection(
898-
tool_id=tool_call["id"],
899-
tool_name=tool_call["name"],
921+
tool_id=tool_call.tool_call_id or "",
922+
tool_name=tool_call.tool_name,
900923
tool_kwargs=argument_dict,
901924
)
902925
)

llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/utils.py

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
CitationBlock,
1717
ThinkingBlock,
1818
ContentBlock,
19+
ToolCallBlock,
1920
)
2021

2122
from anthropic.types import (
@@ -24,6 +25,7 @@
2425
DocumentBlockParam,
2526
ThinkingBlockParam,
2627
ImageBlockParam,
28+
ToolUseBlockParam,
2729
CacheControlEphemeralParam,
2830
Base64PDFSourceParam,
2931
)
@@ -207,6 +209,7 @@ def blocks_to_anthropic_blocks(
207209
) -> List[AnthropicContentBlock]:
208210
anthropic_blocks: List[AnthropicContentBlock] = []
209211
global_cache_control: Optional[CacheControlEphemeralParam] = None
212+
unique_tool_calls = []
210213

211214
if kwargs.get("cache_control"):
212215
global_cache_control = CacheControlEphemeralParam(**kwargs["cache_control"])
@@ -269,6 +272,19 @@ def blocks_to_anthropic_blocks(
269272
if global_cache_control:
270273
anthropic_blocks[-1]["cache_control"] = global_cache_control
271274

275+
elif isinstance(block, ToolCallBlock):
276+
unique_tool_calls.append((block.tool_call_id, block.tool_name))
277+
anthropic_blocks.append(
278+
ToolUseBlockParam(
279+
id=block.tool_call_id or "",
280+
input=block.tool_kwargs,
281+
name=block.tool_name,
282+
type="tool_use",
283+
)
284+
)
285+
if global_cache_control:
286+
anthropic_blocks[-1]["cache_control"] = global_cache_control
287+
272288
elif isinstance(block, CachePoint):
273289
if len(anthropic_blocks) > 0:
274290
anthropic_blocks[-1]["cache_control"] = CacheControlEphemeralParam(
@@ -282,20 +298,25 @@ def blocks_to_anthropic_blocks(
282298
else:
283299
raise ValueError(f"Unsupported block type: {type(block)}")
284300

301+
# keep this code for compatibility with older chat histories
285302
tool_calls = kwargs.get("tool_calls", [])
286303
for tool_call in tool_calls:
287-
assert "id" in tool_call
288-
assert "input" in tool_call
289-
assert "name" in tool_call
290-
291-
anthropic_blocks.append(
292-
ToolUseBlockParam(
293-
id=tool_call["id"],
294-
input=tool_call["input"],
295-
name=tool_call["name"],
296-
type="tool_use",
297-
)
298-
)
304+
try:
305+
assert "id" in tool_call
306+
assert "input" in tool_call
307+
assert "name" in tool_call
308+
309+
if (tool_call["id"], tool_call["name"]) not in unique_tool_calls:
310+
anthropic_blocks.append(
311+
ToolUseBlockParam(
312+
id=tool_call["id"],
313+
input=tool_call["input"],
314+
name=tool_call["name"],
315+
type="tool_use",
316+
)
317+
)
318+
except AssertionError:
319+
continue
299320

300321
return anthropic_blocks
301322

@@ -359,9 +380,15 @@ def messages_to_anthropic_messages(
359380

360381

361382
def force_single_tool_call(response: ChatResponse) -> None:
362-
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
383+
tool_calls = [
384+
block for block in response.message.blocks if isinstance(block, ToolCallBlock)
385+
]
363386
if len(tool_calls) > 1:
364-
response.message.additional_kwargs["tool_calls"] = [tool_calls[0]]
387+
response.message.blocks = [
388+
block
389+
for block in response.message.blocks
390+
if not isinstance(block, ToolCallBlock)
391+
] + [tool_calls[0]]
365392

366393

367394
# Anthropic models that support prompt caching
@@ -400,6 +427,33 @@ def force_single_tool_call(response: ChatResponse) -> None:
400427
)
401428

402429

430+
def update_tool_calls(blocks: list[ContentBlock], tool_call: ToolCallBlock) -> None:
431+
if len([block for block in blocks if isinstance(block, ToolCallBlock)]) == 0:
432+
blocks.append(tool_call)
433+
return
434+
elif not any(
435+
block.tool_call_id == tool_call.tool_call_id
436+
for block in blocks
437+
if isinstance(block, ToolCallBlock)
438+
):
439+
blocks.append(tool_call)
440+
return
441+
elif any(
442+
block.tool_call_id == tool_call.tool_call_id
443+
and block.tool_kwargs == tool_call.tool_kwargs
444+
for block in blocks
445+
if isinstance(block, ToolCallBlock)
446+
):
447+
return
448+
else:
449+
for i, block in enumerate(blocks):
450+
if isinstance(block, ToolCallBlock):
451+
if block.tool_call_id == tool_call.tool_call_id:
452+
blocks[i] = tool_call
453+
break
454+
return
455+
456+
403457
def is_anthropic_prompt_caching_supported_model(model: str) -> bool:
404458
"""
405459
Check if the given Anthropic model supports prompt caching.

llama-index-integrations/llms/llama-index-llms-anthropic/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ dev = [
2727

2828
[project]
2929
name = "llama-index-llms-anthropic"
30-
version = "0.9.7"
30+
version = "0.10.0"
3131
description = "llama-index llms anthropic integration"
3232
authors = [{name = "Your Name", email = "[email protected]"}]
3333
requires-python = ">=3.9,<4.0"
3434
readme = "README.md"
3535
license = "MIT"
3636
dependencies = [
3737
"anthropic[bedrock, vertex]>=0.69.0",
38-
"llama-index-core>=0.14.3,<0.15",
38+
"llama-index-core>=0.14.5,<0.15",
3939
]
4040

4141
[tool.codespell]

0 commit comments

Comments
 (0)