Skip to content

Commit 05c293f

Browse files
committedAug 26, 2024
feat(BaseCheckpointSaver): add support for ThreadId
work on #20
1 parent 1407b41 commit 05c293f

File tree

2 files changed

+53
-20
lines changed

2 files changed

+53
-20
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
package org.bsc.langgraph4j.checkpoint;
22

3-
import java.io.Externalizable;
3+
import org.bsc.langgraph4j.RunnableConfig;
4+
45
import java.util.Collection;
56
import java.util.Optional;
67

78
public interface BaseCheckpointSaver {
89

9-
10-
Collection<Checkpoint> list();
11-
Optional<Checkpoint> getLast();
12-
void put( Checkpoint checkpoint ) throws Exception;
10+
Collection<Checkpoint> list( RunnableConfig config );
11+
Optional<Checkpoint> get( RunnableConfig config );
12+
RunnableConfig put( RunnableConfig config, Checkpoint checkpoint ) throws Exception;
1313
}
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,72 @@
11
package org.bsc.langgraph4j.checkpoint;
22

3+
import org.bsc.langgraph4j.RunnableConfig;
34
import org.bsc.langgraph4j.serializer.CheckpointSerializer;
45

56
import java.util.*;
7+
import java.util.stream.IntStream;
68

9+
import static java.lang.String.format;
710
import static java.util.Collections.unmodifiableCollection;
8-
import static java.util.Collections.unmodifiableSet;
11+
import static java.util.Optional.ofNullable;
912

1013
public class MemorySaver implements BaseCheckpointSaver {
14+
private final Map<String, LinkedList<Checkpoint>> _checkpointsByThread = new HashMap<>();
15+
private final LinkedList<Checkpoint> _defaultCheckpoints = new LinkedList<>();
1116

12-
private final Stack<Checkpoint> checkpoints = new Stack<>();
17+
public MemorySaver() {
18+
}
1319

20+
private LinkedList<Checkpoint> getCheckpoints( RunnableConfig config ) {
21+
return config.threadId()
22+
.map( threadId -> _checkpointsByThread.computeIfAbsent(threadId, k -> new LinkedList<>()) )
23+
.orElse( _defaultCheckpoints );
24+
}
1425

1526
@Override
16-
public Collection<Checkpoint> list() {
27+
public Collection<Checkpoint> list( RunnableConfig config ) {
28+
final LinkedList<Checkpoint> checkpoints = getCheckpoints(config);
1729
return unmodifiableCollection(checkpoints); // immutable checkpoints;
1830
}
1931

20-
@Override
21-
public Optional<Checkpoint> getLast() {
22-
if( checkpoints.isEmpty() ) {
23-
return Optional.empty();
24-
}
25-
return Optional.ofNullable( checkpoints.peek() );
32+
private Optional<Checkpoint> getLast( RunnableConfig config ) {
33+
final LinkedList<Checkpoint> checkpoints = getCheckpoints(config);
34+
return (checkpoints.isEmpty() ) ? Optional.empty() : ofNullable(checkpoints.peek());
2635
}
2736

28-
public Optional<Checkpoint> get(String id) {
29-
return checkpoints.stream()
30-
.filter( checkpoint -> checkpoint.getId().equals(id) )
31-
.findFirst();
37+
@Override
38+
public Optional<Checkpoint> get(RunnableConfig config) {
39+
final LinkedList<Checkpoint> checkpoints = getCheckpoints(config);
40+
if( config.checkPointId().isPresent() ) {
41+
return config.checkPointId()
42+
.flatMap( id -> checkpoints.stream()
43+
.filter( checkpoint -> checkpoint.getId().equals(id) )
44+
.findFirst());
45+
}
46+
return getLast(config);
3247
}
3348

3449
@Override
35-
public void put(Checkpoint checkpoint) throws Exception {
36-
checkpoints.add( CheckpointSerializer.INSTANCE.cloneObject(checkpoint) );
50+
public RunnableConfig put(RunnableConfig config, Checkpoint checkpoint) throws Exception {
51+
final LinkedList<Checkpoint> checkpoints = getCheckpoints(config);
52+
53+
final Checkpoint clonedCheckpoint = CheckpointSerializer.INSTANCE.cloneObject(checkpoint);
54+
55+
if( config.checkPointId().isPresent() ) { // Replace Checkpoint
56+
String checkPointId = config.checkPointId().get();
57+
int index = IntStream.range(0, checkpoints.size())
58+
.filter(i -> checkpoints.get(i).getId().equals(checkPointId))
59+
.findFirst()
60+
.orElseThrow( () -> (new NoSuchElementException( format("Checkpoint with id %s not found!", checkPointId))) );
61+
checkpoints.set( index, clonedCheckpoint);
62+
return config;
63+
}
64+
65+
checkpoints.push( clonedCheckpoint ); // Add Checkpoint
66+
67+
return RunnableConfig.builder(config)
68+
.checkPointId( clonedCheckpoint.getId() )
69+
.build();
3770
}
3871

3972
}

0 commit comments

Comments
 (0)