16
16
import java .util .*;
17
17
import java .util .concurrent .CompletableFuture ;
18
18
import java .util .concurrent .CompletionException ;
19
+ import java .util .function .Function ;
19
20
import java .util .function .Supplier ;
20
21
import java .util .stream .Collectors ;
21
22
import java .util .stream .Stream ;
@@ -56,12 +57,38 @@ protected CompiledGraph(StateGraph<State> stateGraph, CompileConfig compileConfi
56
57
this .compileConfig = compileConfig ;
57
58
58
59
for (var n : stateGraph .nodes ) {
59
- var factory = n .actionFactory ();
60
- Objects .requireNonNull (factory , format ("action factory for node id '%s' is null!" , n .id ()) );
61
- nodes .put (n .id (), factory .apply (compileConfig ));
60
+
61
+ if ( n instanceof SubGraphNode <State > subgraphNode ) {
62
+
63
+ var sgWorkflow = subgraphNode .subGraph ();
64
+
65
+ // validate subgraph
66
+ sgWorkflow .validateGraph ();
67
+
68
+ var sgEdgeStart = sgWorkflow .edges .findEdgeBySourceId (START ).orElseThrow ();
69
+
70
+ if ( sgEdgeStart .isParallel () ) {
71
+ throw new UnsupportedOperationException ("subgraph not support start with parallel branches yet!" );
72
+ }
73
+ var edgesWithSubgraphTargetId = stateGraph .edges .findEdgesByTargetId (subgraphNode .id ());
74
+
75
+ for ( var edgeWithSubgraphTargetId : edgesWithSubgraphTargetId ) {
76
+
77
+ edgeWithSubgraphTargetId .withSourceAndTargetIdsUpdated ( subgraphNode ,
78
+ Function .identity (),
79
+ ( id ) -> subgraphNode .formatId (sgEdgeStart .sourceId ()) );
80
+ }
81
+
82
+
83
+ }
84
+ else {
85
+ var factory = n .actionFactory ();
86
+ Objects .requireNonNull (factory , format ("action factory for node id '%s' is null!" , n .id ()));
87
+ nodes .put (n .id (), factory .apply (compileConfig ));
88
+ }
62
89
}
63
90
64
- for ( var e : stateGraph .edges ) {
91
+ for ( var e : stateGraph .edges . elements ) {
65
92
var targets = e .targets ();
66
93
if (targets .size () == 1 ) {
67
94
edges .put (e .sourceId (), targets .get (0 ));
@@ -72,9 +99,9 @@ protected CompiledGraph(StateGraph<State> stateGraph, CompileConfig compileConfi
72
99
73
100
var parallelNodeEdges = parallelNodeStream .get ()
74
101
.map ( target -> new Edge <State >(target .id ()))
75
- .filter ( ee -> stateGraph .edges .contains ( ee ) )
76
- .map ( ee -> stateGraph .edges .indexOf ( ee ) )
77
- .map ( index -> stateGraph .edges .get (index ) )
102
+ .filter ( ee -> stateGraph .edges .elements . contains ( ee ) )
103
+ .map ( ee -> stateGraph .edges .elements . indexOf ( ee ) )
104
+ .map ( index -> stateGraph .edges .elements . get (index ) )
78
105
.toList ();
79
106
80
107
var parallelNodeTargets = parallelNodeEdges .stream ()
@@ -99,7 +126,7 @@ protected CompiledGraph(StateGraph<State> stateGraph, CompileConfig compileConfi
99
126
.map ( target -> nodes .get (target .id ()) )
100
127
.toList ();
101
128
102
- var parallelNode = Node . parallel ( e .sourceId (), actions , stateGraph .getChannels () );
129
+ var parallelNode = new ParallelNode <> ( e .sourceId (), actions , stateGraph .getChannels () );
103
130
104
131
nodes .put ( parallelNode .id (), parallelNode .actionFactory ().apply (compileConfig ) );
105
132
@@ -194,16 +221,6 @@ public RunnableConfig updateState( RunnableConfig config, Map<String,Object> val
194
221
return updateState (config , values , null );
195
222
}
196
223
197
- @ Deprecated ( forRemoval = true )
198
- public EdgeValue <State > getEntryPoint () {
199
- return stateGraph .getEntryPoint ();
200
- }
201
-
202
- @ Deprecated ( forRemoval = true )
203
- public String getFinishPoint () {
204
- return stateGraph .getFinishPoint ();
205
- }
206
-
207
224
/**
208
225
* Sets the maximum number of iterations for the graph execution.
209
226
*
0 commit comments