Skip to content

Commit 38a21cc

Browse files
committed
test: add new test cases for subgraph merge
- Added new test cases to cover different scenarios in merging subgraphs. work on #73
1 parent 14b1e51 commit 38a21cc

File tree

1 file changed

+193
-76
lines changed

1 file changed

+193
-76
lines changed

core/src/test/java/org/bsc/langgraph4j/SubGraphTest.java

+193-76
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import java.util.*;
1313
import java.util.logging.LogManager;
1414

15-
import static java.lang.String.format;
1615
import static org.bsc.langgraph4j.StateGraph.END;
1716
import static org.bsc.langgraph4j.StateGraph.START;
1817
import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async;
@@ -65,33 +64,21 @@ public void testMergeSubgraph01() throws Exception {
6564
.addEdge("A", "B")
6665
.addEdge("B", "C")
6766
.addEdge("C", END)
68-
//.compile(compileConfig)
6967
;
7068

71-
var processed = StateGraphNodesAndEdges.process( workflowParent );
72-
processed.nodes().elements.forEach( System.out::println );
73-
processed.edges().elements.forEach( System.out::println );
74-
75-
assertEquals( 5, processed.edges().elements.size() );
76-
assertEquals( 4, processed.nodes().elements.size() );
77-
7869
var B_B1 = SubGraphNode.formatId( "B", "B1");
7970
var B_B2 = SubGraphNode.formatId( "B", "B2");
8071

81-
List.of(
82-
"Node(A,action)",
83-
"Node(C,action)",
84-
format("Node(%s,action)", B_B1 ),
85-
format("Node(%s,action)", B_B2 )
86-
).forEach( n -> assertTrue( processed.nodes().elements.stream().anyMatch(n1 -> Objects.equals(n,n1.toString())) ));
87-
88-
List.of(
89-
"Edge[sourceId=__START__, targets=[EdgeValue[id=A, value=null]]]",
90-
"Edge[sourceId=C, targets=[EdgeValue[id=__END__, value=null]]]",
91-
format("Edge[sourceId=A, targets=[EdgeValue[id=%s, value=null]]]", B_B1),
92-
format("Edge[sourceId=%s, targets=[EdgeValue[id=C, value=null]]]", B_B2 ),
93-
format("Edge[sourceId=%s, targets=[EdgeValue[id=%s, value=null]]]", B_B1, B_B2)
94-
).forEach( e -> assertTrue( processed.edges().elements.stream().anyMatch( e1 -> Objects.equals(e,e1.toString()) ) ) );
72+
var app = workflowParent.compile();
73+
74+
assertIterableEquals( List.of(
75+
START,
76+
"A",
77+
B_B1,
78+
B_B2,
79+
"C",
80+
END
81+
), _execute( app, Map.of() ) );
9582

9683
}
9784

@@ -116,34 +103,11 @@ public void testMergeSubgraph02() throws Exception {
116103
.addEdge("A", "B")
117104
.addEdge("B", "C")
118105
.addEdge("C", END)
119-
//.compile(compileConfig)
120106
;
121107

122-
var processed = StateGraphNodesAndEdges.process( workflowParent );
123-
processed.nodes().elements.forEach( System.out::println );
124-
processed.edges().elements.forEach( System.out::println );
125-
126-
assertEquals( 4, processed.nodes().elements.size() );
127-
assertEquals( 5, processed.edges().elements.size() );
128-
129108
var B_B1 = SubGraphNode.formatId( "B", "B1");
130109
var B_B2 = SubGraphNode.formatId( "B", "B2");
131110

132-
List.of(
133-
"Node(A,action)",
134-
"Node(C,action)",
135-
format("Node(%s,action)", B_B1 ),
136-
format("Node(%s,action)", B_B2 )
137-
).forEach( n -> assertTrue( processed.nodes().elements.stream().anyMatch(n1 -> Objects.equals(n,n1.toString())) ));
138-
139-
List.of(
140-
format("Edge[sourceId=__START__, targets=[EdgeValue[id=null, value=EdgeCondition[ action, mapping={a=A, b=%s}]]]", B_B1),
141-
"Edge[sourceId=C, targets=[EdgeValue[id=__END__, value=null]]]",
142-
format("Edge[sourceId=A, targets=[EdgeValue[id=%s, value=null]]]", B_B1),
143-
format("Edge[sourceId=%s, targets=[EdgeValue[id=C, value=null]]]", B_B2 ),
144-
format("Edge[sourceId=%s, targets=[EdgeValue[id=%s, value=null]]]", B_B1, B_B2)
145-
).forEach( e -> assertTrue( processed.edges().elements.stream().anyMatch( e1 -> Objects.equals(e,e1.toString()) ) ) );
146-
147111
var app = workflowParent.compile();
148112

149113
assertIterableEquals( List.of(
@@ -160,6 +124,51 @@ public void testMergeSubgraph02() throws Exception {
160124
@Test
161125
public void testMergeSubgraph03() throws Exception {
162126

127+
var workflowChild = new MessagesStateGraph<String>()
128+
.addNode("B1", _makeNode("B1") )
129+
.addNode("B2", _makeNode( "B2" ) )
130+
.addNode("C", _makeNode( "subgraph(C)" ) )
131+
.addEdge(START, "B1")
132+
.addEdge("B1", "B2")
133+
.addConditionalEdges( "B2",
134+
edge_async(state -> "c"),
135+
Map.of( END, END, "c", "C") )
136+
.addEdge("C", END)
137+
;
138+
139+
var workflowParent = new MessagesStateGraph<String>()
140+
.addNode("A", _makeNode("A") )
141+
.addSubgraph("B", workflowChild )
142+
.addNode("C", _makeNode("C") )
143+
.addConditionalEdges(START,
144+
edge_async(state -> "a"),
145+
Map.of( "a", "A", "b", "B") )
146+
.addEdge("A", "B")
147+
.addEdge("B", "C")
148+
.addEdge("C", END)
149+
;
150+
151+
var B_B1 = SubGraphNode.formatId( "B", "B1");
152+
var B_B2 = SubGraphNode.formatId( "B", "B2");
153+
var B_C = SubGraphNode.formatId( "B", "C");
154+
155+
var app = workflowParent.compile();
156+
157+
assertIterableEquals( List.of(
158+
START,
159+
"A",
160+
B_B1,
161+
B_B2,
162+
B_C,
163+
"C",
164+
END
165+
), _execute( app, Map.of() ) );
166+
167+
}
168+
169+
@Test
170+
public void testMergeSubgraph03WithInterruption() throws Exception {
171+
163172
var workflowChild = new MessagesStateGraph<String>()
164173
.addNode("B1", _makeNode("B1") )
165174
.addNode("B2", _makeNode( "B2" ) )
@@ -185,6 +194,129 @@ public void testMergeSubgraph03() throws Exception {
185194
//.compile(compileConfig)
186195
;
187196

197+
var B_B1 = SubGraphNode.formatId( "B", "B1");
198+
var B_B2 = SubGraphNode.formatId( "B", "B2");
199+
var B_C = SubGraphNode.formatId( "B", "C");
200+
201+
var saver = new MemorySaver();
202+
203+
var withSaver = workflowParent.compile(
204+
CompileConfig.builder()
205+
.checkpointSaver(saver)
206+
.build());
207+
208+
assertIterableEquals( List.of(
209+
START,
210+
"A",
211+
B_B1,
212+
B_B2,
213+
B_C,
214+
"C",
215+
END
216+
), _execute( withSaver, Map.of()) );
217+
218+
// INTERRUPT AFTER B1
219+
var interruptAfterB1 = workflowParent.compile(
220+
CompileConfig.builder()
221+
.checkpointSaver(saver)
222+
.interruptAfter( B_B1 )
223+
.build());
224+
assertIterableEquals( List.of(
225+
START,
226+
"A",
227+
B_B1
228+
), _execute( interruptAfterB1, Map.of() ) );
229+
230+
// RESUME AFTER B1
231+
assertIterableEquals( List.of(
232+
B_B2,
233+
B_C,
234+
"C",
235+
END
236+
), _execute( interruptAfterB1, null ) );
237+
238+
// INTERRUPT AFTER B2
239+
var interruptAfterB2 = workflowParent.compile(
240+
CompileConfig.builder()
241+
.checkpointSaver(saver)
242+
.interruptAfter( B_B2 )
243+
.build());
244+
245+
assertIterableEquals( List.of(
246+
START,
247+
"A",
248+
B_B1,
249+
B_B2
250+
), _execute( interruptAfterB2, Map.of() ) );
251+
252+
// RESUME AFTER B2
253+
assertIterableEquals( List.of(
254+
B_C,
255+
"C",
256+
END
257+
), _execute( interruptAfterB2, null ) );
258+
259+
// INTERRUPT BEFORE C
260+
var interruptBeforeC = workflowParent.compile(
261+
CompileConfig.builder()
262+
.checkpointSaver(saver)
263+
.interruptBefore( "C" )
264+
.build());
265+
266+
assertIterableEquals( List.of(
267+
START,
268+
"A",
269+
B_B1,
270+
B_B2,
271+
B_C
272+
), _execute( interruptBeforeC, Map.of() ) );
273+
274+
// RESUME AFTER B2
275+
assertIterableEquals( List.of(
276+
"C",
277+
END
278+
), _execute( interruptBeforeC, null ) );
279+
280+
// INTERRUPT BEFORE SUBGRAPH B
281+
var exception = assertThrows(GraphStateException.class, () -> workflowParent.compile(
282+
CompileConfig.builder()
283+
.checkpointSaver(saver)
284+
.interruptBefore( "B" )
285+
.build()));
286+
System.out.println(exception.getMessage());
287+
assertEquals("node 'B' configured as interruption doesn't exist!", exception.getMessage());
288+
289+
}
290+
291+
@Test
292+
public void testMergeSubgraph04() throws Exception {
293+
294+
var workflowChild = new MessagesStateGraph<String>()
295+
.addNode("B1", _makeNode("B1") )
296+
.addNode("B2", _makeNode( "B2" ) )
297+
.addNode("C", _makeNode( "subgraph(C)" ) )
298+
.addEdge(START, "B1")
299+
.addEdge("B1", "B2")
300+
.addConditionalEdges( "B2",
301+
edge_async(state -> "c"),
302+
Map.of( END, END, "c", "C") )
303+
.addEdge("C", END)
304+
;
305+
306+
var workflowParent = new MessagesStateGraph<String>()
307+
.addNode("A", _makeNode("A") )
308+
.addSubgraph("B", workflowChild )
309+
.addNode("C", _makeNode("C") )
310+
.addConditionalEdges(START,
311+
edge_async(state -> "a"),
312+
Map.of( "a", "A", "b", "B") )
313+
.addEdge("A", "B")
314+
.addConditionalEdges("B",
315+
edge_async(state -> "c"),
316+
Map.of( "c", "C", "a", "A"/*END, END*/) )
317+
.addEdge("C", END)
318+
;
319+
188320
var processed = StateGraphNodesAndEdges.process( workflowParent );
189321
processed.nodes().elements.forEach( System.out::println );
190322
processed.edges().elements.forEach( System.out::println );
@@ -196,28 +328,6 @@ public void testMergeSubgraph03() throws Exception {
196328
var B_B2 = SubGraphNode.formatId( "B", "B2");
197329
var B_C = SubGraphNode.formatId( "B", "C");
198330

199-
List.of(
200-
"Node(A,action)",
201-
"Node(C,action)",
202-
format("Node(%s,action)", B_B1 ),
203-
format("Node(%s,action)", B_B2 ),
204-
format("Node(%s,action)", B_C)
205-
).forEach( n -> assertTrue( processed.nodes().elements.stream()
206-
.anyMatch(n1 -> Objects.equals(n, n1.toString())),
207-
format( "node \"%s\" doesn't have a match!", n ) ));
208-
209-
210-
List.of(
211-
format("Edge[sourceId=__START__, targets=[EdgeValue[id=null, value=EdgeCondition[ action, mapping={a=A, b=%s}]]]", B_B1),
212-
"Edge[sourceId=C, targets=[EdgeValue[id=__END__, value=null]]]",
213-
format("Edge[sourceId=A, targets=[EdgeValue[id=%s, value=null]]]", B_B1),
214-
format("Edge[sourceId=%s, targets=[EdgeValue[id=null, value=EdgeCondition[ action, mapping={c=%s, __END__=C}]]]", B_B2, B_C ),
215-
format("Edge[sourceId=%s, targets=[EdgeValue[id=C, value=null]]]", B_C ),
216-
format("Edge[sourceId=%s, targets=[EdgeValue[id=%s, value=null]]]", B_B1, B_B2)
217-
).forEach( e -> assertTrue( processed.edges().elements.stream()
218-
.anyMatch( e1 -> Objects.equals(e,e1.toString()) ) ,
219-
format( "edge \"%s\" doesn't have a match!", e ) ) );
220-
221331
var app = workflowParent.compile();
222332

223333
assertIterableEquals( List.of(
@@ -231,9 +341,8 @@ public void testMergeSubgraph03() throws Exception {
231341
), _execute( app, Map.of() ) );
232342

233343
}
234-
235344
@Test
236-
public void testMergeSubgraphWithInterruption() throws Exception {
345+
public void testMergeSubgraph04WithInterruption() throws Exception {
237346

238347
var workflowChild = new MessagesStateGraph<String>()
239348
.addNode("B1", _makeNode("B1") )
@@ -251,13 +360,16 @@ public void testMergeSubgraphWithInterruption() throws Exception {
251360
.addNode("A", _makeNode("A") )
252361
.addSubgraph("B", workflowChild )
253362
.addNode("C", _makeNode("C") )
363+
.addNode("C1", _makeNode("C1") )
254364
.addConditionalEdges(START,
255365
edge_async(state -> "a"),
256366
Map.of( "a", "A", "b", "B") )
257367
.addEdge("A", "B")
258-
.addEdge("B", "C")
368+
.addConditionalEdges("B",
369+
edge_async(state -> "c"),
370+
Map.of( "c", "C1", "a", "A" /*END, END*/) )
371+
.addEdge("C1", "C")
259372
.addEdge("C", END)
260-
//.compile(compileConfig)
261373
;
262374

263375
var B_B1 = SubGraphNode.formatId( "B", "B1");
@@ -268,15 +380,16 @@ public void testMergeSubgraphWithInterruption() throws Exception {
268380

269381
var withSaver = workflowParent.compile(
270382
CompileConfig.builder()
271-
.checkpointSaver(saver)
272-
.build());
383+
.checkpointSaver(saver)
384+
.build());
273385

274386
assertIterableEquals( List.of(
275387
START,
276388
"A",
277389
B_B1,
278390
B_B2,
279391
B_C,
392+
"C1",
280393
"C",
281394
END
282395
), _execute( withSaver, Map.of()) );
@@ -297,13 +410,14 @@ public void testMergeSubgraphWithInterruption() throws Exception {
297410
assertIterableEquals( List.of(
298411
B_B2,
299412
B_C,
413+
"C1",
300414
"C",
301415
END
302416
), _execute( interruptAfterB1, null ) );
303417

304418
// INTERRUPT AFTER B2
305419
var interruptAfterB2 = workflowParent.compile(
306-
CompileConfig.builder()
420+
CompileConfig.builder()
307421
.checkpointSaver(saver)
308422
.interruptAfter( B_B2 )
309423
.build());
@@ -318,6 +432,7 @@ public void testMergeSubgraphWithInterruption() throws Exception {
318432
// RESUME AFTER B2
319433
assertIterableEquals( List.of(
320434
B_C,
435+
"C1",
321436
"C",
322437
END
323438
), _execute( interruptAfterB2, null ) );
@@ -334,7 +449,8 @@ public void testMergeSubgraphWithInterruption() throws Exception {
334449
"A",
335450
B_B1,
336451
B_B2,
337-
B_C
452+
B_C,
453+
"C1"
338454
), _execute( interruptBeforeC, Map.of() ) );
339455

340456
// RESUME AFTER B2
@@ -354,6 +470,7 @@ public void testMergeSubgraphWithInterruption() throws Exception {
354470

355471
}
356472

473+
357474
@Test
358475
public void testCheckpointWithSubgraph() throws Exception {
359476

0 commit comments

Comments
 (0)