1
1
package ai .azure .openai .rag .workshop .backend ;
2
2
3
+ import dev .langchain4j .data .embedding .Embedding ;
4
+ import dev .langchain4j .data .message .AiMessage ;
5
+ import dev .langchain4j .data .message .ChatMessage ;
6
+ import dev .langchain4j .data .message .SystemMessage ;
7
+ import dev .langchain4j .data .message .UserMessage ;
8
+ import dev .langchain4j .data .segment .TextSegment ;
3
9
import dev .langchain4j .model .chat .ChatLanguageModel ;
10
+ import dev .langchain4j .model .embedding .AllMiniLmL6V2EmbeddingModel ;
11
+ import dev .langchain4j .model .embedding .EmbeddingModel ;
4
12
import dev .langchain4j .model .openai .OpenAiChatModel ;
5
13
import dev .langchain4j .model .openai .OpenAiModelName ;
14
+ import dev .langchain4j .store .embedding .EmbeddingMatch ;
15
+ import dev .langchain4j .store .embedding .EmbeddingStore ;
16
+ import dev .langchain4j .store .embedding .qdrant .QdrantEmbeddingStore ;
6
17
import jakarta .ws .rs .Consumes ;
7
18
import jakarta .ws .rs .POST ;
8
19
import jakarta .ws .rs .Path ;
9
20
import jakarta .ws .rs .Produces ;
10
- import jakarta . ws . rs . core .Response ;
21
+ import dev . langchain4j . model . output .Response ;
11
22
import static java .time .Duration .ofSeconds ;
23
+ import org .slf4j .Logger ;
24
+ import org .slf4j .LoggerFactory ;
25
+
26
+ import java .util .ArrayList ;
27
+ import java .util .List ;
12
28
13
29
@ Path ("/chat" )
14
30
public class ChatResource {
15
31
32
+ private static final Logger log = LoggerFactory .getLogger (ChatResource .class );
33
+
16
34
@ POST
17
35
@ Consumes ({"application/json" })
18
36
@ Produces ({"application/json" })
19
- public Response chat (ChatRequest chatRequest ) {
37
+ public String chat (ChatRequest chatRequest ) {
38
+
39
+ String question = chatRequest .messages .get (0 ).content ;
40
+
41
+ log .info ("### Embed the question (convert the question into vectors that represent the meaning) using embeddedQuestion model" );
42
+ EmbeddingModel embeddingModel = new AllMiniLmL6V2EmbeddingModel ();
43
+ Embedding embeddedQuestion = embeddingModel .embed (question ).content ();
44
+ log .debug ("Vector length: {}" , embeddedQuestion .vector ().length );
45
+
46
+ log .info ("### Find relevant embeddings from Qdrant based on the question" );
47
+ EmbeddingStore <TextSegment > qdrantEmbeddingStore = QdrantEmbeddingStore .builder ()
48
+ .collectionName ("rag-workshop-collection" )
49
+ .host ("localhost" )
50
+ .port (6334 )
51
+ .build ();
52
+
53
+ List <EmbeddingMatch <TextSegment >> relevant = qdrantEmbeddingStore .findRelevant (embeddedQuestion , 3 );
54
+
55
+ log .info ("### Builds chat history using the relevant embeddings" );
56
+ List <ChatMessage > chatMessages = new ArrayList <>();
57
+ for (int i = 0 ; i < relevant .size (); i ++) {
58
+ EmbeddingMatch <TextSegment > textSegmentEmbeddingMatch = relevant .get (i );
59
+ chatMessages .add (SystemMessage .from (textSegmentEmbeddingMatch .embedded ().text ()));
60
+ log .debug ("Relevant segment {}: {}" , i , textSegmentEmbeddingMatch .embedded ().text ());
61
+ }
62
+
63
+ log .info ("### Invoke the LLM" );
64
+ chatMessages .add (UserMessage .from (question ));
20
65
21
66
ChatLanguageModel model = OpenAiChatModel .builder ()
22
67
.apiKey (System .getenv ("OPENAI_API_KEY" ))
@@ -27,9 +72,8 @@ public Response chat(ChatRequest chatRequest) {
27
72
.logResponses (true )
28
73
.build ();
29
74
30
- String response = model .generate (chatRequest .messages .get (0 ).content );
31
-
32
- return Response .ok ().entity (response ).build ();
75
+ Response <AiMessage > response = model .generate (chatMessages );
33
76
77
+ return response .content ().text ();
34
78
}
35
79
}
0 commit comments