Skip to content

feat: support ToolProvider automatically wired into the AI Service if… #129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.tool.ToolProvider;
import org.springframework.stereotype.Service;

import java.lang.annotation.Retention;
Expand Down Expand Up @@ -103,4 +104,10 @@
* this attribute specifies the names of beans containing methods annotated with {@link Tool} that should be used by this AI Service.
*/
String[] tools() default {};

/**
* When the {@link #wiringMode()} is set to {@link AiServiceWiringMode#EXPLICIT},
* this attribute specifies the name of a {@link ToolProvider} bean that should be used by this AI Service.
*/
String toolProvider() default "";
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.tool.DefaultToolExecutor;
import dev.langchain4j.service.tool.ToolExecutor;
import dev.langchain4j.service.tool.ToolProvider;
import org.springframework.beans.factory.FactoryBean;

import java.lang.reflect.Method;
Expand All @@ -36,6 +37,7 @@ class AiServiceFactory implements FactoryBean<Object> {
private RetrievalAugmentor retrievalAugmentor;
private ModerationModel moderationModel;
private List<Object> tools;
private ToolProvider toolProvider;

public AiServiceFactory(Class<Object> aiServiceClass) {
this.aiServiceClass = aiServiceClass;
Expand Down Expand Up @@ -73,6 +75,10 @@ public void setTools(List<Object> tools) {
this.tools = tools;
}

public void setToolProvider(ToolProvider toolProvider) {
this.toolProvider = toolProvider;
}

@Override
public Object getObject() {

Expand Down Expand Up @@ -113,6 +119,9 @@ public Object getObject() {
}
}
}
if (toolProvider != null) {
builder = builder.toolProvider(toolProvider);
}

return builder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.service.IllegalConfigurationException;
import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent;
import dev.langchain4j.service.tool.ToolProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.MutablePropertyValues;
Expand Down Expand Up @@ -58,6 +59,7 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
String[] contentRetrievers = beanFactory.getBeanNamesForType(ContentRetriever.class);
String[] retrievalAugmentors = beanFactory.getBeanNamesForType(RetrievalAugmentor.class);
String[] moderationModels = beanFactory.getBeanNamesForType(ModerationModel.class);
String[] toolProviders = beanFactory.getBeanNamesForType(ToolProvider.class);

Set<String> toolBeanNames = new HashSet<>();
List<ToolSpecification> toolSpecifications = new ArrayList<>();
Expand Down Expand Up @@ -165,6 +167,16 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
propertyValues
);

addBeanReference(
ToolProvider.class,
aiServiceAnnotation,
aiServiceAnnotation.toolProvider(),
toolProviders,
"toolProvider",
"toolProvider",
propertyValues
);

if (aiServiceAnnotation.wiringMode() == EXPLICIT) {
propertyValues.add("tools", toManagedList(asList(aiServiceAnnotation.tools())));
} else if (aiServiceAnnotation.wiringMode() == AUTOMATIC) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package dev.langchain4j.service.spring.mode.automatic.issue3074;

import dev.langchain4j.service.spring.AiService;

@AiService
public interface AiServiceWithToolProvider {
String chat(String userMessage);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package dev.langchain4j.service.spring.mode.automatic.issue3074;

import dev.langchain4j.service.tool.ToolProvider;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;


@SpringBootApplication
public class TestAutowireAiServiceToolProviderApplication {
@Bean
public ToolProvider toolProvider() {
return new TestMcpToolProvider();
}

public static void main(String[] args) {
SpringApplication.run(TestAutowireAiServiceToolProviderApplication.class, args);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package dev.langchain4j.service.spring.mode.automatic.issue3074;

import dev.langchain4j.service.spring.AiServicesAutoConfig;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;

public class TestAutowrieToolProvider {

ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withConfiguration(AutoConfigurations.of(AiServicesAutoConfig.class));

@Test
void should_fail_to_create_AI_service_when_conflicting_chat_models_are_found() {
contextRunner
.withUserConfiguration(TestAutowireAiServiceToolProviderApplication.class)
.run(context -> {
Assertions.assertDoesNotThrow(() -> context.getBean(TestMcpToolProvider.class));
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package dev.langchain4j.service.spring.mode.automatic.issue3074;

import dev.langchain4j.service.tool.ToolProvider;
import dev.langchain4j.service.tool.ToolProviderRequest;
import dev.langchain4j.service.tool.ToolProviderResult;

public class TestMcpToolProvider implements ToolProvider {
@Override
public ToolProviderResult provideTools(ToolProviderRequest toolProviderRequest) {
return null;
}
}