6
6
import jakarta .servlet .http .HttpServlet ;
7
7
import jakarta .servlet .http .HttpServletRequest ;
8
8
import jakarta .servlet .http .HttpServletResponse ;
9
- import org .bsc .async .AsyncGenerator ;
10
9
import org .bsc .langgraph4j .state .AgentState ;
11
10
import org .eclipse .jetty .ee10 .servlet .ServletContextHandler ;
12
11
import org .eclipse .jetty .ee10 .servlet .ServletHolder ;
13
12
import org .eclipse .jetty .server .Handler ;
14
13
import org .eclipse .jetty .server .Server ;
15
14
import org .eclipse .jetty .server .ServerConnector ;
16
15
import org .eclipse .jetty .server .handler .ResourceHandler ;
16
+ import org .eclipse .jetty .util .resource .Resource ;
17
17
import org .eclipse .jetty .util .resource .ResourceFactory ;
18
18
19
19
import java .io .IOException ;
20
20
import java .io .PrintWriter ;
21
- import java .nio .file .Path ;
22
- import java .nio .file .Paths ;
23
- import java .util .Arrays ;
24
- import java .util .List ;
21
+ import java .util .HashMap ;
25
22
import java .util .Map ;
26
23
import java .util .Objects ;
27
24
import java .util .concurrent .CompletableFuture ;
28
25
import java .util .concurrent .TimeUnit ;
29
26
27
+
30
28
public interface LangGraphStreamingServer {
31
29
32
30
CompletableFuture <Void > start () throws Exception ;
33
31
34
- public static <State extends AgentState > LangGraphStreamingServer of (CompiledGraph <State > compiledGraph ) throws Exception {
35
-
36
- Server server = new Server ();
37
- ServerConnector connector = new ServerConnector (server );
38
- connector .setPort (8080 );
39
- server .addConnector (connector );
40
-
41
- ResourceHandler resourceHandler = new ResourceHandler ();
42
- Path publicResourcesPath = Paths .get ( "jetty" , "src" , "main" , "webapp" );
43
- resourceHandler .setBaseResource (ResourceFactory .of (resourceHandler ).newResource (publicResourcesPath ));
44
- resourceHandler .setDirAllowed (true );
45
-
46
- ServletContextHandler context = new ServletContextHandler (ServletContextHandler .SESSIONS );
47
- context .setContextPath ("/" );
48
- // Add the streaming servlet
49
- context .addServlet (new ServletHolder (new StreamingServlet <State >(compiledGraph )), "/stream" );
50
- context .addServlet (new ServletHolder (new GraphServlet <State >(compiledGraph )), "/graph" );
51
-
52
- Handler .Sequence handlerList = new Handler .Sequence (resourceHandler , context );
53
-
54
- server .setHandler (handlerList );
55
-
56
- return new LangGraphStreamingServer () {
57
- @ Override
58
- public CompletableFuture <Void > start () throws Exception {
59
- return CompletableFuture .runAsync ( () -> {
60
- try {
61
- server .start ();
62
- } catch ( Exception e ) {
63
- throw new RuntimeException (e );
64
- }
65
- }, Runnable ::run );
66
- }
67
- };
32
+ static Builder builder () {
33
+ return new Builder ();
68
34
}
69
35
70
- class StreamingServlet <State extends AgentState > extends HttpServlet {
71
- final CompiledGraph <State > compiledGraph ;
72
- final ObjectMapper objectMapper = new ObjectMapper ();
36
+ class Builder {
37
+ private int port = 8080 ;
38
+ private Map <String ,ArgumentMetadata > inputArgs = new HashMap <>();
39
+
40
+ public Builder port (int port ) {
41
+ this .port = port ;
42
+ return this ;
43
+ }
73
44
74
- public StreamingServlet ( CompiledGraph <State > compiledGraph ) {
75
- Objects .requireNonNull (compiledGraph , "compiledGraph cannot be null" );
76
- this .compiledGraph = compiledGraph ;
45
+ public Builder addInputStringArg (String name , boolean required ) {
46
+ inputArgs .put (name , new ArgumentMetadata ("string" , required ) );
47
+ return this ;
48
+ }
49
+ public Builder addInputStringArg (String name ) {
50
+ inputArgs .put (name , new ArgumentMetadata ("string" , true ) );
51
+ return this ;
77
52
}
78
-
79
- @ Override
80
- protected void doPost (HttpServletRequest request , HttpServletResponse response ) throws ServletException , IOException {
81
- response .setHeader ("Accept" , "application/json" );
82
- response .setContentType ("text/plain" );
83
- response .setCharacterEncoding ("UTF-8" );
84
53
85
- final PrintWriter writer = response .getWriter ();
54
+ public <State extends AgentState > LangGraphStreamingServer build (CompiledGraph <State > compiledGraph ) {
55
+ Server server = new Server ();
56
+ ServerConnector connector = new ServerConnector (server );
57
+ connector .setPort (port );
58
+ server .addConnector (connector );
59
+
60
+ ResourceHandler resourceHandler = new ResourceHandler ();
61
+
62
+ // Path publicResourcesPath = Paths.get("jetty", "src", "main", "webapp");
63
+ // Resource baseResource = ResourceFactory.of(resourceHandler).newResource(publicResourcesPath));
64
+ Resource baseResource = ResourceFactory .of (resourceHandler ).newClassLoaderResource ("webapp" );
65
+ resourceHandler .setBaseResource (baseResource );
66
+
67
+ resourceHandler .setDirAllowed (true );
68
+
69
+ ServletContextHandler context = new ServletContextHandler (ServletContextHandler .SESSIONS );
70
+ context .setContextPath ("/" );
71
+ // Add the streaming servlet
72
+ context .addServlet (new ServletHolder (new GraphExecutionServlet <State >(compiledGraph )), "/stream" );
73
+ context .addServlet (new ServletHolder (new GraphInitServlet <State >(compiledGraph , inputArgs )), "/init" );
86
74
87
- Map < String , Object > dataMap = objectMapper . readValue ( request . getInputStream (), new TypeReference < Map < String , Object >>() {} );
75
+ Handler . Sequence handlerList = new Handler . Sequence ( resourceHandler , context );
88
76
89
- // Start asynchronous processing
90
- request .startAsync ();
77
+ server .setHandler (handlerList );
91
78
92
- try {
93
- compiledGraph .stream (dataMap )
94
- .forEachAsync ( s -> {
79
+ return new LangGraphStreamingServer () {
80
+ @ Override
81
+ public CompletableFuture <Void > start () throws Exception {
82
+ return CompletableFuture .runAsync (() -> {
83
+ try {
84
+ server .start ();
85
+ } catch (Exception e ) {
86
+ throw new RuntimeException (e );
87
+ }
88
+ }, Runnable ::run );
89
+
90
+ }
91
+ };
92
+
93
+ }
94
+ }
95
+ }
96
+
97
+
98
+ class GraphExecutionServlet <State extends AgentState > extends HttpServlet {
99
+ final CompiledGraph <State > compiledGraph ;
100
+ final ObjectMapper objectMapper = new ObjectMapper ();
101
+
102
+ public GraphExecutionServlet (CompiledGraph <State > compiledGraph ) {
103
+ Objects .requireNonNull (compiledGraph , "compiledGraph cannot be null" );
104
+ this .compiledGraph = compiledGraph ;
105
+ }
106
+
107
+ @ Override
108
+ protected void doPost (HttpServletRequest request , HttpServletResponse response ) throws ServletException , IOException {
109
+ response .setHeader ("Accept" , "application/json" );
110
+ response .setContentType ("text/plain" );
111
+ response .setCharacterEncoding ("UTF-8" );
112
+
113
+ final PrintWriter writer = response .getWriter ();
114
+
115
+ Map <String , Object > dataMap = objectMapper .readValue (request .getInputStream (), new TypeReference <Map <String , Object >>() {
116
+ });
117
+
118
+ // Start asynchronous processing
119
+ request .startAsync ();
120
+
121
+ try {
122
+ compiledGraph .stream (dataMap )
123
+ .forEachAsync (s -> {
95
124
writer .println (s .node ());
96
125
writer .flush ();
97
126
@@ -100,41 +129,46 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
100
129
} catch (InterruptedException e ) {
101
130
throw new RuntimeException (e );
102
131
}
103
- }).thenAccept ( v -> {
104
- writer .close ();
105
- });
132
+ }).thenAccept (v -> {
133
+ writer .close ();
134
+ });
106
135
107
- } catch (Exception e ) {
108
- throw new RuntimeException (e );
109
- }
136
+ } catch (Exception e ) {
137
+ throw new RuntimeException (e );
110
138
}
111
139
}
140
+ }
112
141
113
- /**
114
- * return the graph representation in mermaid format
115
- */
116
- class GraphServlet < State extends AgentState > extends HttpServlet {
142
+ record ArgumentMetadata (
143
+ String type ,
144
+ boolean required
145
+ ) {}
117
146
118
- final CompiledGraph <State > compiledGraph ;
147
+ /**
148
+ * return the graph representation in mermaid format
149
+ */
150
+ class GraphInitServlet <State extends AgentState > extends HttpServlet {
119
151
120
- public GraphServlet ( CompiledGraph <State > compiledGraph ) {
121
- Objects .requireNonNull (compiledGraph , "compiledGraph cannot be null" );
122
- this .compiledGraph = compiledGraph ;
123
- }
152
+ final CompiledGraph <State > compiledGraph ;
153
+ final Map <String , ArgumentMetadata > inputArgs ;
124
154
125
- @ Override
126
- protected void doGet (HttpServletRequest request , HttpServletResponse response ) throws ServletException , IOException {
127
- response .setContentType ("text/plain" );
128
- response .setCharacterEncoding ("UTF-8" );
155
+ public GraphInitServlet (CompiledGraph <State > compiledGraph , Map <String , ArgumentMetadata > inputArgs ) {
156
+ Objects .requireNonNull (compiledGraph , "compiledGraph cannot be null" );
157
+ this .compiledGraph = compiledGraph ;
158
+ this .inputArgs = inputArgs ;
159
+ }
129
160
130
- GraphRepresentation result = compiledGraph .getGraph (GraphRepresentation .Type .MERMAID );
161
+ @ Override
162
+ protected void doGet (HttpServletRequest request , HttpServletResponse response ) throws ServletException , IOException {
163
+ response .setContentType ("text/plain" );
164
+ response .setCharacterEncoding ("UTF-8" );
131
165
132
- // Start asynchronous processing
133
- request .startAsync ();
134
- final PrintWriter writer = response .getWriter ();
135
- writer .println (result .getContent ());
136
- writer .close ();
137
- }
138
- }
166
+ GraphRepresentation result = compiledGraph .getGraph (GraphRepresentation .Type .MERMAID );
139
167
140
- }
168
+ // Start asynchronous processing
169
+ request .startAsync ();
170
+ final PrintWriter writer = response .getWriter ();
171
+ writer .println (result .getContent ());
172
+ writer .close ();
173
+ }
174
+ }
0 commit comments