1
1
package dev .langchain4j .adaptiverag ;
2
2
3
- import dev .langchain4j .data .embedding .Embedding ;
4
3
import dev .langchain4j .data .segment .TextSegment ;
5
- import dev .langchain4j .model .chat .ChatLanguageModel ;
6
- import dev .langchain4j .model .input .Prompt ;
7
- import dev .langchain4j .model .input .structured .StructuredPrompt ;
8
- import dev .langchain4j .model .input .structured .StructuredPromptProcessor ;
9
- import dev .langchain4j .model .openai .OpenAiChatModel ;
10
- import dev .langchain4j .model .openai .OpenAiEmbeddingModel ;
11
- import dev .langchain4j .service .AiServices ;
12
- import dev .langchain4j .service .UserMessage ;
13
- import dev .langchain4j .service .V ;
14
- import dev .langchain4j .store .embedding .EmbeddingSearchRequest ;
15
4
import dev .langchain4j .store .embedding .EmbeddingSearchResult ;
16
- import dev .langchain4j .store .embedding .chroma .ChromaEmbeddingStore ;
17
- import lombok .Value ;
18
5
import lombok .var ;
19
6
import org .bsc .langgraph4j .state .AgentState ;
20
- import org .bsc .langgraph4j .state .AppendableValue ;
21
- import org .bsc .langgraph4j .utils .CollectionsUtils ;
22
7
23
- import java .time .Duration ;
24
- import java .util .ArrayList ;
25
8
import java .util .List ;
26
9
import java .util .Map ;
10
+ import java .util .Objects ;
27
11
import java .util .Optional ;
28
12
import java .util .stream .Collectors ;
29
13
@@ -46,60 +30,43 @@ public State(Map<String, Object> initData) {
46
30
super (initData );
47
31
}
48
32
49
- public Optional <String > question () {
50
- return value ("question" );
33
+ public String question () {
34
+ Optional <String > result = value ("question" );
35
+ return result .orElseThrow ( () -> new IllegalStateException ( "question is not set!" ) );
51
36
}
52
- public Optional <String > generation () {
53
- return value ("generation" );
37
+ public String generation () {
38
+ Optional <String > result = value ("generation" );
39
+ return result .orElseThrow ( () -> new IllegalStateException ( "generation is not set!" ) );
40
+
54
41
}
55
42
public List <String > documents () {
56
- return (List <String >) value ("documents" ).orElse (emptyList ());
43
+ Optional <List <String >> result = value ("documents" );
44
+ return result .orElse (emptyList ());
57
45
}
58
46
59
47
}
60
48
61
49
private final String openApiKey ;
62
50
private final String tavilyApiKey ;
63
- private final ChromaEmbeddingStore chroma = new ChromaEmbeddingStore (
64
- "http://localhost:8000" ,
65
- "rag-chroma" ,
66
- Duration .ofMinutes (2 ) );
67
- private final OpenAiEmbeddingModel embeddingModel ;
51
+ private final ChromaStore chroma ;
68
52
69
53
public AdaptiveRag ( String openApiKey , String tavilyApiKey ) {
70
54
this .openApiKey = openApiKey ;
71
55
this .tavilyApiKey = tavilyApiKey ;
72
-
73
- this .embeddingModel = OpenAiEmbeddingModel .builder ()
74
- .apiKey (openApiKey )
75
- .build ();
76
-
77
- }
78
-
79
- private EmbeddingSearchResult <TextSegment > retrieverSearch ( String question ) {
80
-
81
- Embedding queryEmbedding = embeddingModel .embed (question ).content ();
82
-
83
- EmbeddingSearchRequest query = EmbeddingSearchRequest .builder ()
84
- .queryEmbedding ( queryEmbedding )
85
- .maxResults ( 1 )
86
- .minScore ( 0.0 )
87
- .build ();
88
- return chroma .search ( query );
56
+ this .chroma = ChromaStore .of (openApiKey );
89
57
90
58
}
91
59
92
60
/**
93
- * Retrieve documents
94
- * @param state
95
- * @return
61
+ * Node: Retrieve documents
62
+ * @param state The current graph state
63
+ * @return New key added to state, documents, that contains retrieved documents
96
64
*/
97
65
public Map <String ,Object > retrieve ( State state ) {
98
66
99
- String question = state .question ()
100
- .orElseThrow ( () -> new IllegalStateException ( "question is null!" ) );
67
+ String question = state .question ();
101
68
102
- EmbeddingSearchResult <TextSegment > relevant = retrieverSearch ( question );
69
+ EmbeddingSearchResult <TextSegment > relevant = this . chroma . search ( question );
103
70
104
71
List <String > documents = relevant .matches ().stream ()
105
72
.map ( m -> m .embedded ().text () )
@@ -108,60 +75,37 @@ public Map<String,Object> retrieve( State state ) {
108
75
return mapOf ( "documents" , documents , "question" , question );
109
76
}
110
77
111
- public interface RagService {
112
-
113
- @ UserMessage ("You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.\n " +
114
- "Question: {{question}} \n " +
115
- "Context: {{context}} \n " +
116
- "Answer:" )
117
- String invoke (@ V ("question" ) String question , @ V ("context" ) List <String > context );
118
- }
119
-
120
78
/**
121
- * Generate answer
79
+ * Node: Generate answer
122
80
*
123
- * @param state
124
- * @return
81
+ * @param state The current graph state
82
+ * @return New key added to state, generation, that contains LLM generation
125
83
*/
126
84
public Map <String ,Object > generate ( State state ) {
127
- String question = state .question ()
128
- .orElseThrow ( () -> new IllegalStateException ( "question is null!" ) );
85
+ String question = state .question ();
129
86
List <String > documents = state .documents ();
130
87
131
- ChatLanguageModel chatLanguageModel = OpenAiChatModel .builder ()
132
- .apiKey ( openApiKey )
133
- .modelName ( "gpt-3.5-turbo" )
134
- .timeout (Duration .ofMinutes (2 ))
135
- .logRequests (true )
136
- .logResponses (true )
137
- .maxRetries (2 )
138
- .temperature (0.0 )
139
- .maxTokens (2000 )
140
- .build ();
141
-
142
- RagService service = AiServices .create (RagService .class , chatLanguageModel );
143
-
144
- String generation = service .invoke ( question , documents ); // service
88
+ String generation = Generation .of (openApiKey ).apply (question , documents ); // service
145
89
146
90
return mapOf ("generation" , generation );
147
91
}
148
92
149
93
/**
150
- * Determines whether the retrieved documents are relevant to the question.
151
- * @param state
152
- * @return
94
+ * Node: Determines whether the retrieved documents are relevant to the question.
95
+ * @param state The current graph state
96
+ * @return Updates documents key with only filtered relevant documents
153
97
*/
154
98
public Map <String ,Object > gradeDocuments ( State state ) {
155
99
156
- String question = state .question ()
157
- . orElseThrow ( () -> new IllegalStateException ( "question is null!" ) );
100
+ String question = state .question ();
101
+
158
102
List <String > documents = state .documents ();
159
103
160
104
final RetrievalGrader grader = RetrievalGrader .of ( openApiKey );
161
105
162
106
List <String > filteredDocs = documents .stream ()
163
107
.filter ( d -> {
164
- var score = grader .apply ( new RetrievalGrader .Arguments (question , d ));
108
+ var score = grader .apply ( RetrievalGrader .Arguments . of (question , d ));
165
109
return score .binaryScore .equals ("yes" );
166
110
})
167
111
.collect (Collectors .toList ());
@@ -170,28 +114,25 @@ public Map<String,Object> gradeDocuments( State state ) {
170
114
}
171
115
172
116
/**
173
- * Transform the query to produce a better question.
174
- * @param state
175
- * @return
117
+ * Node: Transform the query to produce a better question.
118
+ * @param state The current graph state
119
+ * @return Updates question key with a re-phrased question
176
120
*/
177
121
public Map <String ,Object > transformQuery ( State state ) {
178
- String question = state .question ()
179
- .orElseThrow ( () -> new IllegalStateException ( "question is null!" ) );
180
- List <String > documents = state .documents ();
122
+ String question = state .question ();
181
123
182
124
String betterQuestion = QuestionRewriter .of ( openApiKey ).apply ( question );
183
125
184
126
return mapOf ( "question" , betterQuestion );
185
127
}
186
128
187
129
/**
188
- * Web search based on the re-phrased question.
189
- * @param state
190
- * @return
130
+ * Node: Web search based on the re-phrased question.
131
+ * @param state The current graph state
132
+ * @return Updates documents key with appended web results
191
133
*/
192
134
public Map <String ,Object > webSearch ( State state ) {
193
- String question = state .question ()
194
- .orElseThrow ( () -> new IllegalStateException ( "question is null!" ) );
135
+ String question = state .question ();
195
136
196
137
var result = WebSearchTool .of ( tavilyApiKey ).apply (question );
197
138
@@ -201,4 +142,58 @@ public Map<String,Object> webSearch( State state ) {
201
142
202
143
return mapOf ( "documents" , listOf ( webResult ) );
203
144
}
145
+
146
+ /**
147
+ * Edge: Route question to web search or RAG.
148
+ * @param state The current graph state
149
+ * @return Next node to call
150
+ */
151
+ public String routeQuestion ( State state ) {
152
+ String question = state .question ();
153
+
154
+ var source = QuestionRouter .of ( openApiKey ).apply ( question );
155
+
156
+ return source .name ();
157
+ }
158
+
159
+ /**
160
+ * Edge: Determines whether to generate an answer, or re-generate a question.
161
+ * @param state The current graph state
162
+ * @return Binary decision for next node to call
163
+ */
164
+ public String decideToGenerate ( State state ) {
165
+ List <String > documents = state .documents ();
166
+
167
+ if (documents .isEmpty ()) {
168
+ return "transform_query" ;
169
+ }
170
+ return "generate" ;
171
+ }
172
+
173
+ /**
174
+ * Edge: Determines whether the generation is grounded in the document and answers question.
175
+ * @param state The current graph state
176
+ * @return Decision for next node to call
177
+ */
178
+ public String gradeGeneration_v_DocumentsAndQuestion ( State state ) {
179
+ String question = state .question ();
180
+ List <String > documents = state .documents ();
181
+ String generation = state .generation ();
182
+
183
+ HallucinationGrader .Score score = HallucinationGrader .of ( openApiKey )
184
+ .apply ( HallucinationGrader .Arguments .of (documents , generation ));
185
+
186
+ if (Objects .equals (score .binaryScore , "yes" )) {
187
+
188
+ AnswerGrader .Score score2 = AnswerGrader .of ( openApiKey )
189
+ .apply ( AnswerGrader .Arguments .of (question , generation ) );
190
+ if ( Objects .equals ( score2 .binaryScore , "yes" ) ) {
191
+ return "useful" ;
192
+ }
193
+
194
+ return "not useful" ;
195
+ }
196
+
197
+ return "not supported" ;
198
+ }
204
199
}
0 commit comments