Skip to content

Commit 1e58877

Browse files
committed
feat(CompiledGraph.java): Refactor CompiledGraph to prepare for support of sub-graphs merge
- Removed deprecated methods `getEntryPoint` and `getFinishPoint` work on #73
1 parent 68ae2df commit 1e58877

File tree

1 file changed

+35
-18
lines changed

1 file changed

+35
-18
lines changed

core/src/main/java/org/bsc/langgraph4j/CompiledGraph.java

+35-18
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import java.util.*;
1717
import java.util.concurrent.CompletableFuture;
1818
import java.util.concurrent.CompletionException;
19+
import java.util.function.Function;
1920
import java.util.function.Supplier;
2021
import java.util.stream.Collectors;
2122
import java.util.stream.Stream;
@@ -56,12 +57,38 @@ protected CompiledGraph(StateGraph<State> stateGraph, CompileConfig compileConfi
5657
this.compileConfig = compileConfig;
5758

5859
for (var n : stateGraph.nodes) {
59-
var factory = n.actionFactory();
60-
Objects.requireNonNull(factory, format("action factory for node id '%s' is null!", n.id()) );
61-
nodes.put(n.id(), factory.apply(compileConfig));
60+
61+
if( n instanceof SubGraphNode<State> subgraphNode ) {
62+
63+
var sgWorkflow = subgraphNode.subGraph();
64+
65+
// validate subgraph
66+
sgWorkflow.validateGraph();
67+
68+
var sgEdgeStart = sgWorkflow.edges.findEdgeBySourceId(START).orElseThrow();
69+
70+
if( sgEdgeStart.isParallel() ) {
71+
throw new UnsupportedOperationException("subgraph not support start with parallel branches yet!");
72+
}
73+
var edgesWithSubgraphTargetId = stateGraph.edges.findEdgesByTargetId(subgraphNode.id());
74+
75+
for( var edgeWithSubgraphTargetId : edgesWithSubgraphTargetId ) {
76+
77+
edgeWithSubgraphTargetId.withSourceAndTargetIdsUpdated( subgraphNode,
78+
Function.identity(),
79+
( id ) -> subgraphNode.formatId(sgEdgeStart.sourceId()) );
80+
}
81+
82+
83+
}
84+
else {
85+
var factory = n.actionFactory();
86+
Objects.requireNonNull(factory, format("action factory for node id '%s' is null!", n.id()));
87+
nodes.put(n.id(), factory.apply(compileConfig));
88+
}
6289
}
6390

64-
for( var e : stateGraph.edges ) {
91+
for( var e : stateGraph.edges.elements ) {
6592
var targets = e.targets();
6693
if (targets.size() == 1) {
6794
edges.put(e.sourceId(), targets.get(0));
@@ -72,9 +99,9 @@ protected CompiledGraph(StateGraph<State> stateGraph, CompileConfig compileConfi
7299

73100
var parallelNodeEdges = parallelNodeStream.get()
74101
.map( target -> new Edge<State>(target.id()))
75-
.filter( ee -> stateGraph.edges.contains( ee ) )
76-
.map( ee -> stateGraph.edges.indexOf( ee ) )
77-
.map( index -> stateGraph.edges.get(index) )
102+
.filter( ee -> stateGraph.edges.elements.contains( ee ) )
103+
.map( ee -> stateGraph.edges.elements.indexOf( ee ) )
104+
.map( index -> stateGraph.edges.elements.get(index) )
78105
.toList();
79106

80107
var parallelNodeTargets = parallelNodeEdges.stream()
@@ -99,7 +126,7 @@ protected CompiledGraph(StateGraph<State> stateGraph, CompileConfig compileConfi
99126
.map( target -> nodes.get(target.id()) )
100127
.toList();
101128

102-
var parallelNode = Node.parallel( e.sourceId(), actions, stateGraph.getChannels() );
129+
var parallelNode = new ParallelNode<>( e.sourceId(), actions, stateGraph.getChannels() );
103130

104131
nodes.put( parallelNode.id(), parallelNode.actionFactory().apply(compileConfig) );
105132

@@ -194,16 +221,6 @@ public RunnableConfig updateState( RunnableConfig config, Map<String,Object> val
194221
return updateState(config, values, null);
195222
}
196223

197-
@Deprecated( forRemoval = true )
198-
public EdgeValue<State> getEntryPoint() {
199-
return stateGraph.getEntryPoint();
200-
}
201-
202-
@Deprecated( forRemoval = true )
203-
public String getFinishPoint() {
204-
return stateGraph.getFinishPoint();
205-
}
206-
207224
/**
208225
* Sets the maximum number of iterations for the graph execution.
209226
*

0 commit comments

Comments
 (0)