Skip to content

Commit 9efd6a5

Browse files
Implemented a web search tool, provided by Anthropic
This web search tool is categorized as a `server tool`. related doc: https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/web-search-tool Signed-off-by: jonghoonpark <[email protected]>
1 parent 3919204 commit 9efd6a5

File tree

13 files changed

+548
-44
lines changed

13 files changed

+548
-44
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Source;
4343
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type;
4444
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
45+
import org.springframework.ai.anthropic.api.tool.Tool;
4546
import org.springframework.ai.chat.messages.AssistantMessage;
4647
import org.springframework.ai.chat.messages.MessageType;
4748
import org.springframework.ai.chat.messages.ToolResponseMessage;
@@ -443,20 +444,32 @@ Prompt buildRequestPrompt(Prompt prompt) {
443444
this.defaultOptions.getToolCallbacks()));
444445
requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(),
445446
this.defaultOptions.getToolContext()));
447+
requestOptions.setServerTools(
448+
mergeServerTools(runtimeOptions.getServerTools(), this.defaultOptions.getServerTools()));
446449
}
447450
else {
448451
requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders());
449452
requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
450453
requestOptions.setToolNames(this.defaultOptions.getToolNames());
451454
requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
452455
requestOptions.setToolContext(this.defaultOptions.getToolContext());
456+
requestOptions.setServerTools(this.defaultOptions.getServerTools());
453457
}
454458

455459
ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks());
456460

457461
return new Prompt(prompt.getInstructions(), requestOptions);
458462
}
459463

464+
static List<Tool> mergeServerTools(List<Tool> runtimeServerTools, List<Tool> defaultToolNames) {
465+
Assert.notNull(runtimeServerTools, "runtimeServerTools cannot be null");
466+
Assert.notNull(defaultToolNames, "defaultToolNames cannot be null");
467+
if (CollectionUtils.isEmpty(runtimeServerTools)) {
468+
return new ArrayList<>(defaultToolNames);
469+
}
470+
return new ArrayList<>(runtimeServerTools);
471+
}
472+
460473
private Map<String, String> mergeHttpHeaders(Map<String, String> runtimeHttpHeaders,
461474
Map<String, String> defaultHttpHeaders) {
462475
var mergedHttpHeaders = new HashMap<>(defaultHttpHeaders);
@@ -531,15 +544,19 @@ else if (message.getMessageType() == MessageType.TOOL) {
531544
request = ChatCompletionRequest.from(request).tools(getFunctionTools(toolDefinitions)).build();
532545
}
533546

547+
if (!CollectionUtils.isEmpty(requestOptions.getServerTools())) {
548+
request = ChatCompletionRequest.from(request).tools(requestOptions.getServerTools()).build();
549+
}
550+
534551
return request;
535552
}
536553

537-
private List<AnthropicApi.Tool> getFunctionTools(List<ToolDefinition> toolDefinitions) {
554+
private List<Tool> getFunctionTools(List<ToolDefinition> toolDefinitions) {
538555
return toolDefinitions.stream().map(toolDefinition -> {
539556
var name = toolDefinition.name();
540557
var description = toolDefinition.description();
541558
String inputSchema = toolDefinition.inputSchema();
542-
return new AnthropicApi.Tool(name, description, JsonParser.fromJson(inputSchema, new TypeReference<>() {
559+
return new Tool(name, description, JsonParser.fromJson(inputSchema, new TypeReference<>() {
543560
}));
544561
}).toList();
545562
}

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import org.springframework.ai.anthropic.api.AnthropicApi;
3434
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest;
35+
import org.springframework.ai.anthropic.api.tool.Tool;
3536
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3637
import org.springframework.ai.tool.ToolCallback;
3738
import org.springframework.lang.Nullable;
@@ -44,6 +45,7 @@
4445
* @author Thomas Vitale
4546
* @author Alexandros Pappas
4647
* @author Ilayaperumal Gopinathan
48+
* @author Jonghoon Park
4749
* @since 1.0.0
4850
*/
4951
@JsonInclude(Include.NON_NULL)
@@ -82,6 +84,8 @@ public class AnthropicChatOptions implements ToolCallingChatOptions {
8284
@JsonIgnore
8385
private Map<String, Object> toolContext = new HashMap<>();
8486

87+
@JsonIgnore
88+
private List<Tool> serverTools = new ArrayList<>();
8589

8690
/**
8791
* Optional HTTP headers to be added to the chat completion request.
@@ -110,6 +114,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
110114
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
111115
.internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled())
112116
.toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null)
117+
.serverTools(fromOptions.getServerTools() != null ? new ArrayList<>(fromOptions.getServerTools()) : null)
113118
.httpHeaders(fromOptions.getHttpHeaders() != null ? new HashMap<>(fromOptions.getHttpHeaders()) : null)
114119
.build();
115120
}
@@ -250,6 +255,17 @@ public void setToolContext(Map<String, Object> toolContext) {
250255
this.toolContext = toolContext;
251256
}
252257

258+
@JsonIgnore
259+
public List<Tool> getServerTools() {
260+
return this.serverTools;
261+
}
262+
263+
public void setServerTools(List<Tool> serverTools) {
264+
Assert.notNull(serverTools, "serverTools cannot be null");
265+
Assert.noNullElements(serverTools, "serverTools cannot contain null elements");
266+
this.serverTools = serverTools;
267+
}
268+
253269
@JsonIgnore
254270
public Map<String, String> getHttpHeaders() {
255271
return this.httpHeaders;
@@ -282,14 +298,15 @@ public boolean equals(Object o) {
282298
&& Objects.equals(this.toolNames, that.toolNames)
283299
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
284300
&& Objects.equals(this.toolContext, that.toolContext)
301+
&& Objects.equals(this.serverTools, that.serverTools)
285302
&& Objects.equals(this.httpHeaders, that.httpHeaders);
286303
}
287304

288305
@Override
289306
public int hashCode() {
290307
return Objects.hash(this.model, this.maxTokens, this.metadata, this.stopSequences, this.temperature, this.topP,
291308
this.topK, this.thinking, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled,
292-
this.toolContext, this.httpHeaders);
309+
this.toolContext, this.serverTools, this.httpHeaders);
293310
}
294311

295312
public static class Builder {
@@ -384,6 +401,16 @@ public Builder toolContext(Map<String, Object> toolContext) {
384401
return this;
385402
}
386403

404+
public Builder serverTools(List<Tool> serverTools) {
405+
if (this.options.serverTools == null) {
406+
this.options.serverTools = serverTools;
407+
}
408+
else {
409+
this.options.serverTools.addAll(serverTools);
410+
}
411+
return this;
412+
}
413+
387414
public Builder httpHeaders(Map<String, String> httpHeaders) {
388415
this.options.setHttpHeaders(httpHeaders);
389416
return this;

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java

Lines changed: 67 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,15 @@
3030
import com.fasterxml.jackson.annotation.JsonProperty;
3131
import com.fasterxml.jackson.annotation.JsonSubTypes;
3232
import com.fasterxml.jackson.annotation.JsonTypeInfo;
33+
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
34+
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
3335
import reactor.core.publisher.Flux;
3436
import reactor.core.publisher.Mono;
3537

3638
import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder;
39+
import org.springframework.ai.anthropic.api.tool.Tool;
40+
import org.springframework.ai.anthropic.util.ContentFieldDeserializer;
41+
import org.springframework.ai.anthropic.util.ContentFieldSerializer;
3742
import org.springframework.ai.model.ChatModelDescription;
3843
import org.springframework.ai.model.ModelOptionsUtils;
3944
import org.springframework.ai.observation.conventions.AiProvider;
@@ -632,7 +637,12 @@ public ChatCompletionRequestBuilder topK(Integer topK) {
632637
}
633638

634639
public ChatCompletionRequestBuilder tools(List<Tool> tools) {
635-
this.tools = tools;
640+
if (this.tools == null) {
641+
this.tools = tools;
642+
}
643+
else {
644+
this.tools.addAll(tools);
645+
}
636646
return this;
637647
}
638648

@@ -717,7 +727,11 @@ public record ContentBlock(
717727

718728
// tool_result response only
719729
@JsonProperty("tool_use_id") String toolUseId,
720-
@JsonProperty("content") String content,
730+
731+
@JsonSerialize(using = ContentFieldSerializer.class)
732+
@JsonDeserialize(using = ContentFieldDeserializer.class)
733+
@JsonProperty("content")
734+
Object content,
721735

722736
// Thinking only
723737
@JsonProperty("signature") String signature,
@@ -728,6 +742,15 @@ public record ContentBlock(
728742
) {
729743
// @formatter:on
730744

745+
@JsonInclude(Include.NON_NULL)
746+
@JsonIgnoreProperties(ignoreUnknown = true)
747+
public record WebSearchToolContentBlock(@JsonProperty("type") String type, @JsonProperty("title") String title,
748+
@JsonProperty("url") String url, @JsonProperty("encrypted_content") String EncryptedContent,
749+
@JsonProperty("page_age") String pageAge) {
750+
751+
}
752+
// @formatter:on
753+
731754
/**
732755
* Create content block
733756
* @param mediaType The media type of the content.
@@ -813,6 +836,18 @@ public enum Type {
813836
@JsonProperty("tool_result")
814837
TOOL_RESULT("tool_result"),
815838

839+
/**
840+
* Server Tool request
841+
*/
842+
@JsonProperty("server_tool_use")
843+
SERVER_TOOL_USE("server_tool_use"),
844+
845+
/**
846+
* Web search tool result
847+
*/
848+
@JsonProperty("web_search_tool_result")
849+
WEB_SEARCH_TOOL_RESULT("web_search_tool_result"),
850+
816851
/**
817852
* Text message.
818853
*/
@@ -926,22 +961,6 @@ public Source(String url) {
926961
/// CONTENT_BLOCK EVENTS
927962
///////////////////////////////////////
928963

929-
/**
930-
* Tool description.
931-
*
932-
* @param name The name of the tool.
933-
* @param description A description of the tool.
934-
* @param inputSchema The input schema of the tool.
935-
*/
936-
@JsonInclude(Include.NON_NULL)
937-
public record Tool(
938-
// @formatter:off
939-
@JsonProperty("name") String name,
940-
@JsonProperty("description") String description,
941-
@JsonProperty("input_schema") Map<String, Object> inputSchema) {
942-
// @formatter:on
943-
}
944-
945964
// CB START EVENT
946965

947966
/**
@@ -987,16 +1006,25 @@ public record ChatCompletionResponse(
9871006
public record Usage(
9881007
// @formatter:off
9891008
@JsonProperty("input_tokens") Integer inputTokens,
990-
@JsonProperty("output_tokens") Integer outputTokens) {
991-
// @formatter:off
1009+
@JsonProperty("output_tokens") Integer outputTokens,
1010+
@JsonProperty("server_tool_use") ServerToolUse serverToolUse) {
1011+
// @formatter:on
1012+
}
1013+
1014+
@JsonInclude(Include.NON_NULL)
1015+
@JsonIgnoreProperties(ignoreUnknown = true)
1016+
public record ServerToolUse(
1017+
// @formatter:off
1018+
@JsonProperty("web_search_requests") Integer webSearchRequests) {
1019+
// @formatter:on
9921020
}
9931021

994-
/// ECB STOP
1022+
/// ECB STOP
9951023

9961024
/**
9971025
* Special event used to aggregate multiple tool use events into a single event with
9981026
* list of aggregated ContentBlockToolUse.
999-
*/
1027+
*/
10001028
public static class ToolUseAggregationEvent implements StreamEvent {
10011029

10021030
private Integer index;
@@ -1015,17 +1043,17 @@ public EventType type() {
10151043
}
10161044

10171045
/**
1018-
* Get tool content blocks.
1019-
* @return The tool content blocks.
1020-
*/
1046+
* Get tool content blocks.
1047+
* @return The tool content blocks.
1048+
*/
10211049
public List<ContentBlockStartEvent.ContentBlockToolUse> getToolContentBlocks() {
10221050
return this.toolContentBlocks;
10231051
}
10241052

10251053
/**
1026-
* Check if the event is empty.
1027-
* @return True if the event is empty, false otherwise.
1028-
*/
1054+
* Check if the event is empty.
1055+
* @return True if the event is empty, false otherwise.
1056+
*/
10291057
public boolean isEmpty() {
10301058
return (this.index == null || this.id == null || this.name == null
10311059
|| !StringUtils.hasText(this.partialJson));
@@ -1054,7 +1082,8 @@ ToolUseAggregationEvent appendPartialJson(String partialJson) {
10541082
void squashIntoContentBlock() {
10551083
Map<String, Object> map = (StringUtils.hasText(this.partialJson))
10561084
? ModelOptionsUtils.jsonToMap(this.partialJson) : Map.of();
1057-
this.toolContentBlocks.add(new ContentBlockStartEvent.ContentBlockToolUse("tool_use", this.id, this.name, map));
1085+
this.toolContentBlocks
1086+
.add(new ContentBlockStartEvent.ContentBlockToolUse("tool_use", this.id, this.name, map));
10581087
this.index = null;
10591088
this.id = null;
10601089
this.name = null;
@@ -1063,28 +1092,29 @@ void squashIntoContentBlock() {
10631092

10641093
@Override
10651094
public String toString() {
1066-
return "EventToolUseBuilder [index=" + this.index + ", id=" + this.id + ", name=" + this.name + ", partialJson="
1067-
+ this.partialJson + ", toolUseMap=" + this.toolContentBlocks + "]";
1095+
return "EventToolUseBuilder [index=" + this.index + ", id=" + this.id + ", name=" + this.name
1096+
+ ", partialJson=" + this.partialJson + ", toolUseMap=" + this.toolContentBlocks + "]";
10681097
}
10691098

10701099
}
10711100

1072-
///////////////////////////////////////
1073-
/// MESSAGE EVENTS
1074-
///////////////////////////////////////
1101+
///////////////////////////////////////
1102+
/// MESSAGE EVENTS
1103+
///////////////////////////////////////
10751104

1076-
// MESSAGE START EVENT
1105+
// MESSAGE START EVENT
10771106

10781107
/**
10791108
* Content block start event.
1109+
*
10801110
* @param type The event type.
10811111
* @param index The index of the content block.
10821112
* @param contentBlock The content block body.
1083-
*/
1113+
*/
10841114
@JsonInclude(Include.NON_NULL)
10851115
@JsonIgnoreProperties(ignoreUnknown = true)
10861116
public record ContentBlockStartEvent(
1087-
// @formatter:off
1117+
// @formatter:off
10881118
@JsonProperty("type") EventType type,
10891119
@JsonProperty("index") Integer index,
10901120
@JsonProperty("content_block") ContentBlockBody contentBlock) implements StreamEvent {

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -174,7 +174,7 @@ else if (event.type().equals(EventType.MESSAGE_DELTA)) {
174174

175175
if (messageDeltaEvent.usage() != null) {
176176
var totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(),
177-
messageDeltaEvent.usage().outputTokens());
177+
messageDeltaEvent.usage().outputTokens(), contentBlockReference.get().usage.serverToolUse());
178178
contentBlockReference.get().withUsage(totalUsage);
179179
}
180180
}

0 commit comments

Comments
 (0)