|
16 | 16 | import java.util.*;
|
17 | 17 | import java.util.concurrent.CompletableFuture;
|
18 | 18 | import java.util.concurrent.CompletionException;
|
| 19 | +import java.util.function.Supplier; |
19 | 20 | import java.util.stream.Collectors;
|
| 21 | +import java.util.stream.Stream; |
20 | 22 |
|
21 | 23 | import static java.lang.String.format;
|
22 | 24 | import static java.util.concurrent.CompletableFuture.completedFuture;
|
@@ -59,9 +61,55 @@ protected CompiledGraph(StateGraph<State> stateGraph, CompileConfig compileConfi
|
59 | 61 | nodes.put(n.id(), factory.apply(compileConfig));
|
60 | 62 | }
|
61 | 63 |
|
62 |
| - stateGraph.edges.forEach(e -> |
63 |
| - edges.put(e.sourceId(), e.target()) |
64 |
| - ); |
| 64 | + for( var e : stateGraph.edges ) { |
| 65 | + var targets = e.targets(); |
| 66 | + if (targets.size() == 1) { |
| 67 | + edges.put(e.sourceId(), targets.get(0)); |
| 68 | + } |
| 69 | + else { |
| 70 | + Supplier<Stream<EdgeValue<State>>> parallelNodeStream = () -> |
| 71 | + targets.stream().filter( target -> nodes.containsKey(target.id()) ); |
| 72 | + |
| 73 | + var parallelNodeEdges = parallelNodeStream.get() |
| 74 | + .map( target -> new Edge<State>(target.id())) |
| 75 | + .filter( ee -> stateGraph.edges.contains( ee ) ) |
| 76 | + .map( ee -> stateGraph.edges.indexOf( ee ) ) |
| 77 | + .map( index -> stateGraph.edges.get(index) ) |
| 78 | + .toList(); |
| 79 | + |
| 80 | + var parallelNodeTargets = parallelNodeEdges.stream() |
| 81 | + .map( ee -> ee.target().id() ) |
| 82 | + .collect(Collectors.toSet()); |
| 83 | + |
| 84 | + if( parallelNodeTargets.size() > 1 ) { |
| 85 | + |
| 86 | + var conditionalEdges = parallelNodeEdges.stream() |
| 87 | + .filter( ee -> ee.target().value() != null ) |
| 88 | + .toList(); |
| 89 | + if(!conditionalEdges.isEmpty()) { |
| 90 | + throw StateGraph.Errors.unsupportedConditionalEdgeOnParallelNode.exception( |
| 91 | + e.sourceId(), |
| 92 | + conditionalEdges.stream().map(Edge::sourceId).toList() ); |
| 93 | + } |
| 94 | + throw StateGraph.Errors.illegalMultipleTargetsOnParallelNode.exception(e.sourceId(), parallelNodeTargets ); |
| 95 | + } |
| 96 | + |
| 97 | + var actions = parallelNodeStream.get() |
| 98 | + .map( target -> nodes.remove(target.id()) ) |
| 99 | + .toList(); |
| 100 | + |
| 101 | + var parallelNode = Node.parallel( e.sourceId(), actions, stateGraph.getChannels() ); |
| 102 | + |
| 103 | + nodes.put( parallelNode.id(), parallelNode.actionFactory().apply(compileConfig) ); |
| 104 | + |
| 105 | + edges.put( e.sourceId(), new EdgeValue<>( parallelNode.id(), null ) ); |
| 106 | + |
| 107 | + edges.put( parallelNode.id(), new EdgeValue<>( parallelNodeTargets.iterator().next(), null )); |
| 108 | + |
| 109 | + } |
| 110 | + |
| 111 | + |
| 112 | + } |
65 | 113 | }
|
66 | 114 |
|
67 | 115 | public Collection<StateSnapshot<State>> getStateHistory( RunnableConfig config ) {
|
@@ -145,12 +193,12 @@ public RunnableConfig updateState( RunnableConfig config, Map<String,Object> val
|
145 | 193 | return updateState(config, values, null);
|
146 | 194 | }
|
147 | 195 |
|
148 |
| - @Deprecated |
| 196 | + @Deprecated( forRemoval = true ) |
149 | 197 | public EdgeValue<State> getEntryPoint() {
|
150 | 198 | return stateGraph.getEntryPoint();
|
151 | 199 | }
|
152 | 200 |
|
153 |
| - @Deprecated |
| 201 | + @Deprecated( forRemoval = true ) |
154 | 202 | public String getFinishPoint() {
|
155 | 203 | return stateGraph.getFinishPoint();
|
156 | 204 | }
|
@@ -203,7 +251,8 @@ private String nextNodeId(String nodeId, Map<String,Object> state) throws Except
|
203 | 251 | }
|
204 | 252 |
|
205 | 253 | private String getEntryPoint( Map<String,Object> state ) throws Exception {
|
206 |
| - return nextNodeId(stateGraph.getEntryPoint(), state, "entryPoint"); |
| 254 | + var entryPoint = this.edges.get(START); |
| 255 | + return nextNodeId(entryPoint, state, "entryPoint"); |
207 | 256 | }
|
208 | 257 |
|
209 | 258 | private boolean shouldInterruptBefore(@NonNull String nodeId, String previousNodeId ) {
|
@@ -557,6 +606,5 @@ public Data<Output> next() {
|
557 | 606 | }
|
558 | 607 | }
|
559 | 608 |
|
560 |
| - |
561 |
| - |
562 | 609 | }
|
| 610 | + |
0 commit comments