Skip to content

Commit 9692ebd

Browse files
committed
Refactor toolcalling support for ZhipuAI
- Update ZhipuAI Chat Model to use ToolCalling Manager and ToolExecutionEligibilityPredicate - Update ZhipuAI ChatOptions to implement ToolCallingChatOptions - Update Autoconfiguration for ZhipuAI model to use ToolCallingAutoconfiguration - Update tests Signed-off-by: Ilayaperumal Gopinathan <[email protected]>
1 parent 49e1dd1 commit 9692ebd

File tree

12 files changed

+327
-237
lines changed

12 files changed

+327
-237
lines changed

auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/pom.xml

+6
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@
3535

3636
<!-- Spring AI auto configurations -->
3737

38+
<dependency>
39+
<groupId>org.springframework.ai</groupId>
40+
<artifactId>spring-ai-autoconfigure-model-tool</artifactId>
41+
<version>${project.parent.version}</version>
42+
</dependency>
43+
3844
<dependency>
3945
<groupId>org.springframework.ai</groupId>
4046
<artifactId>spring-ai-autoconfigure-retry</artifactId>

auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiChatAutoConfiguration.java

+14-13
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,16 @@
2626
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
2727
import org.springframework.ai.model.function.FunctionCallback;
2828
import org.springframework.ai.model.function.FunctionCallbackResolver;
29+
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
30+
import org.springframework.ai.model.tool.ToolCallingManager;
31+
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
32+
import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration;
2933
import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration;
3034
import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
3135
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
3236
import org.springframework.beans.factory.ObjectProvider;
3337
import org.springframework.boot.autoconfigure.AutoConfiguration;
38+
import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
3439
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
3540
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
3641
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
@@ -50,28 +55,32 @@
5055
* @author Geng Rong
5156
* @author Ilayaperumal Gopinathan
5257
*/
53-
@AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class })
58+
@AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class,
59+
ToolCallingAutoConfiguration.class })
5460
@ConditionalOnClass(ZhiPuAiApi.class)
5561
@ConditionalOnProperty(name = SpringAIModelProperties.CHAT_MODEL, havingValue = SpringAIModels.ZHIPUAI,
5662
matchIfMissing = true)
5763
@EnableConfigurationProperties({ ZhiPuAiConnectionProperties.class, ZhiPuAiChatProperties.class })
64+
@ImportAutoConfiguration(classes = { SpringAiRetryAutoConfiguration.class, RestClientAutoConfiguration.class,
65+
ToolCallingAutoConfiguration.class })
5866
public class ZhiPuAiChatAutoConfiguration {
5967

6068
@Bean
6169
@ConditionalOnMissingBean
6270
public ZhiPuAiChatModel zhiPuAiChatModel(ZhiPuAiConnectionProperties commonProperties,
6371
ZhiPuAiChatProperties chatProperties, ObjectProvider<RestClient.Builder> restClientBuilderProvider,
64-
List<FunctionCallback> toolFunctionCallbacks, FunctionCallbackResolver functionCallbackResolver,
6572
RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler,
6673
ObjectProvider<ObservationRegistry> observationRegistry,
67-
ObjectProvider<ChatModelObservationConvention> observationConvention) {
74+
ObjectProvider<ChatModelObservationConvention> observationConvention, ToolCallingManager toolCallingManager,
75+
ObjectProvider<ToolExecutionEligibilityPredicate> toolExecutionEligibilityPredicate) {
6876

6977
var zhiPuAiApi = zhiPuAiApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(),
7078
chatProperties.getApiKey(), commonProperties.getApiKey(),
7179
restClientBuilderProvider.getIfAvailable(RestClient::builder), responseErrorHandler);
7280

73-
var chatModel = new ZhiPuAiChatModel(zhiPuAiApi, chatProperties.getOptions(), functionCallbackResolver,
74-
toolFunctionCallbacks, retryTemplate, observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP));
81+
var chatModel = new ZhiPuAiChatModel(zhiPuAiApi, chatProperties.getOptions(), toolCallingManager, retryTemplate,
82+
observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
83+
toolExecutionEligibilityPredicate.getIfUnique(DefaultToolExecutionEligibilityPredicate::new));
7584

7685
observationConvention.ifAvailable(chatModel::setObservationConvention);
7786

@@ -90,12 +99,4 @@ private ZhiPuAiApi zhiPuAiApi(String baseUrl, String commonBaseUrl, String apiKe
9099
return new ZhiPuAiApi(resolvedBaseUrl, resolvedApiKey, restClientBuilder, responseErrorHandler);
91100
}
92101

93-
@Bean
94-
@ConditionalOnMissingBean
95-
public FunctionCallbackResolver springAiFunctionManager(ApplicationContext context) {
96-
DefaultFunctionCallbackResolver manager = new DefaultFunctionCallbackResolver();
97-
manager.setApplicationContext(context);
98-
return manager;
99-
}
100-
101102
}

auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/test/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiAutoConfigurationIT.java

+5-3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import org.springframework.ai.chat.messages.UserMessage;
2929
import org.springframework.ai.chat.model.ChatResponse;
30+
import org.springframework.ai.chat.prompt.ChatOptions;
3031
import org.springframework.ai.chat.prompt.Prompt;
3132
import org.springframework.ai.embedding.EmbeddingResponse;
3233
import org.springframework.ai.image.ImagePrompt;
@@ -58,8 +59,8 @@ public class ZhiPuAiAutoConfigurationIT {
5859
void generate() {
5960
this.contextRunner.withConfiguration(AutoConfigurations.of(ZhiPuAiChatAutoConfiguration.class)).run(context -> {
6061
ZhiPuAiChatModel chatModel = context.getBean(ZhiPuAiChatModel.class);
61-
String response = chatModel.call("Hello");
62-
assertThat(response).isNotEmpty();
62+
ChatResponse response = chatModel.call(new Prompt("Hello", ChatOptions.builder().build()));
63+
assertThat(response.getResult().getOutput().getText()).isNotEmpty();
6364
logger.info("Response: " + response);
6465
});
6566
}
@@ -68,7 +69,8 @@ void generate() {
6869
void generateStreaming() {
6970
this.contextRunner.withConfiguration(AutoConfigurations.of(ZhiPuAiChatAutoConfiguration.class)).run(context -> {
7071
ZhiPuAiChatModel chatModel = context.getBean(ZhiPuAiChatModel.class);
71-
Flux<ChatResponse> responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello")));
72+
Flux<ChatResponse> responseFlux = chatModel
73+
.stream(new Prompt(new UserMessage("Hello"), ChatOptions.builder().build()));
7274
String response = responseFlux.collectList()
7375
.block()
7476
.stream()

auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/test/java/org/springframework/ai/model/zhipuai/autoconfigure/tool/FunctionCallbackInPromptIT.java

+3-4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.springframework.ai.model.function.FunctionCallback;
3434
import org.springframework.ai.model.zhipuai.autoconfigure.ZhiPuAiChatAutoConfiguration;
3535
import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration;
36+
import org.springframework.ai.tool.function.FunctionToolCallback;
3637
import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
3738
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
3839
import org.springframework.boot.autoconfigure.AutoConfigurations;
@@ -64,8 +65,7 @@ void functionCallTest() {
6465
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");
6566

6667
var promptOptions = ZhiPuAiChatOptions.builder()
67-
.functionCallbacks(List.of(FunctionCallback.builder()
68-
.function("CurrentWeatherService", new MockWeatherService())
68+
.toolCallbacks(List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService())
6969
.description("Get the weather in location")
7070
.inputType(MockWeatherService.Request.class)
7171
// .responseConverter(response -> "" + response.temp() +
@@ -92,8 +92,7 @@ void streamingFunctionCallTest() {
9292
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");
9393

9494
var promptOptions = ZhiPuAiChatOptions.builder()
95-
.functionCallbacks(List.of(FunctionCallback.builder()
96-
.function("CurrentWeatherService", new MockWeatherService())
95+
.toolCallbacks(List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService())
9796
.description("Get the weather in location")
9897
.inputType(MockWeatherService.Request.class)
9998
.build()))

auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/test/java/org/springframework/ai/model/zhipuai/autoconfigure/tool/FunctionCallbackWithPlainFunctionBeanIT.java

+9-8
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.springframework.ai.chat.model.Generation;
3333
import org.springframework.ai.chat.prompt.Prompt;
3434
import org.springframework.ai.model.function.FunctionCallingOptions;
35+
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3536
import org.springframework.ai.model.zhipuai.autoconfigure.ZhiPuAiChatAutoConfiguration;
3637
import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration;
3738
import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
@@ -69,16 +70,16 @@ void functionCallTest() {
6970
UserMessage userMessage = new UserMessage(
7071
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");
7172

72-
ChatResponse response = chatModel.call(
73-
new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().function("weatherFunction").build()));
73+
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage),
74+
ZhiPuAiChatOptions.builder().toolNames("weatherFunction").build()));
7475

7576
logger.info("Response: {}", response);
7677

7778
assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15");
7879

7980
// Test weatherFunctionTwo
8081
response = chatModel.call(new Prompt(List.of(userMessage),
81-
ZhiPuAiChatOptions.builder().function("weatherFunctionTwo").build()));
82+
ZhiPuAiChatOptions.builder().toolNames("weatherFunctionTwo").build()));
8283

8384
logger.info("Response: {}", response);
8485

@@ -97,8 +98,8 @@ void functionCallWithPortableFunctionCallingOptions() {
9798
UserMessage userMessage = new UserMessage(
9899
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");
99100

100-
FunctionCallingOptions functionOptions = FunctionCallingOptions.builder()
101-
.function("weatherFunction")
101+
ToolCallingChatOptions functionOptions = ToolCallingChatOptions.builder()
102+
.toolNames("weatherFunction")
102103
.build();
103104

104105
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions));
@@ -117,8 +118,8 @@ void streamFunctionCallTest() {
117118
UserMessage userMessage = new UserMessage(
118119
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");
119120

120-
Flux<ChatResponse> response = chatModel.stream(
121-
new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().function("weatherFunction").build()));
121+
Flux<ChatResponse> response = chatModel.stream(new Prompt(List.of(userMessage),
122+
ZhiPuAiChatOptions.builder().toolNames("weatherFunction").build()));
122123

123124
String content = response.collectList()
124125
.block()
@@ -136,7 +137,7 @@ void streamFunctionCallTest() {
136137

137138
// Test weatherFunctionTwo
138139
response = chatModel.stream(new Prompt(List.of(userMessage),
139-
ZhiPuAiChatOptions.builder().function("weatherFunctionTwo").build()));
140+
ZhiPuAiChatOptions.builder().toolNames("weatherFunctionTwo").build()));
140141

141142
content = response.collectList()
142143
.block()

auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/test/java/org/springframework/ai/model/zhipuai/autoconfigure/tool/ZhipuAiFunctionCallbackIT.java

+7-6
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
import org.springframework.ai.model.function.FunctionCallback;
3434
import org.springframework.ai.model.zhipuai.autoconfigure.ZhiPuAiChatAutoConfiguration;
3535
import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration;
36+
import org.springframework.ai.tool.ToolCallback;
37+
import org.springframework.ai.tool.function.FunctionToolCallback;
3638
import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
3739
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
3840
import org.springframework.boot.autoconfigure.AutoConfigurations;
@@ -67,7 +69,7 @@ void functionCallTest() {
6769
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");
6870

6971
ChatResponse response = chatModel
70-
.call(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().function("WeatherInfo").build()));
72+
.call(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().toolNames("WeatherInfo").build()));
7173

7274
logger.info("Response: {}", response);
7375

@@ -85,8 +87,8 @@ void streamFunctionCallTest() {
8587
UserMessage userMessage = new UserMessage(
8688
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");
8789

88-
Flux<ChatResponse> response = chatModel
89-
.stream(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().function("WeatherInfo").build()));
90+
Flux<ChatResponse> response = chatModel.stream(
91+
new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().toolNames("WeatherInfo").build()));
9092

9193
String content = response.collectList()
9294
.block()
@@ -109,10 +111,9 @@ void streamFunctionCallTest() {
109111
static class Config {
110112

111113
@Bean
112-
public FunctionCallback weatherFunctionInfo() {
114+
public ToolCallback weatherFunctionInfo() {
113115

114-
return FunctionCallback.builder()
115-
.function("WeatherInfo", new MockWeatherService())
116+
return FunctionToolCallback.builder("WeatherInfo", new MockWeatherService())
116117
.description("Get the weather in location")
117118
.inputType(MockWeatherService.Request.class)
118119
// .responseConverter(response -> "" + response.temp() + response.unit())

0 commit comments

Comments
 (0)