|
| 1 | +package org.bsc.langgraph4j.checkpoint; |
| 2 | + |
| 3 | +import lombok.NonNull; |
| 4 | +import org.bsc.langgraph4j.RunnableConfig; |
| 5 | +import org.bsc.langgraph4j.serializer.Serializer; |
| 6 | +import org.bsc.langgraph4j.serializer.StateSerializer; |
| 7 | +import org.bsc.langgraph4j.serializer.std.NullableObjectSerializer; |
| 8 | +import org.bsc.langgraph4j.state.AgentState; |
| 9 | + |
| 10 | +import java.io.*; |
| 11 | +import java.nio.file.Files; |
| 12 | +import java.nio.file.Path; |
| 13 | +import java.nio.file.Paths; |
| 14 | +import java.util.*; |
| 15 | + |
| 16 | +import static java.lang.String.format; |
| 17 | + |
| 18 | +public class FileSystemSaver extends MemorySaver { |
| 19 | + |
| 20 | + private final StateSerializer<AgentState> stateSerializer; |
| 21 | + private final Path targetFolder; |
| 22 | + private final Serializer<Checkpoint> serializer = new NullableObjectSerializer<Checkpoint>() { |
| 23 | + |
| 24 | + @Override |
| 25 | + public void write(Checkpoint object, ObjectOutput out) throws IOException { |
| 26 | + out.writeUTF( object.getId() ); |
| 27 | + writeNullableUTF(object.getNodeId(), out); |
| 28 | + writeNullableUTF(object.getNextNodeId(), out); |
| 29 | + AgentState state = stateSerializer.stateFactory().apply(object.getState()); |
| 30 | + stateSerializer.write( state, out); |
| 31 | + } |
| 32 | + |
| 33 | + @Override |
| 34 | + public Checkpoint read(ObjectInput in) throws IOException, ClassNotFoundException { |
| 35 | + return Checkpoint.builder() |
| 36 | + .id( in.readUTF() ) |
| 37 | + .state( stateSerializer.read(in) ) |
| 38 | + .nextNodeId( readNullableUTF(in).orElse(null) ) |
| 39 | + .nodeId( readNullableUTF(in).orElse(null) ) |
| 40 | + .build(); |
| 41 | + } |
| 42 | + }; |
| 43 | + |
| 44 | + @SuppressWarnings("unchecked") |
| 45 | + public FileSystemSaver( @NonNull Path targetFolder, @NonNull StateSerializer<? extends AgentState> stateSerializer) { |
| 46 | + File targetFolderAsFile = targetFolder.toFile(); |
| 47 | + |
| 48 | + if( targetFolderAsFile.exists() ) { |
| 49 | + if (targetFolderAsFile.isFile()) { |
| 50 | + throw new IllegalArgumentException( format("targetFolder '%s' must be a folder", targetFolder) ); // TODO: format"targetFolder must be a directory"); |
| 51 | + } |
| 52 | + } |
| 53 | + else { |
| 54 | + if( !targetFolderAsFile.mkdirs() ) { |
| 55 | + throw new IllegalArgumentException( format("targetFolder '%s' cannot be created", targetFolder) ); // TODO: format"targetFolder cannot be created"); |
| 56 | + } |
| 57 | + } |
| 58 | + |
| 59 | + this.targetFolder = targetFolder; |
| 60 | + this.stateSerializer = (StateSerializer<AgentState>) stateSerializer; |
| 61 | + } |
| 62 | + |
| 63 | + private File getFile(RunnableConfig config) { |
| 64 | + return config.threadId() |
| 65 | + .map( threadId -> Paths.get( targetFolder.toString(), format( "thread-%s.saver", threadId) ) ) |
| 66 | + .orElseGet( () -> Paths.get( targetFolder.toString(), "thread-default.saver" ) ) |
| 67 | + .toFile(); |
| 68 | + |
| 69 | + } |
| 70 | + private void serialize(@NonNull LinkedList<Checkpoint> checkpoints, @NonNull File outFile ) throws IOException { |
| 71 | + |
| 72 | + try (ObjectOutputStream oos = new ObjectOutputStream(Files.newOutputStream(outFile.toPath())) ) { |
| 73 | + |
| 74 | + oos.writeInt( checkpoints.size() ); |
| 75 | + for(Checkpoint checkpoint : checkpoints) { |
| 76 | + serializer.write(checkpoint, oos); |
| 77 | + } |
| 78 | + } |
| 79 | + } |
| 80 | + |
| 81 | + private void deserialize(@NonNull File file, @NonNull LinkedList<Checkpoint> result) throws IOException, ClassNotFoundException { |
| 82 | + |
| 83 | + try (ObjectInputStream ois = new ObjectInputStream(Files.newInputStream(file.toPath())) ) { |
| 84 | + int size = ois.readInt(); |
| 85 | + for( int i = 0; i < size; i++ ) { |
| 86 | + result.add( serializer.read(ois) ); |
| 87 | + } |
| 88 | + } |
| 89 | + } |
| 90 | + @Override |
| 91 | + protected LinkedList<Checkpoint> getCheckpoints(RunnableConfig config) { |
| 92 | + LinkedList<Checkpoint> result = super.getCheckpoints(config); |
| 93 | + |
| 94 | + File targetFile = getFile(config); |
| 95 | + if( targetFile.exists() && result.isEmpty() ) { |
| 96 | + try { |
| 97 | + deserialize( targetFile, result ); |
| 98 | + } |
| 99 | + catch (IOException | ClassNotFoundException e) { |
| 100 | + throw new RuntimeException(e); |
| 101 | + } |
| 102 | + } |
| 103 | + return result; |
| 104 | + } |
| 105 | + |
| 106 | + @Override |
| 107 | + public Collection<Checkpoint> list(RunnableConfig config) { |
| 108 | + return super.list(config); |
| 109 | + } |
| 110 | + |
| 111 | + @Override |
| 112 | + public Optional<Checkpoint> get(RunnableConfig config) { |
| 113 | + return super.get(config); |
| 114 | + } |
| 115 | + |
| 116 | + @Override |
| 117 | + public RunnableConfig put(RunnableConfig config, Checkpoint checkpoint) throws Exception { |
| 118 | + RunnableConfig result = super.put(config, checkpoint); |
| 119 | + |
| 120 | + File targetFile = getFile(config); |
| 121 | + serialize( super.getCheckpoints(config), targetFile ); |
| 122 | + return result; |
| 123 | + } |
| 124 | +} |
0 commit comments