Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 5d8036d

Browse files
committedOct 28, 2024
feat: add file system checkpoint saver
work on #35
1 parent 3cb700a commit 5d8036d

File tree

4 files changed

+520
-3
lines changed

4 files changed

+520
-3
lines changed
 

‎core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/Checkpoint.java

-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import org.bsc.langgraph4j.state.AgentState;
55
import org.bsc.langgraph4j.state.Channel;
66

7-
import java.io.Externalizable;
87
import java.util.*;
98

109
/**
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
package org.bsc.langgraph4j.checkpoint;
2+
3+
import lombok.NonNull;
4+
import org.bsc.langgraph4j.RunnableConfig;
5+
import org.bsc.langgraph4j.serializer.Serializer;
6+
import org.bsc.langgraph4j.serializer.StateSerializer;
7+
import org.bsc.langgraph4j.serializer.std.NullableObjectSerializer;
8+
import org.bsc.langgraph4j.state.AgentState;
9+
10+
import java.io.*;
11+
import java.nio.file.Files;
12+
import java.nio.file.Path;
13+
import java.nio.file.Paths;
14+
import java.util.*;
15+
16+
import static java.lang.String.format;
17+
18+
public class FileSystemSaver extends MemorySaver {
19+
20+
private final StateSerializer<AgentState> stateSerializer;
21+
private final Path targetFolder;
22+
private final Serializer<Checkpoint> serializer = new NullableObjectSerializer<Checkpoint>() {
23+
24+
@Override
25+
public void write(Checkpoint object, ObjectOutput out) throws IOException {
26+
out.writeUTF( object.getId() );
27+
writeNullableUTF(object.getNodeId(), out);
28+
writeNullableUTF(object.getNextNodeId(), out);
29+
AgentState state = stateSerializer.stateFactory().apply(object.getState());
30+
stateSerializer.write( state, out);
31+
}
32+
33+
@Override
34+
public Checkpoint read(ObjectInput in) throws IOException, ClassNotFoundException {
35+
return Checkpoint.builder()
36+
.id( in.readUTF() )
37+
.state( stateSerializer.read(in) )
38+
.nextNodeId( readNullableUTF(in).orElse(null) )
39+
.nodeId( readNullableUTF(in).orElse(null) )
40+
.build();
41+
}
42+
};
43+
44+
@SuppressWarnings("unchecked")
45+
public FileSystemSaver( @NonNull Path targetFolder, @NonNull StateSerializer<? extends AgentState> stateSerializer) {
46+
File targetFolderAsFile = targetFolder.toFile();
47+
48+
if( targetFolderAsFile.exists() ) {
49+
if (targetFolderAsFile.isFile()) {
50+
throw new IllegalArgumentException( format("targetFolder '%s' must be a folder", targetFolder) ); // TODO: format"targetFolder must be a directory");
51+
}
52+
}
53+
else {
54+
if( !targetFolderAsFile.mkdirs() ) {
55+
throw new IllegalArgumentException( format("targetFolder '%s' cannot be created", targetFolder) ); // TODO: format"targetFolder cannot be created");
56+
}
57+
}
58+
59+
this.targetFolder = targetFolder;
60+
this.stateSerializer = (StateSerializer<AgentState>) stateSerializer;
61+
}
62+
63+
private File getFile(RunnableConfig config) {
64+
return config.threadId()
65+
.map( threadId -> Paths.get( targetFolder.toString(), format( "thread-%s.saver", threadId) ) )
66+
.orElseGet( () -> Paths.get( targetFolder.toString(), "thread-default.saver" ) )
67+
.toFile();
68+
69+
}
70+
private void serialize(@NonNull LinkedList<Checkpoint> checkpoints, @NonNull File outFile ) throws IOException {
71+
72+
try (ObjectOutputStream oos = new ObjectOutputStream(Files.newOutputStream(outFile.toPath())) ) {
73+
74+
oos.writeInt( checkpoints.size() );
75+
for(Checkpoint checkpoint : checkpoints) {
76+
serializer.write(checkpoint, oos);
77+
}
78+
}
79+
}
80+
81+
private void deserialize(@NonNull File file, @NonNull LinkedList<Checkpoint> result) throws IOException, ClassNotFoundException {
82+
83+
try (ObjectInputStream ois = new ObjectInputStream(Files.newInputStream(file.toPath())) ) {
84+
int size = ois.readInt();
85+
for( int i = 0; i < size; i++ ) {
86+
result.add( serializer.read(ois) );
87+
}
88+
}
89+
}
90+
@Override
91+
protected LinkedList<Checkpoint> getCheckpoints(RunnableConfig config) {
92+
LinkedList<Checkpoint> result = super.getCheckpoints(config);
93+
94+
File targetFile = getFile(config);
95+
if( targetFile.exists() && result.isEmpty() ) {
96+
try {
97+
deserialize( targetFile, result );
98+
}
99+
catch (IOException | ClassNotFoundException e) {
100+
throw new RuntimeException(e);
101+
}
102+
}
103+
return result;
104+
}
105+
106+
@Override
107+
public Collection<Checkpoint> list(RunnableConfig config) {
108+
return super.list(config);
109+
}
110+
111+
@Override
112+
public Optional<Checkpoint> get(RunnableConfig config) {
113+
return super.get(config);
114+
}
115+
116+
@Override
117+
public RunnableConfig put(RunnableConfig config, Checkpoint checkpoint) throws Exception {
118+
RunnableConfig result = super.put(config, checkpoint);
119+
120+
File targetFile = getFile(config);
121+
serialize( super.getCheckpoints(config), targetFile );
122+
return result;
123+
}
124+
}

‎core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/MemorySaver.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public class MemorySaver implements BaseCheckpointSaver {
2121
public MemorySaver() {
2222
}
2323

24-
private LinkedList<Checkpoint> getCheckpoints( RunnableConfig config ) {
24+
protected LinkedList<Checkpoint> getCheckpoints( RunnableConfig config ) {
2525
return config.threadId()
2626
.map( threadId -> _checkpointsByThread.computeIfAbsent(threadId, k -> new LinkedList<>()) )
2727
.orElse( _defaultCheckpoints );
@@ -38,7 +38,7 @@ public Collection<Checkpoint> list( RunnableConfig config ) {
3838
}
3939
}
4040

41-
private Optional<Checkpoint> getLast( LinkedList<Checkpoint> checkpoints, RunnableConfig config ) {
41+
protected Optional<Checkpoint> getLast( LinkedList<Checkpoint> checkpoints, RunnableConfig config ) {
4242
return (checkpoints.isEmpty() ) ? Optional.empty() : ofNullable(checkpoints.peek());
4343
}
4444

There was a problem loading the remainder of the diff.

0 commit comments

Comments
 (0)
Failed to load comments.