Skip to content

Commit d9bf1a3

Browse files
committed
fix: Agent extend conversation with assistant's reply
1 parent 826ccfa commit d9bf1a3

File tree

4 files changed

+39
-47
lines changed

4 files changed

+39
-47
lines changed

agents/src/main/java/dev/langchain4j/Agent.java

+18-31
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import dev.langchain4j.data.message.SystemMessage;
77
import dev.langchain4j.data.message.UserMessage;
88
import dev.langchain4j.data.message.ToolExecutionResultMessage;
9+
910
import dev.langchain4j.model.chat.ChatLanguageModel;
1011
import dev.langchain4j.model.input.PromptTemplate;
1112
import dev.langchain4j.model.output.Response;
@@ -16,53 +17,39 @@
1617
import java.util.List;
1718
import java.util.Map;
1819

20+
import static java.lang.String.format;
21+
1922
@Builder
2023
public class Agent {
2124

2225
private final ChatLanguageModel chatLanguageModel;
2326
@Singular private final List<ToolSpecification> tools;
2427

2528

26-
public Response<AiMessage> execute( Map<String,Object> inputs ) {
27-
var messages = new ArrayList<ChatMessage>();
28-
var promptTemplate = PromptTemplate.from( "USER: {{input}}" ).apply(inputs);
29-
30-
messages.add(new SystemMessage("You are a helpful assistant"));
31-
32-
messages.add( new UserMessage(promptTemplate.text()) );
33-
34-
return chatLanguageModel.generate( messages, tools );
35-
}
36-
37-
private PromptTemplate getToolResponseTemplate( ) {
38-
var TEMPLATE_TOOL_RESPONSE = """
39-
TOOL RESPONSE:
40-
---------------------
41-
{{observation}}
42-
--------------------
43-
""";
44-
return PromptTemplate.from(TEMPLATE_TOOL_RESPONSE);
45-
}
46-
4729
public Response<AiMessage> execute( String input, List<AgentExecutor.IntermediateStep> intermediateSteps ) {
48-
var agentScratchpadTemplate = getToolResponseTemplate();
49-
var userMessageTemplate = PromptTemplate.from( "USER'S INPUT: {{input}}" )
30+
var userMessageTemplate = PromptTemplate.from( "{{input}}" )
5031
.apply( Map.of( "input", input));
5132

5233
var messages = new ArrayList<ChatMessage>();
5334

5435
messages.add(new SystemMessage("You are a helpful assistant"));
36+
messages.add(new UserMessage(userMessageTemplate.text()));
5537

56-
if( intermediateSteps.isEmpty()) {
57-
messages.add(new UserMessage(userMessageTemplate.text()));
58-
}
38+
if (!intermediateSteps.isEmpty()) {
5939

60-
for( AgentExecutor.IntermediateStep step: intermediateSteps ) {
61-
var agentScratchpad = agentScratchpadTemplate
62-
.apply( Map.of("observation", step.observation()) );
63-
messages.add(new UserMessage(agentScratchpad.text()));
64-
}
40+
var toolRequests = intermediateSteps.stream()
41+
.map(AgentExecutor.IntermediateStep::action)
42+
.map(AgentExecutor.AgentAction::toolExecutionRequest)
43+
.toList();
44+
45+
messages.add(new AiMessage(toolRequests)); // reply with tool requests
6546

47+
for (AgentExecutor.IntermediateStep step : intermediateSteps) {
48+
var toolRequest = step.action().toolExecutionRequest();
49+
50+
messages.add(new ToolExecutionResultMessage(toolRequest.id(), toolRequest.name(), step.observation()));
51+
}
52+
}
6653
return chatLanguageModel.generate( messages, tools );
6754
}
6855
}

agents/src/main/java/dev/langchain4j/AgentExecutor.java

+2-13
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import dev.langchain4j.agent.tool.ToolExecutionRequest;
44
import dev.langchain4j.agent.tool.ToolSpecification;
5+
import dev.langchain4j.model.chat.ChatLanguageModel;
56
import dev.langchain4j.model.openai.OpenAiChatModel;
67
import dev.langchain4j.model.output.FinishReason;
78
import org.bsc.langgraph4j.GraphState;
@@ -117,21 +118,9 @@ String shouldContinue(State state) {
117118
return "continue";
118119
}
119120

120-
public AsyncIterator<GraphState.Runnable.NodeOutput<State>> execute(Map<String, Object> inputs, List<Object> objectsWithTools) throws Exception {
121+
public AsyncIterator<GraphState.Runnable.NodeOutput<State>> execute(ChatLanguageModel chatLanguageModel, Map<String, Object> inputs, List<Object> objectsWithTools) throws Exception {
121122

122123

123-
var openApiKey = DotEnvConfig.valueOf("OPENAI_API_KEY")
124-
.orElseThrow( () -> new IllegalArgumentException("no APIKEY provided!"));
125-
126-
var chatLanguageModel = OpenAiChatModel.builder()
127-
.apiKey( openApiKey )
128-
.modelName( "gpt-3.5-turbo-0613" )
129-
.logResponses(true)
130-
.maxRetries(2)
131-
.temperature(0.0)
132-
.maxTokens(2000)
133-
.build();
134-
135124
var toolInfoList = ToolInfo.fromList( objectsWithTools );
136125

137126
final List<ToolSpecification> toolSpecifications = toolInfoList.stream()

agents/src/test/java/dev/langchain4j/AgentExecutorTest.java

+17-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import dev.langchain4j.agent.tool.P;
44
import dev.langchain4j.agent.tool.Tool;
5+
import dev.langchain4j.model.openai.OpenAiChatModel;
56

67
import java.util.List;
78
import java.util.Map;
@@ -15,11 +16,25 @@ public static void main( String[] args) {
1516

1617
DotEnvConfig.load();
1718

18-
var agentExecutor = new AgentExecutor();
19+
var openApiKey = DotEnvConfig.valueOf("OPENAI_API_KEY")
20+
.orElseThrow( () -> new IllegalArgumentException("no APIKEY provided!"));
21+
22+
var chatLanguageModel = OpenAiChatModel.builder()
23+
.apiKey( openApiKey )
24+
.modelName( "gpt-3.5-turbo-0125" )
25+
.logResponses(true)
26+
.maxRetries(2)
27+
.temperature(0.0)
28+
.maxTokens(2000)
29+
.build();
1930

2031
try {
32+
var agentExecutor = new AgentExecutor();
33+
2134
var iterator = agentExecutor.execute(
22-
Map.of( "input", "what is the result of test with message: 'MY FIRST TEST'?"),
35+
chatLanguageModel,
36+
//Map.of( "input", "what is the result of test with messages: 'MY FIRST TEST' and the result of test with message: 'MY SECOND TEST'"),
37+
Map.of( "input", "what is the result of test with messages: 'MY FIRST TEST'"),
2338
List.of(new TestTool()) );
2439

2540
AgentExecutor.State output = null;

agents/src/test/java/dev/langchain4j/AgentTest.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import org.junit.jupiter.api.BeforeAll;
77
import dev.langchain4j.model.openai.OpenAiChatModel;
88

9+
import java.util.List;
910
import java.util.Map;
1011
import java.util.Optional;
1112

@@ -40,7 +41,7 @@ public static void main( String[] args) throws Exception {
4041
.build();
4142

4243
var msg = "hello world";
43-
var response = agent.execute( Map.of("input", format("this is an AI test with message: '%s'", msg) ));
44+
var response = agent.execute( format("this is an AI test with message: '%s'", msg), List.of() );
4445

4546
assertNotNull(response);
4647
assertEquals(response.finishReason(), FinishReason.TOOL_EXECUTION );

0 commit comments

Comments
 (0)