30
30
import com .google .cloud .vertexai .api .GenerationConfig ;
31
31
import com .google .cloud .vertexai .api .SafetySetting ;
32
32
import com .google .cloud .vertexai .api .Tool ;
33
+ import com .google .cloud .vertexai .api .ToolConfig ;
33
34
import com .google .common .base .Strings ;
34
35
import com .google .common .collect .ImmutableList ;
35
36
import com .google .errorprone .annotations .CanIgnoreReturnValue ;
@@ -46,6 +47,7 @@ public final class GenerativeModel {
46
47
private final GenerationConfig generationConfig ;
47
48
private final ImmutableList <SafetySetting > safetySettings ;
48
49
private final ImmutableList <Tool > tools ;
50
+ private final Optional <ToolConfig > toolConfig ;
49
51
private final Optional <Content > systemInstruction ;
50
52
51
53
/**
@@ -65,6 +67,7 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
65
67
ImmutableList .of (),
66
68
ImmutableList .of (),
67
69
Optional .empty (),
70
+ Optional .empty (),
68
71
vertexAi );
69
72
}
70
73
@@ -79,6 +82,10 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
79
82
* that will be used by default for generating response
80
83
* @param tools a list of {@link com.google.cloud.vertexai.api.Tool} instances that can be used by
81
84
* the model as auxiliary tools to generate content.
85
+ * @param toolConfig a {@link com.google.cloud.vertexai.api.ToolConfig} instance that will be used
86
+ * to specify the tool configuration.
87
+ * @param systemInstruction a {@link com.google.cloud.vertexai.api.Content} instance that will be
88
+ * used by default for generating response.
82
89
* @param vertexAi a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
83
90
* for the generative model
84
91
*/
@@ -87,6 +94,7 @@ private GenerativeModel(
87
94
GenerationConfig generationConfig ,
88
95
ImmutableList <SafetySetting > safetySettings ,
89
96
ImmutableList <Tool > tools ,
97
+ Optional <ToolConfig > toolConfig ,
90
98
Optional <Content > systemInstruction ,
91
99
VertexAI vertexAi ) {
92
100
checkArgument (
@@ -98,6 +106,8 @@ private GenerativeModel(
98
106
checkNotNull (generationConfig , "GenerationConfig can't be null." );
99
107
checkNotNull (safetySettings , "ImmutableList<SafetySettings> can't be null." );
100
108
checkNotNull (tools , "ImmutableList<Tool> can't be null." );
109
+ checkNotNull (toolConfig , "Optional<ToolConfig> can't be null." );
110
+ checkNotNull (systemInstruction , "Optional<Content> can't be null." );
101
111
102
112
this .resourceName = getResourceName (modelName , vertexAi );
103
113
// reconcileModelName should be called after getResourceName.
@@ -106,6 +116,7 @@ private GenerativeModel(
106
116
this .generationConfig = generationConfig ;
107
117
this .safetySettings = safetySettings ;
108
118
this .tools = tools ;
119
+ this .toolConfig = toolConfig ;
109
120
// We remove the role in the system instruction content because it's officially documented
110
121
// to be used without role specified:
111
122
// https://cloud.google.com/vertex-ai/generative-ai/docs/samples/generativeaionvertexai-gemini-system-instruction
@@ -128,6 +139,7 @@ public static class Builder {
128
139
private GenerationConfig generationConfig = GenerationConfig .getDefaultInstance ();
129
140
private ImmutableList <SafetySetting > safetySettings = ImmutableList .of ();
130
141
private ImmutableList <Tool > tools = ImmutableList .of ();
142
+ private Optional <ToolConfig > toolConfig = Optional .empty ();
131
143
private Optional <Content > systemInstruction = Optional .empty ();
132
144
133
145
public GenerativeModel build () {
@@ -136,7 +148,13 @@ public GenerativeModel build() {
136
148
"modelName is required. Please call setModelName() before building." );
137
149
checkNotNull (vertexAi , "vertexAi is required. Please call setVertexAi() before building." );
138
150
return new GenerativeModel (
139
- modelName , generationConfig , safetySettings , tools , systemInstruction , vertexAi );
151
+ modelName ,
152
+ generationConfig ,
153
+ safetySettings ,
154
+ tools ,
155
+ toolConfig ,
156
+ systemInstruction ,
157
+ vertexAi );
140
158
}
141
159
142
160
/**
@@ -204,6 +222,19 @@ public Builder setTools(List<Tool> tools) {
204
222
return this ;
205
223
}
206
224
225
+ /**
226
+ * Sets a {@link com.google.cloud.vertexai.api.ToolConfig} that will be used by default to
227
+ * interact with the generative model.
228
+ */
229
+ @ CanIgnoreReturnValue
230
+ public Builder setToolConfig (ToolConfig toolConfig ) {
231
+ checkNotNull (
232
+ toolConfig ,
233
+ "toolConfig can't be null. Use Optional.empty() if no tool config is intended." );
234
+ this .toolConfig = Optional .of (toolConfig );
235
+ return this ;
236
+ }
237
+
207
238
/**
208
239
* Sets a system instruction that will be used by default to interact with the generative model.
209
240
*/
@@ -228,7 +259,13 @@ public Builder setSystemInstruction(Content systemInstruction) {
228
259
public GenerativeModel withGenerationConfig (GenerationConfig generationConfig ) {
229
260
checkNotNull (generationConfig , "GenerationConfig can't be null." );
230
261
return new GenerativeModel (
231
- modelName , generationConfig , safetySettings , tools , systemInstruction , vertexAi );
262
+ modelName ,
263
+ generationConfig ,
264
+ safetySettings ,
265
+ tools ,
266
+ toolConfig ,
267
+ systemInstruction ,
268
+ vertexAi );
232
269
}
233
270
234
271
/**
@@ -247,6 +284,7 @@ public GenerativeModel withSafetySettings(List<SafetySetting> safetySettings) {
247
284
generationConfig ,
248
285
ImmutableList .copyOf (safetySettings ),
249
286
tools ,
287
+ toolConfig ,
250
288
systemInstruction ,
251
289
vertexAi );
252
290
}
@@ -265,6 +303,28 @@ public GenerativeModel withTools(List<Tool> tools) {
265
303
generationConfig ,
266
304
safetySettings ,
267
305
ImmutableList .copyOf (tools ),
306
+ toolConfig ,
307
+ systemInstruction ,
308
+ vertexAi );
309
+ }
310
+
311
+ /**
312
+ * Creates a copy of the current model with updated tool config.
313
+ *
314
+ * @param toolConfig a {@link com.google.cloud.vertexai.api.ToolConfig} that will be used in the
315
+ * new model.
316
+ * @return a new {@link GenerativeModel} instance with the specified tool config.
317
+ */
318
+ public GenerativeModel withToolConfig (ToolConfig toolConfig ) {
319
+ checkNotNull (
320
+ toolConfig ,
321
+ "toolConfig can't be null. Use Optional.empty() if no tool config is intended." );
322
+ return new GenerativeModel (
323
+ modelName ,
324
+ generationConfig ,
325
+ safetySettings ,
326
+ tools ,
327
+ Optional .of (toolConfig ),
268
328
systemInstruction ,
269
329
vertexAi );
270
330
}
@@ -286,6 +346,7 @@ public GenerativeModel withSystemInstruction(Content systemInstruction) {
286
346
generationConfig ,
287
347
safetySettings ,
288
348
tools ,
349
+ toolConfig ,
289
350
Optional .of (systemInstruction ),
290
351
vertexAi );
291
352
}
@@ -537,6 +598,10 @@ private GenerateContentRequest buildGenerateContentRequest(List<Content> content
537
598
.addAllSafetySettings (safetySettings )
538
599
.addAllTools (tools );
539
600
601
+ if (toolConfig .isPresent ()) {
602
+ requestBuilder .setToolConfig (toolConfig .get ());
603
+ }
604
+
540
605
if (systemInstruction .isPresent ()) {
541
606
requestBuilder .setSystemInstruction (systemInstruction .get ());
542
607
}
@@ -568,6 +633,13 @@ public ImmutableList<Tool> getTools() {
568
633
return tools ;
569
634
}
570
635
636
+ /**
637
+ * Returns the optional {@link com.google.cloud.vertexai.api.ToolConfig} of this generative model.
638
+ */
639
+ public Optional <ToolConfig > getToolConfig () {
640
+ return toolConfig ;
641
+ }
642
+
571
643
/** Returns the optional system instruction of this generative model. */
572
644
public Optional <Content > getSystemInstruction () {
573
645
return systemInstruction ;
0 commit comments