diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiService.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiService.java index 567c877..216a59a 100644 --- a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiService.java +++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiService.java @@ -9,6 +9,7 @@ import dev.langchain4j.rag.RetrievalAugmentor; import dev.langchain4j.rag.content.retriever.ContentRetriever; import dev.langchain4j.service.AiServices; +import dev.langchain4j.service.tool.ToolProvider; import org.springframework.stereotype.Service; import java.lang.annotation.Retention; @@ -103,4 +104,10 @@ * this attribute specifies the names of beans containing methods annotated with {@link Tool} that should be used by this AI Service. */ String[] tools() default {}; + + /** + * When the {@link #wiringMode()} is set to {@link AiServiceWiringMode#EXPLICIT}, + * this attribute specifies the name of a {@link ToolProvider} bean that should be used by this AI Service. + */ + String toolProvider() default ""; } diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceFactory.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceFactory.java index ac47877..362983c 100644 --- a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceFactory.java +++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceFactory.java @@ -12,6 +12,7 @@ import dev.langchain4j.service.AiServices; import dev.langchain4j.service.tool.DefaultToolExecutor; import dev.langchain4j.service.tool.ToolExecutor; +import dev.langchain4j.service.tool.ToolProvider; import org.springframework.beans.factory.FactoryBean; import java.lang.reflect.Method; @@ -36,6 +37,7 @@ class AiServiceFactory implements FactoryBean { private RetrievalAugmentor retrievalAugmentor; private ModerationModel moderationModel; private List tools; + private ToolProvider toolProvider; public AiServiceFactory(Class aiServiceClass) { this.aiServiceClass = aiServiceClass; @@ -73,6 +75,10 @@ public void setTools(List tools) { this.tools = tools; } + public void setToolProvider(ToolProvider toolProvider) { + this.toolProvider = toolProvider; + } + @Override public Object getObject() { @@ -113,6 +119,9 @@ public Object getObject() { } } } + if (toolProvider != null) { + builder = builder.toolProvider(toolProvider); + } return builder.build(); } diff --git a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java index b835023..e84a47d 100644 --- a/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java +++ b/langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java @@ -12,6 +12,7 @@ import dev.langchain4j.rag.content.retriever.ContentRetriever; import dev.langchain4j.service.IllegalConfigurationException; import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent; +import dev.langchain4j.service.tool.ToolProvider; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.MutablePropertyValues; @@ -58,6 +59,7 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { String[] contentRetrievers = beanFactory.getBeanNamesForType(ContentRetriever.class); String[] retrievalAugmentors = beanFactory.getBeanNamesForType(RetrievalAugmentor.class); String[] moderationModels = beanFactory.getBeanNamesForType(ModerationModel.class); + String[] toolProviders = beanFactory.getBeanNamesForType(ToolProvider.class); Set toolBeanNames = new HashSet<>(); List toolSpecifications = new ArrayList<>(); @@ -165,6 +167,16 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() { propertyValues ); + addBeanReference( + ToolProvider.class, + aiServiceAnnotation, + aiServiceAnnotation.toolProvider(), + toolProviders, + "toolProvider", + "toolProvider", + propertyValues + ); + if (aiServiceAnnotation.wiringMode() == EXPLICIT) { propertyValues.add("tools", toManagedList(asList(aiServiceAnnotation.tools()))); } else if (aiServiceAnnotation.wiringMode() == AUTOMATIC) { diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/issue3074/AiServiceWithToolProvider.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/issue3074/AiServiceWithToolProvider.java new file mode 100644 index 0000000..f873fb1 --- /dev/null +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/issue3074/AiServiceWithToolProvider.java @@ -0,0 +1,8 @@ +package dev.langchain4j.service.spring.mode.automatic.issue3074; + +import dev.langchain4j.service.spring.AiService; + +@AiService +public interface AiServiceWithToolProvider { + String chat(String userMessage); +} diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/issue3074/TestAutowireAiServiceToolProviderApplication.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/issue3074/TestAutowireAiServiceToolProviderApplication.java new file mode 100644 index 0000000..bd1eda8 --- /dev/null +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/issue3074/TestAutowireAiServiceToolProviderApplication.java @@ -0,0 +1,19 @@ +package dev.langchain4j.service.spring.mode.automatic.issue3074; + +import dev.langchain4j.service.tool.ToolProvider; +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.context.annotation.Bean; + + +@SpringBootApplication +public class TestAutowireAiServiceToolProviderApplication { + @Bean + public ToolProvider toolProvider() { + return new TestMcpToolProvider(); + } + + public static void main(String[] args) { + SpringApplication.run(TestAutowireAiServiceToolProviderApplication.class, args); + } +} diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/issue3074/TestAutowrieToolProvider.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/issue3074/TestAutowrieToolProvider.java new file mode 100644 index 0000000..70626d1 --- /dev/null +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/issue3074/TestAutowrieToolProvider.java @@ -0,0 +1,22 @@ +package dev.langchain4j.service.spring.mode.automatic.issue3074; + +import dev.langchain4j.service.spring.AiServicesAutoConfig; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +public class TestAutowrieToolProvider { + + ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(AiServicesAutoConfig.class)); + + @Test + void should_fail_to_create_AI_service_when_conflicting_chat_models_are_found() { + contextRunner + .withUserConfiguration(TestAutowireAiServiceToolProviderApplication.class) + .run(context -> { + Assertions.assertDoesNotThrow(() -> context.getBean(TestMcpToolProvider.class)); + }); + } +} diff --git a/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/issue3074/TestMcpToolProvider.java b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/issue3074/TestMcpToolProvider.java new file mode 100644 index 0000000..7146a8b --- /dev/null +++ b/langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/issue3074/TestMcpToolProvider.java @@ -0,0 +1,12 @@ +package dev.langchain4j.service.spring.mode.automatic.issue3074; + +import dev.langchain4j.service.tool.ToolProvider; +import dev.langchain4j.service.tool.ToolProviderRequest; +import dev.langchain4j.service.tool.ToolProviderResult; + +public class TestMcpToolProvider implements ToolProvider { + @Override + public ToolProviderResult provideTools(ToolProviderRequest toolProviderRequest) { + return null; + } +}