diff --git a/local-ai-example/pom.xml b/local-ai-example/pom.xml new file mode 100644 index 00000000..01aab70a --- /dev/null +++ b/local-ai-example/pom.xml @@ -0,0 +1,65 @@ + + 4.0.0 + + dev.langchain4j + langchain4j-examples + 0.30.0 + + + local-ai-example + jar + + local-ai-example + http://maven.apache.org + + + UTF-8 + + + + + junit + junit + 3.8.1 + test + + + dev.langchain4j + langchain4j-core + 0.30.0 + compile + + + + dev.langchain4j + langchain4j-local-ai + 0.29.1 + compile + + + dev.langchain4j + langchain4j + 0.30.0 + compile + + + dev.langchain4j + langchain4j-embeddings-all-minilm-l6-v2 + 0.30.0 + compile + + + com.github.docker-java + docker-java-api + 3.3.6 + compile + + + org.testcontainers + testcontainers + 1.19.7 + compile + + + diff --git a/local-ai-example/src/main/java/AbstractLocalAiInfrastructure.java b/local-ai-example/src/main/java/AbstractLocalAiInfrastructure.java new file mode 100644 index 00000000..655f1543 --- /dev/null +++ b/local-ai-example/src/main/java/AbstractLocalAiInfrastructure.java @@ -0,0 +1,99 @@ +import com.github.dockerjava.api.DockerClient; +import com.github.dockerjava.api.command.InspectContainerResponse; +import com.github.dockerjava.api.model.Image; +import org.testcontainers.DockerClientFactory; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.utility.DockerImageName; +import org.testcontainers.utility.MountableFile; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +public class AbstractLocalAiInfrastructure { + + private static final String LOCAL_AI_IMAGE = "localai/localai:latest"; + + private static final String LOCAL_IMAGE_NAME = "tc-local-ai"; + + private static final String LOCAL_LOCAL_AI_IMAGE = String.format("%s:%s", LOCAL_IMAGE_NAME, DockerImageName.parse(LOCAL_AI_IMAGE).getVersionPart()); + + private static final List CMDS = Arrays.asList( + new String[]{"curl", "-o", "/build/models/ggml-gpt4all-j", "https://gpt4all.io/models/ggml-gpt4all-j.bin"}, + new String[]{"curl", "-Lo", "/build/models/ggml-model-q4_0", "https://huggingface.co/LangChain4j/localai-embeddings/resolve/main/ggml-model-q4_0"}); + + static final LocalAiContainer localAi; + + static { + localAi = new LocalAiContainer(new LocalAi(LOCAL_AI_IMAGE, LOCAL_LOCAL_AI_IMAGE).resolve()); + localAi.start(); + createImage(localAi, LOCAL_LOCAL_AI_IMAGE); + } + + static void createImage(GenericContainer container, String localImageName) { + DockerImageName dockerImageName = DockerImageName.parse(container.getDockerImageName()); + if (!dockerImageName.equals(DockerImageName.parse(localImageName))) { + DockerClient dockerClient = DockerClientFactory.instance().client(); + List images = dockerClient.listImagesCmd().withReferenceFilter(localImageName).exec(); + if (images.isEmpty()) { + DockerImageName imageModel = DockerImageName.parse(localImageName); + dockerClient.commitCmd(container.getContainerId()) + .withRepository(imageModel.getUnversionedPart()) + .withLabels(Collections.singletonMap("org.testcontainers.sessionId", "")) + .withTag(imageModel.getVersionPart()) + .exec(); + } + } + } + + static class LocalAiContainer extends GenericContainer { + + public LocalAiContainer(DockerImageName image) { + super(image); + withExposedPorts(8080); + withImagePullPolicy(dockerImageName -> !dockerImageName.getUnversionedPart().startsWith(LOCAL_IMAGE_NAME)); + } + @Override + protected void containerIsStarted(InspectContainerResponse containerInfo) { + if (!DockerImageName.parse(getDockerImageName()).equals(DockerImageName.parse(LOCAL_LOCAL_AI_IMAGE))) { + try { + for (String[] cmd : CMDS) { + execInContainer(cmd); + } + copyFileToContainer(MountableFile.forClasspathResource("ggml-model-q4_0.yaml"), "/build/models/ggml-model-q4_0.yaml"); + } catch (IOException | InterruptedException e) { + throw new RuntimeException("Error downloading the model", e); + } + } + } + + public String getBaseUrl() { + return "http://" + getHost() + ":" + getMappedPort(8080); + } + } + + static class LocalAi { + + private final String baseImage; + + private final String localImageName; + + LocalAi(String baseImage, String localImageName) { + this.baseImage = baseImage; + this.localImageName = localImageName; + } + + protected DockerImageName resolve() { + DockerImageName dockerImageName = DockerImageName.parse(this.baseImage); + DockerClient dockerClient = DockerClientFactory.instance().client(); + List images = dockerClient.listImagesCmd().withReferenceFilter(this.localImageName).exec(); + if (images.isEmpty()) { + return dockerImageName; + } + return DockerImageName.parse(this.localImageName); + } + + } + +} diff --git a/local-ai-example/src/main/java/LocalAiChatModelExamples.java b/local-ai-example/src/main/java/LocalAiChatModelExamples.java new file mode 100644 index 00000000..fdca7191 --- /dev/null +++ b/local-ai-example/src/main/java/LocalAiChatModelExamples.java @@ -0,0 +1,37 @@ +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.localai.LocalAiChatModel; +import dev.langchain4j.model.output.Response; + +import java.util.Collections; +import java.util.List; + +public class LocalAiChatModelExamples extends AbstractLocalAiInfrastructure { + static ChatLanguageModel model = LocalAiChatModel.builder() + .baseUrl(localAi.getBaseUrl()) + .modelName("ggml-gpt4all-j") + .maxTokens(3) + .logRequests(true) + .logResponses(true) + .build(); + + static class Simple_Prompt { + public static void main(String[] args) { + String answer = model.generate("better go home and weave a net than to stand by the pond longing for fish."); + + System.out.println(answer); + } + } + + static class Simple_Message_Prompt { + public static void main(String[] args) { + UserMessage userMessage = UserMessage.from("better go home and weave a net than to stand by the pond longing for fish."); + List messages = Collections.singletonList(userMessage); + Response response = model.generate(messages); + + System.out.println(response); + } + } +} diff --git a/local-ai-example/src/main/java/LocalAiEmbeddingModelExamples.java b/local-ai-example/src/main/java/LocalAiEmbeddingModelExamples.java new file mode 100644 index 00000000..6660fc72 --- /dev/null +++ b/local-ai-example/src/main/java/LocalAiEmbeddingModelExamples.java @@ -0,0 +1,37 @@ +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.localai.LocalAiEmbeddingModel; +import dev.langchain4j.model.output.Response; +import org.testcontainers.shaded.com.google.common.collect.Lists; + +import java.util.List; + +public class LocalAiEmbeddingModelExamples extends AbstractLocalAiInfrastructure { + + static EmbeddingModel embeddingModel = LocalAiEmbeddingModel.builder() + .baseUrl(localAi.getBaseUrl()) + .modelName("ggml-model-q4_0") + .logRequests(true) + .logResponses(true) + .build(); + + static class Simple_Embed { + public static void main(String[] args) { + Response response = embeddingModel.embed("better go home and weave a net than to stand by the pond longing for fish."); + + System.out.println(response.content()); + } + } + + static class List_Embed { + public static void main(String[] args) { + TextSegment textSegment1 = TextSegment.from("better go home and weave a net than "); + TextSegment textSegment2 = TextSegment.from("to stand by the pond longing for fish."); + Response> listResponse = embeddingModel.embedAll(Lists.newArrayList(textSegment1, textSegment2)); + + listResponse.content().stream().map(Embedding::dimension).forEach(System.out::println); + } + } + +} diff --git a/local-ai-example/src/main/java/LocalAiStreamingChatModelExamples.java b/local-ai-example/src/main/java/LocalAiStreamingChatModelExamples.java new file mode 100644 index 00000000..6ddabc4d --- /dev/null +++ b/local-ai-example/src/main/java/LocalAiStreamingChatModelExamples.java @@ -0,0 +1,40 @@ +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.localai.LocalAiStreamingChatModel; +import dev.langchain4j.model.output.Response; + +public class LocalAiStreamingChatModelExamples extends AbstractLocalAiInfrastructure { + + + static StreamingChatLanguageModel model = LocalAiStreamingChatModel.builder() + .baseUrl(localAi.getBaseUrl()) + .modelName("ggml-gpt4all-j") + .maxTokens(50) + .logRequests(true) + .logResponses(true) + .build(); + + static class Simple_Prompt { + public static void main(String[] args) { + + model.generate("Tell me a poem by Li Bai", new StreamingResponseHandler() { + + @Override + public void onNext(String token) { + System.out.println("onNext(): " + token); + } + + @Override + public void onComplete(Response response) { + System.out.println("onComplete(): " + response); + } + + @Override + public void onError(Throwable error) { + error.printStackTrace(); + } + }); + } + } +}