Skip to content

Commit 80dff74

Browse files
committed
test(subgraph): add more subgraph tests
work on #73
1 parent dd24149 commit 80dff74

File tree

1 file changed

+153
-27
lines changed

1 file changed

+153
-27
lines changed

core/src/test/java/org/bsc/langgraph4j/SubGraphTest.java

+153-27
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import java.util.*;
1313
import java.util.logging.LogManager;
1414

15+
import static java.lang.String.format;
1516
import static org.bsc.langgraph4j.StateGraph.END;
1617
import static org.bsc.langgraph4j.StateGraph.START;
1718
import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async;
@@ -58,9 +59,30 @@ public void testMergeSubgraph01() throws Exception {
5859
;
5960

6061
var processed = StateGraphNodesAndEdges.process( workflowParent );
62+
processed.nodes().elements.forEach( System.out::println );
63+
processed.edges().elements.forEach( System.out::println );
6164

6265
assertEquals( 5, processed.edges().elements.size() );
6366
assertEquals( 4, processed.nodes().elements.size() );
67+
68+
var B_B1 = SubGraphNode.formatId( "B", "B1");
69+
var B_B2 = SubGraphNode.formatId( "B", "B2");
70+
71+
List.of(
72+
"Node(A,action)",
73+
"Node(C,action)",
74+
format("Node(%s,action)", B_B1 ),
75+
format("Node(%s,action)", B_B2 )
76+
).forEach( n -> assertTrue( processed.nodes().elements.stream().anyMatch(n1 -> Objects.equals(n,n1.toString())) ));
77+
78+
List.of(
79+
"Edge[sourceId=__START__, targets=[EdgeValue[id=A, value=null]]]",
80+
"Edge[sourceId=C, targets=[EdgeValue[id=__END__, value=null]]]",
81+
format("Edge[sourceId=A, targets=[EdgeValue[id=%s, value=null]]]", B_B1),
82+
format("Edge[sourceId=%s, targets=[EdgeValue[id=C, value=null]]]", B_B2 ),
83+
format("Edge[sourceId=%s, targets=[EdgeValue[id=%s, value=null]]]", B_B1, B_B2)
84+
).forEach( e -> assertTrue( processed.edges().elements.stream().anyMatch( e1 -> Objects.equals(e,e1.toString()) ) ) );
85+
6486
}
6587

6688
@Test
@@ -88,52 +110,150 @@ public void testMergeSubgraph02() throws Exception {
88110
;
89111

90112
var processed = StateGraphNodesAndEdges.process( workflowParent );
113+
processed.nodes().elements.forEach( System.out::println );
114+
processed.edges().elements.forEach( System.out::println );
91115

92-
assertEquals( 5, processed.edges().elements.size() );
93-
var startEdge = processed.edges().edgeBySourceId(START);
94-
assertTrue( startEdge.isPresent() );
95-
assertTrue(startEdge.get().target().value().mappings().containsValue("B@B1"),
96-
"conditional edges 'START' doesn't contain 'B@B1'");
97116
assertEquals( 4, processed.nodes().elements.size() );
117+
assertEquals( 5, processed.edges().elements.size() );
98118

119+
var B_B1 = SubGraphNode.formatId( "B", "B1");
120+
var B_B2 = SubGraphNode.formatId( "B", "B2");
121+
122+
List.of(
123+
"Node(A,action)",
124+
"Node(C,action)",
125+
format("Node(%s,action)", B_B1 ),
126+
format("Node(%s,action)", B_B2 )
127+
).forEach( n -> assertTrue( processed.nodes().elements.stream().anyMatch(n1 -> Objects.equals(n,n1.toString())) ));
128+
129+
List.of(
130+
format("Edge[sourceId=__START__, targets=[EdgeValue[id=null, value=EdgeCondition[ action, mapping={a=A, b=%s}]]]", B_B1),
131+
"Edge[sourceId=C, targets=[EdgeValue[id=__END__, value=null]]]",
132+
format("Edge[sourceId=A, targets=[EdgeValue[id=%s, value=null]]]", B_B1),
133+
format("Edge[sourceId=%s, targets=[EdgeValue[id=C, value=null]]]", B_B2 ),
134+
format("Edge[sourceId=%s, targets=[EdgeValue[id=%s, value=null]]]", B_B1, B_B2)
135+
).forEach( e -> assertTrue( processed.edges().elements.stream().anyMatch( e1 -> Objects.equals(e,e1.toString()) ) ) );
99136

100137
var app = workflowParent.compile();
101-
for( var output : app.stream( Map.of()) ) {
102138

103-
System.out.println( output );
104-
};
139+
var output = app.stream(Map.of())
140+
.stream()
141+
.peek(System.out::println)
142+
.map(NodeOutput::node)
143+
.toList();
144+
145+
assertIterableEquals( List.of(
146+
START,
147+
"A",
148+
B_B1,
149+
B_B2,
150+
"C",
151+
END
152+
), output );
153+
154+
}
155+
156+
@Test
157+
public void testMergeSubgraph03() throws Exception {
158+
159+
var workflowChild = new MessagesStateGraph<String>()
160+
.addNode("B1", makeNode("B1") )
161+
.addNode("B2", makeNode( "B2" ) )
162+
.addNode("C", makeNode( "subgraph(C)" ) )
163+
.addEdge(START, "B1")
164+
.addEdge("B1", "B2")
165+
.addConditionalEdges( "B2",
166+
edge_async(state -> "c"),
167+
Map.of( END, END, "c", "C") )
168+
.addEdge("C", END)
169+
;
170+
171+
var workflowParent = new MessagesStateGraph<String>()
172+
.addNode("A", makeNode("A") )
173+
.addSubgraph("B", workflowChild )
174+
.addNode("C", makeNode("C") )
175+
.addConditionalEdges(START,
176+
edge_async(state -> "a"),
177+
Map.of( "a", "A", "b", "B") )
178+
.addEdge("A", "B")
179+
.addEdge("B", "C")
180+
.addEdge("C", END)
181+
//.compile(compileConfig)
182+
;
183+
184+
var processed = StateGraphNodesAndEdges.process( workflowParent );
185+
processed.nodes().elements.forEach( System.out::println );
186+
processed.edges().elements.forEach( System.out::println );
187+
188+
assertEquals( 5, processed.nodes().elements.size() );
189+
assertEquals( 6, processed.edges().elements.size() );
190+
191+
var B_B1 = SubGraphNode.formatId( "B", "B1");
192+
var B_B2 = SubGraphNode.formatId( "B", "B2");
193+
var B_C = SubGraphNode.formatId( "B", "C");
194+
195+
List.of(
196+
"Node(A,action)",
197+
"Node(C,action)",
198+
format("Node(%s,action)", B_B1 ),
199+
format("Node(%s,action)", B_B2 ),
200+
format("Node(%s,action)", B_C)
201+
).forEach( n -> assertTrue( processed.nodes().elements.stream()
202+
.anyMatch(n1 -> Objects.equals(n, n1.toString())),
203+
format( "node \"%s\" doesn't have a match!", n ) ));
204+
205+
206+
List.of(
207+
format("Edge[sourceId=__START__, targets=[EdgeValue[id=null, value=EdgeCondition[ action, mapping={a=A, b=%s}]]]", B_B1),
208+
"Edge[sourceId=C, targets=[EdgeValue[id=__END__, value=null]]]",
209+
format("Edge[sourceId=A, targets=[EdgeValue[id=%s, value=null]]]", B_B1),
210+
format("Edge[sourceId=%s, targets=[EdgeValue[id=null, value=EdgeCondition[ action, mapping={c=%s, __END__=C}]]]", B_B2, B_C ),
211+
format("Edge[sourceId=%s, targets=[EdgeValue[id=C, value=null]]]", B_C ),
212+
format("Edge[sourceId=%s, targets=[EdgeValue[id=%s, value=null]]]", B_B1, B_B2)
213+
).forEach( e -> assertTrue( processed.edges().elements.stream()
214+
.anyMatch( e1 -> Objects.equals(e,e1.toString()) ) ,
215+
format( "edge \"%s\" doesn't have a match!", e ) ) );
105216

217+
var app = workflowParent.compile();
218+
219+
var output = app.stream(Map.of())
220+
.stream()
221+
.peek(System.out::println)
222+
.map(NodeOutput::node)
223+
.toList();
224+
225+
assertIterableEquals( List.of(
226+
START,
227+
"A",
228+
B_B1,
229+
B_B2,
230+
B_C,
231+
"C",
232+
END
233+
), output );
106234

107235
}
108236

109237
@Test
110238
public void testCheckpointWithSubgraph() throws Exception {
111239

112-
var childStep1 = makeNode("child:step1");
113-
var childStep2 = makeNode("child:step2");
114-
var childStep3 = makeNode("child:step3");
115-
116240
var compileConfig = CompileConfig.builder().checkpointSaver(new MemorySaver()).build();
117241

118242
var workflowChild = new MessagesStateGraph<String>()
119-
.addNode("child:step_1", childStep1 )
120-
.addNode("child:step_2", childStep2 )
121-
.addNode("child:step_3", childStep3 )
122-
.addEdge(START, "child:step_1")
123-
.addEdge("child:step_1", "child:step_2")
124-
.addEdge("child:step_2", "child:step_3")
125-
.addEdge("child:step_3", END)
243+
.addNode("step_1", makeNode("child:step1") )
244+
.addNode("step_2", makeNode("child:step2") )
245+
.addNode("step_3", makeNode("child:step3") )
246+
.addEdge(START, "step_1")
247+
.addEdge("step_1", "step_2")
248+
.addEdge("step_2", "step_3")
249+
.addEdge("step_3", END)
126250
//.compile(compileConfig)
127251
;
128252

129-
var step1 = makeNode( "step1");
130-
var step2 = makeNode("step2");
131-
var step3 = makeNode("step3");
132-
133253
var workflowParent = new MessagesStateGraph<String>()
134-
.addNode("step_1", step1)
135-
.addNode("step_2", step2)
136-
.addNode("step_3", step3)
254+
.addNode("step_1", makeNode( "step1"))
255+
.addNode("step_2", makeNode("step2"))
256+
.addNode("step_3", makeNode("step3"))
137257
.addSubgraph("subgraph", workflowChild)
138258
.addEdge(START, "step_1")
139259
.addEdge("step_1", "step_2")
@@ -150,7 +270,13 @@ public void testCheckpointWithSubgraph() throws Exception {
150270
.map(NodeOutput::state);
151271

152272
assertTrue(result.isPresent());
153-
assertIterableEquals(List.of("step1", "step2", "child:step1", "child:step2", "child:step3", "step3"), result.get().messages());
273+
assertIterableEquals(List.of(
274+
"step1",
275+
"step2",
276+
"child:step1",
277+
"child:step2",
278+
"child:step3",
279+
"step3"), result.get().messages());
154280

155281
}
156282

0 commit comments

Comments
 (0)