Skip to content

Commit 89ac2d3

Browse files
committed
refactor(Edge): refactor edge representation to support multiple targets
work on #72
1 parent d50f56c commit 89ac2d3

File tree

1 file changed

+70
-3
lines changed
  • core/src/main/java/org/bsc/langgraph4j

1 file changed

+70
-3
lines changed

core/src/main/java/org/bsc/langgraph4j/Edge.java

+70-3
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,83 @@
11
package org.bsc.langgraph4j;
22

3+
import lombok.NonNull;
34
import org.bsc.langgraph4j.state.AgentState;
45

5-
import java.util.Objects;
6+
import java.util.*;
7+
import java.util.function.Function;
8+
import java.util.stream.Collectors;
9+
10+
import static java.lang.String.format;
11+
import static org.bsc.langgraph4j.StateGraph.START;
612

713
/**
814
* Represents an edge in a graph with a source ID and a target value.
915
*
1016
* @param <State> the type of the state associated with the edge
1117
* @param sourceId The ID of the source node.
12-
* @param target The target value associated with the edge.
18+
* @param targets The targets value associated with the edge.
1319
*/
14-
record Edge<State extends AgentState>(String sourceId, EdgeValue<State> target) {
20+
record Edge<State extends AgentState>(String sourceId, List<EdgeValue<State>> targets) {
21+
22+
public Edge(String sourceId, EdgeValue<State> target) {
23+
this(sourceId, List.of(target));
24+
}
25+
26+
public Edge(String id) {
27+
this(id, List.of());
28+
}
29+
30+
public boolean isParallel() {
31+
return targets.size() > 1;
32+
}
33+
34+
public EdgeValue<State> target() {
35+
if( isParallel() ) {
36+
throw new IllegalStateException( format("Edge '%s' is parallel", sourceId));
37+
}
38+
return targets.get(0);
39+
}
40+
41+
public void validate( @NonNull Collection<Node<State>> nodes) throws GraphStateException {
42+
if ( !Objects.equals(sourceId(),START) && !nodes.contains(new Node<State>(sourceId()))) {
43+
throw StateGraph.Errors.missingNodeReferencedByEdge.exception(sourceId());
44+
}
45+
46+
if( isParallel() ) { // check for duplicates targets
47+
Set<String> duplicates = targets.stream()
48+
.collect(Collectors.groupingBy(EdgeValue::id, Collectors.counting())) // Group by element and count occurrences
49+
.entrySet()
50+
.stream()
51+
.filter(entry -> entry.getValue() > 1) // Filter elements with more than one occurrence
52+
.map(Map.Entry::getKey)
53+
.collect(Collectors.toSet());
54+
if( !duplicates.isEmpty() ) {
55+
throw StateGraph.Errors.duplicateEdgeTargetError.exception(sourceId(), duplicates);
56+
}
57+
}
58+
59+
for( EdgeValue<State> target : targets ) {
60+
validate(target, nodes);
61+
}
62+
63+
}
64+
65+
private void validate( EdgeValue<State> target, Collection<Node<State>> nodes ) throws GraphStateException {
66+
if (target.id() != null) {
67+
if (!Objects.equals(target.id(), StateGraph.END) && !nodes.contains(new Node<State>(target.id()))) {
68+
throw StateGraph.Errors.missingNodeReferencedByEdge.exception(target.id());
69+
}
70+
} else if (target.value() != null) {
71+
for (String nodeId : target.value().mappings().values()) {
72+
if (!Objects.equals(nodeId, StateGraph.END) && !nodes.contains(new Node<State>(nodeId))) {
73+
throw StateGraph.Errors.missingNodeInEdgeMapping.exception(sourceId(), nodeId);
74+
}
75+
}
76+
} else {
77+
throw StateGraph.Errors.invalidEdgeTarget.exception(sourceId());
78+
}
79+
80+
}
1581

1682
/**
1783
* Checks if this edge is equal to another object.
@@ -38,3 +104,4 @@ public int hashCode() {
38104
}
39105

40106
}
107+

0 commit comments

Comments
 (0)