Skip to content

Commit 7ab1205

Browse files
committed
feat: complete nodes and edges
work on #6
1 parent 383476f commit 7ab1205

File tree

10 files changed

+433
-193
lines changed

10 files changed

+433
-193
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,13 @@
11
package dev.langchain4j.adaptiverag;
22

3-
import dev.langchain4j.data.embedding.Embedding;
43
import dev.langchain4j.data.segment.TextSegment;
5-
import dev.langchain4j.model.chat.ChatLanguageModel;
6-
import dev.langchain4j.model.input.Prompt;
7-
import dev.langchain4j.model.input.structured.StructuredPrompt;
8-
import dev.langchain4j.model.input.structured.StructuredPromptProcessor;
9-
import dev.langchain4j.model.openai.OpenAiChatModel;
10-
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
11-
import dev.langchain4j.service.AiServices;
12-
import dev.langchain4j.service.UserMessage;
13-
import dev.langchain4j.service.V;
14-
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
154
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
16-
import dev.langchain4j.store.embedding.chroma.ChromaEmbeddingStore;
17-
import lombok.Value;
185
import lombok.var;
196
import org.bsc.langgraph4j.state.AgentState;
20-
import org.bsc.langgraph4j.state.AppendableValue;
21-
import org.bsc.langgraph4j.utils.CollectionsUtils;
227

23-
import java.time.Duration;
24-
import java.util.ArrayList;
258
import java.util.List;
269
import java.util.Map;
10+
import java.util.Objects;
2711
import java.util.Optional;
2812
import java.util.stream.Collectors;
2913

@@ -46,60 +30,43 @@ public State(Map<String, Object> initData) {
4630
super(initData);
4731
}
4832

49-
public Optional<String> question() {
50-
return value("question");
33+
public String question() {
34+
Optional<String> result = value("question");
35+
return result.orElseThrow( () -> new IllegalStateException( "question is not set!" ) );
5136
}
52-
public Optional<String> generation() {
53-
return value("generation");
37+
public String generation() {
38+
Optional<String> result = value("generation");
39+
return result.orElseThrow( () -> new IllegalStateException( "generation is not set!" ) );
40+
5441
}
5542
public List<String> documents() {
56-
return (List<String>) value("documents").orElse(emptyList());
43+
Optional<List<String>> result = value("documents");
44+
return result.orElse(emptyList());
5745
}
5846

5947
}
6048

6149
private final String openApiKey;
6250
private final String tavilyApiKey;
63-
private final ChromaEmbeddingStore chroma = new ChromaEmbeddingStore(
64-
"http://localhost:8000",
65-
"rag-chroma",
66-
Duration.ofMinutes(2) );
67-
private final OpenAiEmbeddingModel embeddingModel;
51+
private final ChromaStore chroma;
6852

6953
public AdaptiveRag( String openApiKey, String tavilyApiKey ) {
7054
this.openApiKey = openApiKey;
7155
this.tavilyApiKey = tavilyApiKey;
72-
73-
this.embeddingModel = OpenAiEmbeddingModel.builder()
74-
.apiKey(openApiKey)
75-
.build();
76-
77-
}
78-
79-
private EmbeddingSearchResult<TextSegment> retrieverSearch( String question ) {
80-
81-
Embedding queryEmbedding = embeddingModel.embed(question).content();
82-
83-
EmbeddingSearchRequest query = EmbeddingSearchRequest.builder()
84-
.queryEmbedding( queryEmbedding )
85-
.maxResults( 1 )
86-
.minScore( 0.0 )
87-
.build();
88-
return chroma.search( query );
56+
this.chroma = ChromaStore.of(openApiKey);
8957

9058
}
9159

9260
/**
93-
* Retrieve documents
94-
* @param state
95-
* @return
61+
* Node: Retrieve documents
62+
* @param state The current graph state
63+
* @return New key added to state, documents, that contains retrieved documents
9664
*/
9765
public Map<String,Object> retrieve( State state ) {
9866

99-
String question = state.question()
100-
.orElseThrow( () -> new IllegalStateException( "question is null!" ) );
67+
String question = state.question();
10168

102-
EmbeddingSearchResult<TextSegment> relevant = retrieverSearch( question );
69+
EmbeddingSearchResult<TextSegment> relevant = this.chroma.search( question );
10370

10471
List<String> documents = relevant.matches().stream()
10572
.map( m -> m.embedded().text() )
@@ -108,60 +75,37 @@ public Map<String,Object> retrieve( State state ) {
10875
return mapOf( "documents", documents , "question", question );
10976
}
11077

111-
public interface RagService {
112-
113-
@UserMessage("You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.\n" +
114-
"Question: {{question}} \n" +
115-
"Context: {{context}} \n" +
116-
"Answer:")
117-
String invoke(@V("question") String question, @V("context") List<String> context );
118-
}
119-
12078
/**
121-
* Generate answer
79+
* Node: Generate answer
12280
*
123-
* @param state
124-
* @return
81+
* @param state The current graph state
82+
* @return New key added to state, generation, that contains LLM generation
12583
*/
12684
public Map<String,Object> generate( State state ) {
127-
String question = state.question()
128-
.orElseThrow( () -> new IllegalStateException( "question is null!" ) );
85+
String question = state.question();
12986
List<String> documents = state.documents();
13087

131-
ChatLanguageModel chatLanguageModel = OpenAiChatModel.builder()
132-
.apiKey( openApiKey )
133-
.modelName( "gpt-3.5-turbo" )
134-
.timeout(Duration.ofMinutes(2))
135-
.logRequests(true)
136-
.logResponses(true)
137-
.maxRetries(2)
138-
.temperature(0.0)
139-
.maxTokens(2000)
140-
.build();
141-
142-
RagService service = AiServices.create(RagService.class, chatLanguageModel);
143-
144-
String generation = service.invoke( question, documents ); // service
88+
String generation = Generation.of(openApiKey).apply(question, documents); // service
14589

14690
return mapOf("generation", generation);
14791
}
14892

14993
/**
150-
* Determines whether the retrieved documents are relevant to the question.
151-
* @param state
152-
* @return
94+
* Node: Determines whether the retrieved documents are relevant to the question.
95+
* @param state The current graph state
96+
* @return Updates documents key with only filtered relevant documents
15397
*/
15498
public Map<String,Object> gradeDocuments( State state ) {
15599

156-
String question = state.question()
157-
.orElseThrow( () -> new IllegalStateException( "question is null!" ) );
100+
String question = state.question();
101+
158102
List<String> documents = state.documents();
159103

160104
final RetrievalGrader grader = RetrievalGrader.of( openApiKey );
161105

162106
List<String> filteredDocs = documents.stream()
163107
.filter( d -> {
164-
var score = grader.apply( new RetrievalGrader.Arguments(question, d ));
108+
var score = grader.apply( RetrievalGrader.Arguments.of(question, d ));
165109
return score.binaryScore.equals("yes");
166110
})
167111
.collect(Collectors.toList());
@@ -170,28 +114,25 @@ public Map<String,Object> gradeDocuments( State state ) {
170114
}
171115

172116
/**
173-
* Transform the query to produce a better question.
174-
* @param state
175-
* @return
117+
* Node: Transform the query to produce a better question.
118+
* @param state The current graph state
119+
* @return Updates question key with a re-phrased question
176120
*/
177121
public Map<String,Object> transformQuery( State state ) {
178-
String question = state.question()
179-
.orElseThrow( () -> new IllegalStateException( "question is null!" ) );
180-
List<String> documents = state.documents();
122+
String question = state.question();
181123

182124
String betterQuestion = QuestionRewriter.of( openApiKey ).apply( question );
183125

184126
return mapOf( "question", betterQuestion );
185127
}
186128

187129
/**
188-
* Web search based on the re-phrased question.
189-
* @param state
190-
* @return
130+
* Node: Web search based on the re-phrased question.
131+
* @param state The current graph state
132+
* @return Updates documents key with appended web results
191133
*/
192134
public Map<String,Object> webSearch( State state ) {
193-
String question = state.question()
194-
.orElseThrow( () -> new IllegalStateException( "question is null!" ) );
135+
String question = state.question();
195136

196137
var result = WebSearchTool.of( tavilyApiKey ).apply(question);
197138

@@ -201,4 +142,58 @@ public Map<String,Object> webSearch( State state ) {
201142

202143
return mapOf( "documents", listOf( webResult ) );
203144
}
145+
146+
/**
147+
* Edge: Route question to web search or RAG.
148+
* @param state The current graph state
149+
* @return Next node to call
150+
*/
151+
public String routeQuestion( State state ) {
152+
String question = state.question();
153+
154+
var source = QuestionRouter.of( openApiKey ).apply( question );
155+
156+
return source.name();
157+
}
158+
159+
/**
160+
* Edge: Determines whether to generate an answer, or re-generate a question.
161+
* @param state The current graph state
162+
* @return Binary decision for next node to call
163+
*/
164+
public String decideToGenerate( State state ) {
165+
List<String> documents = state.documents();
166+
167+
if(documents.isEmpty()) {
168+
return "transform_query";
169+
}
170+
return "generate";
171+
}
172+
173+
/**
174+
* Edge: Determines whether the generation is grounded in the document and answers question.
175+
* @param state The current graph state
176+
* @return Decision for next node to call
177+
*/
178+
public String gradeGeneration_v_DocumentsAndQuestion( State state ) {
179+
String question = state.question();
180+
List<String> documents = state.documents();
181+
String generation = state.generation();
182+
183+
HallucinationGrader.Score score = HallucinationGrader.of( openApiKey )
184+
.apply( HallucinationGrader.Arguments.of(documents, generation));
185+
186+
if(Objects.equals(score.binaryScore, "yes")) {
187+
188+
AnswerGrader.Score score2 = AnswerGrader.of( openApiKey )
189+
.apply( AnswerGrader.Arguments.of(question, generation) );
190+
if( Objects.equals( score2.binaryScore, "yes") ) {
191+
return "useful";
192+
}
193+
194+
return "not useful";
195+
}
196+
197+
return "not supported";
198+
}
204199
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package dev.langchain4j.adaptiverag;
2+
3+
import dev.langchain4j.model.chat.ChatLanguageModel;
4+
import dev.langchain4j.model.input.Prompt;
5+
import dev.langchain4j.model.input.structured.StructuredPrompt;
6+
import dev.langchain4j.model.input.structured.StructuredPromptProcessor;
7+
import dev.langchain4j.model.openai.OpenAiChatModel;
8+
import dev.langchain4j.model.output.structured.Description;
9+
import dev.langchain4j.service.AiServices;
10+
import dev.langchain4j.service.SystemMessage;
11+
import lombok.Value;
12+
13+
import java.time.Duration;
14+
import java.util.function.Function;
15+
16+
@Value(staticConstructor="of")
17+
public class AnswerGrader implements Function<AnswerGrader.Arguments,AnswerGrader.Score> {
18+
/**
19+
* Binary score to assess answer addresses question.
20+
*/
21+
public static class Score {
22+
23+
@Description("Answer addresses the question, 'yes' or 'no'")
24+
public String binaryScore;
25+
}
26+
27+
@StructuredPrompt("User question: \\n\\n {question} \\n\\n LLM generation: {generation}")
28+
@Value(staticConstructor="of")
29+
public static class Arguments {
30+
String question;
31+
String generation;
32+
}
33+
34+
interface Service {
35+
36+
@SystemMessage("You are a grader assessing whether an answer addresses / resolves a question \\n \n" +
37+
" Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question.")
38+
Score invoke(String userMessage);
39+
}
40+
41+
String openApiKey;
42+
43+
@Override
44+
public Score apply(Arguments args) {
45+
ChatLanguageModel chatLanguageModel = OpenAiChatModel.builder()
46+
.apiKey( openApiKey )
47+
.modelName( "gpt-3.5-turbo-0125" )
48+
.timeout(Duration.ofMinutes(2))
49+
.logRequests(true)
50+
.logResponses(true)
51+
.maxRetries(2)
52+
.temperature(0.0)
53+
.maxTokens(2000)
54+
.build();
55+
56+
57+
Service service = AiServices.create(Service.class, chatLanguageModel);
58+
59+
Prompt prompt = StructuredPromptProcessor.toPrompt(args);
60+
61+
return service.invoke(prompt.text());
62+
}
63+
64+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package dev.langchain4j.adaptiverag;
2+
3+
import dev.langchain4j.data.embedding.Embedding;
4+
import dev.langchain4j.data.segment.TextSegment;
5+
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
6+
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
7+
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
8+
import dev.langchain4j.store.embedding.chroma.ChromaEmbeddingStore;
9+
import lombok.Value;
10+
11+
import java.time.Duration;
12+
13+
public final class ChromaStore {
14+
public static ChromaStore of(String openApiKey) {
15+
return new ChromaStore(openApiKey);
16+
}
17+
18+
private final ChromaEmbeddingStore chroma = new ChromaEmbeddingStore(
19+
"http://localhost:8000",
20+
"rag-chroma",
21+
Duration.ofMinutes(2) );
22+
private final OpenAiEmbeddingModel embeddingModel;
23+
24+
private ChromaStore( String openApiKey ) {
25+
this.embeddingModel = OpenAiEmbeddingModel.builder()
26+
.apiKey(openApiKey)
27+
.build();
28+
}
29+
30+
public EmbeddingSearchResult<TextSegment> search(String query) {
31+
32+
Embedding queryEmbedding = embeddingModel.embed(query).content();
33+
34+
EmbeddingSearchRequest searchRequest = EmbeddingSearchRequest.builder()
35+
.queryEmbedding( queryEmbedding )
36+
.maxResults( 1 )
37+
.minScore( 0.0 )
38+
.build();
39+
return chroma.search( searchRequest );
40+
41+
}
42+
}

0 commit comments

Comments
 (0)