Skip to content

Commit d4a5dd2

Browse files
committed
feat(AsyncNodeGenerator): add 'resumedFromEmbed' state management in AsyncNodeGenerator
work on #31
1 parent d39eef9 commit d4a5dd2

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java

+17-6
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ public class AsyncNodeGenerator<Output extends NodeOutput<State>> implements Asy
372372
String nextNodeId;
373373
int iteration = 0;
374374
RunnableConfig config;
375+
boolean resumedFromEmbed = false;
375376

376377
protected AsyncNodeGenerator(Map<String,Object> inputs, RunnableConfig config ) throws Exception {
377378
final boolean isResumeRequest = (inputs == null);
@@ -431,6 +432,7 @@ private Optional<Data<Output>> getEmbedGenerator( Map<String,Object> partialStat
431432
}
432433
currentState = AgentState.updateState(currentState, (Map<String, Object>)data, stateGraph.getChannels());
433434
nextNodeId = nextNodeId(currentNodeId, currentState);
435+
resumedFromEmbed = true;
434436
})
435437
)
436438
;
@@ -449,12 +451,7 @@ private CompletableFuture<Data<Output>> evaluateAction(AsyncNodeAction<State> ac
449451
currentState = AgentState.updateState(currentState, partialState, stateGraph.getChannels());
450452
nextNodeId = nextNodeId(currentNodeId, currentState);
451453

452-
Optional<Checkpoint> cp = addCheckpoint(config, currentNodeId, currentState, nextNodeId);
453-
CompletableFuture<Output> future = completedFuture(( cp.isPresent() && config.streamMode() == StreamMode.SNAPSHOTS) ?
454-
buildStateSnapshot(cp.get()) :
455-
buildNodeOutput( currentNodeId ))
456-
;
457-
return Data.of( future );
454+
return Data.of( getNodeOutput() );
458455
}
459456
catch (Exception e) {
460457
throw new CompletionException(e);
@@ -463,6 +460,14 @@ private CompletableFuture<Data<Output>> evaluateAction(AsyncNodeAction<State> ac
463460
});
464461
}
465462

463+
private CompletableFuture<Output> getNodeOutput() throws Exception {
464+
Optional<Checkpoint> cp = addCheckpoint(config, currentNodeId, currentState, nextNodeId);
465+
return completedFuture(( cp.isPresent() && config.streamMode() == StreamMode.SNAPSHOTS) ?
466+
buildStateSnapshot(cp.get()) :
467+
buildNodeOutput( currentNodeId ))
468+
;
469+
}
470+
466471
@Override
467472
public Data<Output> next() {
468473
// GUARD: CHECK MAX ITERATION REACHED
@@ -475,6 +480,12 @@ public Data<Output> next() {
475480
if( nextNodeId == null && currentNodeId == null ) return Data.done();
476481

477482
try {
483+
// IS IT A RESUME FROM EMBED ?
484+
if(resumedFromEmbed) {
485+
final CompletableFuture<Output> future = getNodeOutput();
486+
resumedFromEmbed = false;
487+
return Data.of( future );
488+
}
478489

479490
if( START.equals(currentNodeId) ) {
480491
nextNodeId = getEntryPoint( currentState );

0 commit comments

Comments
 (0)