Skip to content

Commit 199ae8d

Browse files
committed
feat: refine Serialization implementation
- add StateSerializer abstract class that owns a StateFactory - refactor tests, samples and how-tos accordly work on #29
1 parent 72d0e33 commit 199ae8d

File tree

17 files changed

+370
-342
lines changed

17 files changed

+370
-342
lines changed

agent-executor/src/main/java/org/bsc/langgraph4j/agentexecutor/AgentExecutor.java

+8-5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.bsc.langgraph4j.*;
1313
import org.bsc.langgraph4j.langchain4j.serializer.std.ToolExecutionResultMessageSerializer;
1414
import org.bsc.langgraph4j.serializer.Serializer;
15+
import org.bsc.langgraph4j.serializer.StateSerializer;
1516
import org.bsc.langgraph4j.serializer.std.ObjectStreamStateSerializer;
1617
import org.bsc.langgraph4j.state.AgentState;
1718
import org.bsc.langgraph4j.state.AppenderChannel;
@@ -32,7 +33,7 @@ public class AgentExecutor {
3233
public class GraphBuilder {
3334
private ChatLanguageModel chatLanguageModel;
3435
private List<Object> objectsWithTools;
35-
private Serializer<Map<String,Object>> stateSerializer;
36+
private StateSerializer<State> stateSerializer;
3637

3738
public GraphBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) {
3839
this.chatLanguageModel = chatLanguageModel;
@@ -43,7 +44,7 @@ public GraphBuilder objectsWithTools(List<Object> objectsWithTools) {
4344
return this;
4445
}
4546

46-
public GraphBuilder stateSerializer( Serializer<Map<String,Object>> stateSerializer) {
47+
public GraphBuilder stateSerializer( StateSerializer<State> stateSerializer) {
4748
this.stateSerializer = stateSerializer;
4849
return this;
4950
}
@@ -62,16 +63,18 @@ public StateGraph<State> build() throws GraphStateException {
6263
.build();
6364

6465
if( stateSerializer == null ) {
65-
var stateSerializer = new ObjectStreamStateSerializer();
66-
stateSerializer.mapper()
66+
var serializer = new ObjectStreamStateSerializer<>(State::new);
67+
serializer.mapper()
6768
.register(IntermediateStep.class, new IntermediateStepSerializer())
6869
.register(AgentAction.class, new AgentActionSerializer())
6970
.register(AgentFinish.class, new AgentFinishSerializer())
7071
.register(AgentOutcome.class, new AgentOutcomeSerializer())
7172
.register(ToolExecutionResultMessage.class, new ToolExecutionResultMessageSerializer());
73+
74+
stateSerializer = serializer;
7275
}
7376

74-
return new StateGraph<>(State.SCHEMA,State::new, stateSerializer)
77+
return new StateGraph<>(State.SCHEMA, stateSerializer)
7578
.addEdge(START,"agent")
7679
.addNode( "agent", node_async( state ->
7780
callAgent(agentRunnable, state))

agent-executor/src/main/java/org/bsc/langgraph4j/agentexecutor/serializer/json/JSONStateSerializer.java

+10-8
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
import com.fasterxml.jackson.databind.ObjectMapper;
1010
import com.fasterxml.jackson.databind.module.SimpleModule;
1111
import dev.langchain4j.agent.tool.ToolExecutionRequest;
12+
import lombok.NonNull;
1213
import org.bsc.langgraph4j.agentexecutor.*;
1314
import org.bsc.langgraph4j.serializer.plain_text.PlainTextStateSerializer;
15+
import org.bsc.langgraph4j.state.AgentState;
16+
import org.bsc.langgraph4j.state.AgentStateFactory;
1417

1518
import java.io.*;
1619
import java.util.*;
@@ -140,16 +143,16 @@ public AgentExecutor.State deserialize(JsonParser parser, DeserializationContext
140143
}
141144
}
142145

143-
public class JSONStateSerializer extends PlainTextStateSerializer {
146+
public class JSONStateSerializer extends PlainTextStateSerializer<AgentExecutor.State> {
144147

145148
final ObjectMapper objectMapper;
146149

147150
public static JSONStateSerializer of( ObjectMapper objectMapper ) {
148151
return new JSONStateSerializer(objectMapper);
149152
}
150153

151-
private JSONStateSerializer(ObjectMapper objectMapper) {
152-
Objects.requireNonNull(objectMapper, "objectMapper cannot be null");
154+
private JSONStateSerializer( @NonNull ObjectMapper objectMapper) {
155+
super( AgentExecutor.State::new );
153156
this.objectMapper = objectMapper;
154157

155158
var module = new SimpleModule();
@@ -169,16 +172,15 @@ public String mimeType() {
169172
}
170173

171174
@Override
172-
public void write(Map<String,Object> object, ObjectOutput out) throws IOException {
173-
var state = new AgentExecutor.State( object );
174-
var json = objectMapper.writeValueAsString(state);
175+
public void write(AgentExecutor.State object, ObjectOutput out) throws IOException {
176+
var json = objectMapper.writeValueAsString(object);
175177
out.writeUTF(json);
176178
}
177179

178180
@Override
179-
public Map<String,Object> read(ObjectInput in) throws IOException, ClassNotFoundException {
181+
public AgentExecutor.State read(ObjectInput in) throws IOException, ClassNotFoundException {
180182
var json = in.readUTF();
181-
return objectMapper.readValue(json, AgentExecutor.State.class).data();
183+
return objectMapper.readValue(json, AgentExecutor.State.class);
182184
}
183185

184186
}

agent-executor/src/test/java/org/bsc/langgraph4j/agentexecutor/SerializableTest.java

+16-12
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@ public void jsonSerializeTest() throws Exception {
3838
var state = serializer.read(data);
3939

4040
assertNotNull(state);
41-
assertEquals("perform test twice", state.get("input") );
42-
assertNotNull(state.get("intermediate_steps") );
43-
assertInstanceOf( List.class, state.get("intermediate_steps") );
44-
var intermediateSteps = (List<IntermediateStep>)state.get("intermediate_steps");
41+
assertTrue(state.input().isPresent());
42+
assertEquals("perform test twice", state.input().get() );
43+
assertNotNull(state.intermediateSteps());
44+
assertInstanceOf( List.class, state.intermediateSteps() );
45+
var intermediateSteps = state.intermediateSteps();
4546
assertTrue(intermediateSteps.isEmpty());
46-
assertInstanceOf( AgentOutcome.class, state.get("agent_outcome") );
47-
var agentOutcome = (AgentOutcome)state.get("agent_outcome");
47+
assertTrue( state.agentOutcome().isPresent());
48+
assertInstanceOf( AgentOutcome.class, state.agentOutcome().get() );
49+
var agentOutcome = state.agentOutcome().get();
4850
assertNotNull(agentOutcome);
4951
var action = agentOutcome.action();
5052
assertNotNull(action);
@@ -89,16 +91,18 @@ public void jsonSerializeTest2() throws Exception {
8991
var state = serializer.read(data);
9092

9193
assertNotNull(state);
92-
assertEquals("perform test another time", state.get("input") );
93-
assertNotNull(state.get("intermediate_steps") );
94-
assertInstanceOf( List.class, state.get("intermediate_steps") );
95-
var intermediateSteps = (List<IntermediateStep>)state.get("intermediate_steps");
94+
assertTrue(state.input().isPresent());
95+
assertEquals("perform test another time", state.input().get() );
96+
assertNotNull(state.intermediateSteps() );
97+
assertInstanceOf( List.class, state.intermediateSteps() );
98+
var intermediateSteps =state.intermediateSteps();
9699
assertEquals(1,intermediateSteps.size());
97100
var intermediateStep = intermediateSteps.get(0);
98101
assertNotNull(intermediateStep);
99102
assertEquals("test tool executed: perform test once", intermediateStep.observation() );
100-
assertInstanceOf( AgentOutcome.class, state.get("agent_outcome") );
101-
var agentOutcome = (AgentOutcome)state.get("agent_outcome");
103+
assertTrue(state.agentOutcome().isPresent());
104+
assertInstanceOf( AgentOutcome.class, state.agentOutcome().get() );
105+
var agentOutcome = state.agentOutcome().get();
102106
assertNotNull(agentOutcome);
103107
var action = agentOutcome.action();
104108
assertNotNull(action);

core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java

+1-4
Original file line numberDiff line numberDiff line change
@@ -245,10 +245,7 @@ Map<String,Object> getInitialState(Map<String,Object> inputs, RunnableConfig con
245245
}
246246

247247
State cloneState( Map<String,Object> data ) throws IOException, ClassNotFoundException, InstantiationException, IllegalAccessException {
248-
249-
Map<String,Object> newData = stateGraph.getStateSerializer().cloneObject(data);
250-
251-
return stateGraph.getStateFactory().apply(newData);
248+
return stateGraph.getStateSerializer().cloneObject(data);
252249
}
253250

254251

core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java

+14-15
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33

44
import lombok.Getter;
5+
import lombok.NonNull;
56
import org.bsc.langgraph4j.action.AsyncEdgeAction;
67
import org.bsc.langgraph4j.action.AsyncNodeAction;
78
import org.bsc.langgraph4j.serializer.Serializer;
9+
import org.bsc.langgraph4j.serializer.StateSerializer;
810
import org.bsc.langgraph4j.serializer.std.ObjectStreamStateSerializer;
911
import org.bsc.langgraph4j.state.AgentState;
1012
import org.bsc.langgraph4j.state.AgentStateFactory;
@@ -93,32 +95,26 @@ GraphRunnerException exception(String... args) {
9395
private final Map<String, Channel<?>> channels;
9496

9597
@Getter
96-
private final AgentStateFactory<State> stateFactory;
97-
98-
@Getter
99-
private final Serializer<Map<String,Object>> stateSerializer;
98+
private final StateSerializer<State> stateSerializer;
10099

101100
/**
102101
*
103102
* @param channels the state's schema of the graph
104-
* @param stateFactory the factory to create agent states
105103
* @param stateSerializer the serializer to serialize the state
106104
*/
107105
public StateGraph(Map<String, Channel<?>> channels,
108-
AgentStateFactory<State> stateFactory,
109-
Serializer<Map<String,Object>> stateSerializer) {
106+
StateSerializer<State> stateSerializer) {
110107
this.channels = channels;
111-
this.stateFactory = stateFactory;
112-
this.stateSerializer = ( stateSerializer == null ) ? new ObjectStreamStateSerializer() : stateSerializer;
108+
this.stateSerializer = stateSerializer;
113109
}
114110

115111
/**
116-
* Constructs a new StateGraph with the specified state factory.
112+
* Constructs a new StateGraph with the specified serializer.
117113
*
118-
* @param stateFactory the factory to create agent states
114+
* @param stateSerializer the serializer to serialize the state
119115
*/
120-
public StateGraph(AgentStateFactory<State> stateFactory, Serializer<Map<String,Object>> stateSerializer) {
121-
this( mapOf(), stateFactory, stateSerializer );
116+
public StateGraph(@NonNull StateSerializer<State> stateSerializer) {
117+
this( mapOf(), stateSerializer );
122118

123119
}
124120

@@ -128,7 +124,7 @@ public StateGraph(AgentStateFactory<State> stateFactory, Serializer<Map<String,O
128124
* @param stateFactory the factory to create agent states
129125
*/
130126
public StateGraph(AgentStateFactory<State> stateFactory) {
131-
this( mapOf(), stateFactory, null );
127+
this( mapOf(), stateFactory);
132128

133129
}
134130

@@ -138,9 +134,12 @@ public StateGraph(AgentStateFactory<State> stateFactory) {
138134
* @param stateFactory the factory to create agent states
139135
*/
140136
public StateGraph(Map<String, Channel<?>> channels, AgentStateFactory<State> stateFactory) {
141-
this( channels, stateFactory, null );
137+
this( channels, new ObjectStreamStateSerializer<>(stateFactory) );
142138
}
143139

140+
public final AgentStateFactory<State> getStateFactory() {
141+
return stateSerializer.stateFactory();
142+
}
144143

145144
public Map<String, Channel<?>> getChannels() {
146145
return unmodifiableMap(channels);

core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/Serializer.java

+5-5
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ default String mimeType() {
1414

1515
default byte[] writeObject(T object) throws IOException {
1616
Objects.requireNonNull( object, "object cannot be null" );
17-
try( ByteArrayOutputStream baos = new ByteArrayOutputStream() ) {
18-
ObjectOutputStream oas = new ObjectOutputStream(baos);
17+
try( ByteArrayOutputStream stream = new ByteArrayOutputStream() ) {
18+
ObjectOutputStream oas = new ObjectOutputStream(stream);
1919
write(object, oas);
2020
oas.flush();
21-
return baos.toByteArray();
21+
return stream.toByteArray();
2222
}
2323
}
2424

@@ -27,8 +27,8 @@ default T readObject(byte[] bytes) throws IOException, ClassNotFoundException {
2727
if( bytes.length == 0 ) {
2828
throw new IllegalArgumentException("bytes cannot be empty");
2929
}
30-
try( ByteArrayInputStream bais = new ByteArrayInputStream( bytes ) ) {
31-
ObjectInputStream ois = new ObjectInputStream(bais);
30+
try( ByteArrayInputStream stream = new ByteArrayInputStream( bytes ) ) {
31+
ObjectInputStream ois = new ObjectInputStream(stream);
3232
return read(ois);
3333
}
3434
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package org.bsc.langgraph4j.serializer;
2+
3+
import lombok.NonNull;
4+
import org.bsc.langgraph4j.state.AgentState;
5+
import org.bsc.langgraph4j.state.AgentStateFactory;
6+
7+
import java.io.IOException;
8+
import java.util.Map;
9+
import java.util.Objects;
10+
11+
public abstract class StateSerializer<State extends AgentState> implements Serializer<State> {
12+
13+
private final AgentStateFactory<State> stateFactory;
14+
15+
protected StateSerializer( @NonNull AgentStateFactory<State> stateFactory) {
16+
this.stateFactory = stateFactory;
17+
}
18+
19+
public final AgentStateFactory<State> stateFactory() {
20+
return stateFactory;
21+
}
22+
23+
public final State stateOf( @NonNull Map<String,Object> data) {
24+
return stateFactory.apply(data);
25+
}
26+
27+
public final State cloneObject( @NonNull Map<String,Object> data) throws IOException, ClassNotFoundException {
28+
return cloneObject( stateFactory().apply(data) );
29+
}
30+
31+
}

core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/plain_text/PlainTextStateSerializer.java

+12-5
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
11
package org.bsc.langgraph4j.serializer.plain_text;
22

3-
import org.bsc.langgraph4j.serializer.Serializer;
3+
import lombok.NonNull;
4+
import org.bsc.langgraph4j.serializer.StateSerializer;
5+
import org.bsc.langgraph4j.state.AgentState;
6+
import org.bsc.langgraph4j.state.AgentStateFactory;
47

58
import java.io.*;
6-
import java.util.Map;
79

8-
public abstract class PlainTextStateSerializer implements Serializer<Map<String,Object>> {
10+
public abstract class PlainTextStateSerializer<State extends AgentState> extends StateSerializer<State> {
11+
12+
protected PlainTextStateSerializer(@NonNull AgentStateFactory<State> stateFactory) {
13+
super(stateFactory);
14+
}
15+
916
@Override
1017
public String mimeType() {
1118
return "plain/text";
1219
}
1320

14-
public Map<String,Object> read( String data ) throws IOException, ClassNotFoundException {
21+
public State read( String data ) throws IOException, ClassNotFoundException {
1522
ByteArrayOutputStream bytesStream = new ByteArrayOutputStream();
1623

1724
try(ObjectOutputStream out = new ObjectOutputStream( bytesStream )) {
@@ -25,7 +32,7 @@ public Map<String,Object> read( String data ) throws IOException, ClassNotFoundE
2532

2633
}
2734

28-
public Map<String,Object> read( Reader reader ) throws IOException, ClassNotFoundException {
35+
public State read( Reader reader ) throws IOException, ClassNotFoundException {
2936
StringBuilder sb = new StringBuilder();
3037
try (BufferedReader bufferedReader = new BufferedReader(reader)) {
3138
String line;

core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/std/ObjectStreamStateSerializer.java

+10-7
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22

33
import lombok.extern.slf4j.Slf4j;
44
import org.bsc.langgraph4j.serializer.Serializer;
5+
import org.bsc.langgraph4j.serializer.StateSerializer;
6+
import org.bsc.langgraph4j.state.AgentState;
7+
import org.bsc.langgraph4j.state.AgentStateFactory;
58

69
import java.io.IOException;
710
import java.io.ObjectInput;
811
import java.io.ObjectOutput;
912
import java.util.*;
1013

1114
@Slf4j
12-
public class ObjectStreamStateSerializer implements Serializer<Map<String,Object>> {
15+
public class ObjectStreamStateSerializer<State extends AgentState> extends StateSerializer<State> {
1316

1417
static class ListSerializer implements Serializer<List<Object>> {
1518

@@ -92,8 +95,8 @@ public Map<String, Object> read(ObjectInput in) throws IOException, ClassNotFoun
9295
private final SerializerMapper mapper = new SerializerMapper();
9396
private final MapSerializer mapSerializer = new MapSerializer();
9497

95-
public ObjectStreamStateSerializer() {
96-
super();
98+
public ObjectStreamStateSerializer( AgentStateFactory<State> stateFactory ) {
99+
super(stateFactory);
97100
mapper.register( Collection.class, new ListSerializer() );
98101
mapper.register( Map.class, new MapSerializer() );
99102
}
@@ -103,12 +106,12 @@ public SerializerMapper mapper() {
103106
}
104107

105108
@Override
106-
public void write(Map<String, Object> object, ObjectOutput out) throws IOException {
107-
mapSerializer.write(object, mapper.objectOutputWithMapper(out));
109+
public void write(State object, ObjectOutput out) throws IOException {
110+
mapSerializer.write(object.data(), mapper.objectOutputWithMapper(out));
108111
}
109112

110113
@Override
111-
public final Map<String, Object> read(ObjectInput in) throws IOException, ClassNotFoundException {
112-
return Collections.unmodifiableMap(mapSerializer.read( mapper.objectOutputWithMapper(in) ));
114+
public final State read(ObjectInput in) throws IOException, ClassNotFoundException {
115+
return stateFactory().apply(mapSerializer.read( mapper.objectOutputWithMapper(in) ));
113116
}
114117
}

0 commit comments

Comments
 (0)