Skip to content

Commit 1497053

Browse files
committed
feat(iamge_to_diagram): add sub-graph for error review
1 parent 5ee97bb commit 1497053

File tree

3 files changed

+353
-236
lines changed

3 files changed

+353
-236
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
package dev.langchain4j.image_to_diagram;
2+
3+
import dev.langchain4j.data.message.SystemMessage;
4+
import dev.langchain4j.model.input.Prompt;
5+
import dev.langchain4j.model.openai.OpenAiChatModel;
6+
import lombok.Getter;
7+
import lombok.extern.slf4j.Slf4j;
8+
import lombok.var;
9+
import org.bsc.async.AsyncGenerator;
10+
import org.bsc.langgraph4j.GraphState;
11+
import org.bsc.langgraph4j.NodeOutput;
12+
13+
import java.util.Map;
14+
import java.util.concurrent.CompletableFuture;
15+
16+
import static java.util.Optional.ofNullable;
17+
import static org.bsc.langgraph4j.GraphState.END;
18+
import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async;
19+
import static org.bsc.langgraph4j.utils.CollectionsUtils.mapOf;
20+
21+
@Slf4j( topic="DiagramCorrectionProcess" )
22+
public class DiagramCorrectionProcess implements ImageToDiagram {
23+
24+
@Getter(lazy = true)
25+
private final OpenAiChatModel LLM = newLLM();
26+
27+
private OpenAiChatModel newLLM( ) {
28+
var openApiKey = ofNullable( System.getProperty("OPENAI_API_KEY") )
29+
.orElseThrow( () -> new IllegalArgumentException("no OPENAI_API_KEY provided!") );
30+
31+
return OpenAiChatModel.builder()
32+
.apiKey( openApiKey )
33+
.modelName( "gpt-3.5-turbo" )
34+
.logRequests(true)
35+
.logResponses(true)
36+
.maxRetries(2)
37+
.temperature(0.0)
38+
.maxTokens(2000)
39+
.build();
40+
41+
}
42+
43+
CompletableFuture<Map<String,Object>> reviewResult(State state) {
44+
CompletableFuture<Map<String,Object>> future = new CompletableFuture<>();
45+
try {
46+
47+
var diagramCode = state.diagramCode().last()
48+
.orElseThrow(() -> new IllegalArgumentException("no diagram code provided!"));
49+
50+
var error = state.evaluationError()
51+
.orElseThrow(() -> new IllegalArgumentException("no evaluation error provided!"));
52+
53+
log.trace("evaluation error: {}", error);
54+
55+
Prompt systemPrompt = loadPromptTemplate( "review_diagram.txt" )
56+
.apply( mapOf( "evaluationError", error, "diagramCode", diagramCode));
57+
var response = getLLM().generate( new SystemMessage(systemPrompt.text()) );
58+
59+
var result = response.content().text();
60+
61+
log.trace("review result: {}", result);
62+
63+
future.complete(mapOf("diagramCode", result ) );
64+
65+
} catch (Exception e) {
66+
future.completeExceptionally(e);
67+
}
68+
69+
return future;
70+
}
71+
72+
private CompletableFuture<Map<String,Object>> evaluateResult(State state) {
73+
74+
var diagramCode = state.diagramCode().last()
75+
.orElseThrow(() -> new IllegalArgumentException("no diagram code provided!"));
76+
77+
return PlantUMLAction.validate( diagramCode )
78+
.thenApply( v -> mapOf( "evaluationResult", (Object) EvaluationResult.OK ) )
79+
.exceptionally( e -> {
80+
if( e.getCause() instanceof PlantUMLAction.Error ) {
81+
return mapOf("evaluationResult", EvaluationResult.ERROR,
82+
"evaluationError", e.getCause().getMessage(),
83+
"evaluationErrorType", ((PlantUMLAction.Error)e.getCause()).getType());
84+
}
85+
throw new RuntimeException(e);
86+
});
87+
88+
}
89+
90+
private String routeEvaluationResult( State state ) {
91+
var evaluationResult = state.evaluationResult()
92+
.orElseThrow(() -> new IllegalArgumentException("no evaluationResult provided!"));
93+
94+
if( evaluationResult == EvaluationResult.ERROR ) {
95+
if( state.isExecutionError() ) {
96+
log.warn("evaluation execution error: [{}]", state.evaluationError().orElse("unknown") );
97+
return EvaluationResult.UNKNOWN.name();
98+
}
99+
if (state.lastTwoDiagramsAreEqual()) {
100+
log.warn("correction failed! ");
101+
return EvaluationResult.UNKNOWN.name();
102+
}
103+
}
104+
105+
return evaluationResult.name();
106+
};
107+
108+
@Override
109+
public AsyncGenerator<NodeOutput<State>> execute(Map<String, Object> inputs) throws Exception {
110+
111+
var workflow = new GraphState<>(State::new);
112+
113+
workflow.addNode( "evaluate_result", this::evaluateResult);
114+
workflow.addNode( "agent_review", this::reviewResult );
115+
workflow.addEdge( "agent_review", "evaluate_result" );
116+
workflow.addConditionalEdges(
117+
"evaluate_result",
118+
edge_async(this::routeEvaluationResult),
119+
mapOf( "OK", END,
120+
"ERROR", "agent_review",
121+
"UNKNOWN", END )
122+
);
123+
workflow.setEntryPoint("evaluate_result");
124+
125+
var app = workflow.compile();
126+
127+
return app.stream( inputs );
128+
129+
}
130+
}

0 commit comments

Comments
 (0)