1
1
package org .bsc .langgraph4j ;
2
2
3
+ import lombok .NonNull ;
3
4
import org .bsc .langgraph4j .state .AgentState ;
4
5
5
- import java .util .Objects ;
6
+ import java .util .*;
7
+ import java .util .function .Function ;
8
+ import java .util .stream .Collectors ;
9
+
10
+ import static java .lang .String .format ;
11
+ import static org .bsc .langgraph4j .StateGraph .START ;
6
12
7
13
/**
8
14
* Represents an edge in a graph with a source ID and a target value.
9
15
*
10
16
* @param <State> the type of the state associated with the edge
11
17
* @param sourceId The ID of the source node.
12
- * @param target The target value associated with the edge.
18
+ * @param targets The targets value associated with the edge.
13
19
*/
14
- record Edge <State extends AgentState >(String sourceId , EdgeValue <State > target ) {
20
+ record Edge <State extends AgentState >(String sourceId , List <EdgeValue <State >> targets ) {
21
+
22
+ public Edge (String sourceId , EdgeValue <State > target ) {
23
+ this (sourceId , List .of (target ));
24
+ }
25
+
26
+ public Edge (String id ) {
27
+ this (id , List .of ());
28
+ }
29
+
30
+ public boolean isParallel () {
31
+ return targets .size () > 1 ;
32
+ }
33
+
34
+ public EdgeValue <State > target () {
35
+ if ( isParallel () ) {
36
+ throw new IllegalStateException ( format ("Edge '%s' is parallel" , sourceId ));
37
+ }
38
+ return targets .get (0 );
39
+ }
40
+
41
+ public void validate ( @ NonNull Collection <Node <State >> nodes ) throws GraphStateException {
42
+ if ( !Objects .equals (sourceId (),START ) && !nodes .contains (new Node <State >(sourceId ()))) {
43
+ throw StateGraph .Errors .missingNodeReferencedByEdge .exception (sourceId ());
44
+ }
45
+
46
+ if ( isParallel () ) { // check for duplicates targets
47
+ Set <String > duplicates = targets .stream ()
48
+ .collect (Collectors .groupingBy (EdgeValue ::id , Collectors .counting ())) // Group by element and count occurrences
49
+ .entrySet ()
50
+ .stream ()
51
+ .filter (entry -> entry .getValue () > 1 ) // Filter elements with more than one occurrence
52
+ .map (Map .Entry ::getKey )
53
+ .collect (Collectors .toSet ());
54
+ if ( !duplicates .isEmpty () ) {
55
+ throw StateGraph .Errors .duplicateEdgeTargetError .exception (sourceId (), duplicates );
56
+ }
57
+ }
58
+
59
+ for ( EdgeValue <State > target : targets ) {
60
+ validate (target , nodes );
61
+ }
62
+
63
+ }
64
+
65
+ private void validate ( EdgeValue <State > target , Collection <Node <State >> nodes ) throws GraphStateException {
66
+ if (target .id () != null ) {
67
+ if (!Objects .equals (target .id (), StateGraph .END ) && !nodes .contains (new Node <State >(target .id ()))) {
68
+ throw StateGraph .Errors .missingNodeReferencedByEdge .exception (target .id ());
69
+ }
70
+ } else if (target .value () != null ) {
71
+ for (String nodeId : target .value ().mappings ().values ()) {
72
+ if (!Objects .equals (nodeId , StateGraph .END ) && !nodes .contains (new Node <State >(nodeId ))) {
73
+ throw StateGraph .Errors .missingNodeInEdgeMapping .exception (sourceId (), nodeId );
74
+ }
75
+ }
76
+ } else {
77
+ throw StateGraph .Errors .invalidEdgeTarget .exception (sourceId ());
78
+ }
79
+
80
+ }
15
81
16
82
/**
17
83
* Checks if this edge is equal to another object.
@@ -38,3 +104,4 @@ public int hashCode() {
38
104
}
39
105
40
106
}
107
+
0 commit comments