Skip to content

Commit 0801812

Browse files
feat: [vertexai] support ToolConfig in GenerativeModel (#10950)
PiperOrigin-RevId: 642059737 Co-authored-by: Jaycee Li <jayceeli@google.com>
1 parent c79bfb5 commit 0801812

File tree

2 files changed

+124
-2
lines changed

2 files changed

+124
-2
lines changed

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java

+74-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import com.google.cloud.vertexai.api.GenerationConfig;
3131
import com.google.cloud.vertexai.api.SafetySetting;
3232
import com.google.cloud.vertexai.api.Tool;
33+
import com.google.cloud.vertexai.api.ToolConfig;
3334
import com.google.common.base.Strings;
3435
import com.google.common.collect.ImmutableList;
3536
import com.google.errorprone.annotations.CanIgnoreReturnValue;
@@ -46,6 +47,7 @@ public final class GenerativeModel {
4647
private final GenerationConfig generationConfig;
4748
private final ImmutableList<SafetySetting> safetySettings;
4849
private final ImmutableList<Tool> tools;
50+
private final Optional<ToolConfig> toolConfig;
4951
private final Optional<Content> systemInstruction;
5052

5153
/**
@@ -65,6 +67,7 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
6567
ImmutableList.of(),
6668
ImmutableList.of(),
6769
Optional.empty(),
70+
Optional.empty(),
6871
vertexAi);
6972
}
7073

@@ -79,6 +82,10 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
7982
* that will be used by default for generating response
8083
* @param tools a list of {@link com.google.cloud.vertexai.api.Tool} instances that can be used by
8184
* the model as auxiliary tools to generate content.
85+
* @param toolConfig a {@link com.google.cloud.vertexai.api.ToolConfig} instance that will be used
86+
* to specify the tool configuration.
87+
* @param systemInstruction a {@link com.google.cloud.vertexai.api.Content} instance that will be
88+
* used by default for generating response.
8289
* @param vertexAi a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
8390
* for the generative model
8491
*/
@@ -87,6 +94,7 @@ private GenerativeModel(
8794
GenerationConfig generationConfig,
8895
ImmutableList<SafetySetting> safetySettings,
8996
ImmutableList<Tool> tools,
97+
Optional<ToolConfig> toolConfig,
9098
Optional<Content> systemInstruction,
9199
VertexAI vertexAi) {
92100
checkArgument(
@@ -98,6 +106,8 @@ private GenerativeModel(
98106
checkNotNull(generationConfig, "GenerationConfig can't be null.");
99107
checkNotNull(safetySettings, "ImmutableList<SafetySettings> can't be null.");
100108
checkNotNull(tools, "ImmutableList<Tool> can't be null.");
109+
checkNotNull(toolConfig, "Optional<ToolConfig> can't be null.");
110+
checkNotNull(systemInstruction, "Optional<Content> can't be null.");
101111

102112
this.resourceName = getResourceName(modelName, vertexAi);
103113
// reconcileModelName should be called after getResourceName.
@@ -106,6 +116,7 @@ private GenerativeModel(
106116
this.generationConfig = generationConfig;
107117
this.safetySettings = safetySettings;
108118
this.tools = tools;
119+
this.toolConfig = toolConfig;
109120
// We remove the role in the system instruction content because it's officially documented
110121
// to be used without role specified:
111122
// https://cloud.google.com/vertex-ai/generative-ai/docs/samples/generativeaionvertexai-gemini-system-instruction
@@ -128,6 +139,7 @@ public static class Builder {
128139
private GenerationConfig generationConfig = GenerationConfig.getDefaultInstance();
129140
private ImmutableList<SafetySetting> safetySettings = ImmutableList.of();
130141
private ImmutableList<Tool> tools = ImmutableList.of();
142+
private Optional<ToolConfig> toolConfig = Optional.empty();
131143
private Optional<Content> systemInstruction = Optional.empty();
132144

133145
public GenerativeModel build() {
@@ -136,7 +148,13 @@ public GenerativeModel build() {
136148
"modelName is required. Please call setModelName() before building.");
137149
checkNotNull(vertexAi, "vertexAi is required. Please call setVertexAi() before building.");
138150
return new GenerativeModel(
139-
modelName, generationConfig, safetySettings, tools, systemInstruction, vertexAi);
151+
modelName,
152+
generationConfig,
153+
safetySettings,
154+
tools,
155+
toolConfig,
156+
systemInstruction,
157+
vertexAi);
140158
}
141159

142160
/**
@@ -204,6 +222,19 @@ public Builder setTools(List<Tool> tools) {
204222
return this;
205223
}
206224

225+
/**
226+
* Sets a {@link com.google.cloud.vertexai.api.ToolConfig} that will be used by default to
227+
* interact with the generative model.
228+
*/
229+
@CanIgnoreReturnValue
230+
public Builder setToolConfig(ToolConfig toolConfig) {
231+
checkNotNull(
232+
toolConfig,
233+
"toolConfig can't be null. Use Optional.empty() if no tool config is intended.");
234+
this.toolConfig = Optional.of(toolConfig);
235+
return this;
236+
}
237+
207238
/**
208239
* Sets a system instruction that will be used by default to interact with the generative model.
209240
*/
@@ -228,7 +259,13 @@ public Builder setSystemInstruction(Content systemInstruction) {
228259
public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) {
229260
checkNotNull(generationConfig, "GenerationConfig can't be null.");
230261
return new GenerativeModel(
231-
modelName, generationConfig, safetySettings, tools, systemInstruction, vertexAi);
262+
modelName,
263+
generationConfig,
264+
safetySettings,
265+
tools,
266+
toolConfig,
267+
systemInstruction,
268+
vertexAi);
232269
}
233270

234271
/**
@@ -247,6 +284,7 @@ public GenerativeModel withSafetySettings(List<SafetySetting> safetySettings) {
247284
generationConfig,
248285
ImmutableList.copyOf(safetySettings),
249286
tools,
287+
toolConfig,
250288
systemInstruction,
251289
vertexAi);
252290
}
@@ -265,6 +303,28 @@ public GenerativeModel withTools(List<Tool> tools) {
265303
generationConfig,
266304
safetySettings,
267305
ImmutableList.copyOf(tools),
306+
toolConfig,
307+
systemInstruction,
308+
vertexAi);
309+
}
310+
311+
/**
312+
* Creates a copy of the current model with updated tool config.
313+
*
314+
* @param toolConfig a {@link com.google.cloud.vertexai.api.ToolConfig} that will be used in the
315+
* new model.
316+
* @return a new {@link GenerativeModel} instance with the specified tool config.
317+
*/
318+
public GenerativeModel withToolConfig(ToolConfig toolConfig) {
319+
checkNotNull(
320+
toolConfig,
321+
"toolConfig can't be null. Use Optional.empty() if no tool config is intended.");
322+
return new GenerativeModel(
323+
modelName,
324+
generationConfig,
325+
safetySettings,
326+
tools,
327+
Optional.of(toolConfig),
268328
systemInstruction,
269329
vertexAi);
270330
}
@@ -286,6 +346,7 @@ public GenerativeModel withSystemInstruction(Content systemInstruction) {
286346
generationConfig,
287347
safetySettings,
288348
tools,
349+
toolConfig,
289350
Optional.of(systemInstruction),
290351
vertexAi);
291352
}
@@ -537,6 +598,10 @@ private GenerateContentRequest buildGenerateContentRequest(List<Content> content
537598
.addAllSafetySettings(safetySettings)
538599
.addAllTools(tools);
539600

601+
if (toolConfig.isPresent()) {
602+
requestBuilder.setToolConfig(toolConfig.get());
603+
}
604+
540605
if (systemInstruction.isPresent()) {
541606
requestBuilder.setSystemInstruction(systemInstruction.get());
542607
}
@@ -568,6 +633,13 @@ public ImmutableList<Tool> getTools() {
568633
return tools;
569634
}
570635

636+
/**
637+
* Returns the optional {@link com.google.cloud.vertexai.api.ToolConfig} of this generative model.
638+
*/
639+
public Optional<ToolConfig> getToolConfig() {
640+
return toolConfig;
641+
}
642+
571643
/** Returns the optional system instruction of this generative model. */
572644
public Optional<Content> getSystemInstruction() {
573645
return systemInstruction;

java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java

+50
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import com.google.cloud.vertexai.api.Content;
3232
import com.google.cloud.vertexai.api.CountTokensRequest;
3333
import com.google.cloud.vertexai.api.CountTokensResponse;
34+
import com.google.cloud.vertexai.api.FunctionCallingConfig;
3435
import com.google.cloud.vertexai.api.FunctionDeclaration;
3536
import com.google.cloud.vertexai.api.GenerateContentRequest;
3637
import com.google.cloud.vertexai.api.GenerateContentResponse;
@@ -44,6 +45,7 @@
4445
import com.google.cloud.vertexai.api.SafetySetting.HarmBlockThreshold;
4546
import com.google.cloud.vertexai.api.Schema;
4647
import com.google.cloud.vertexai.api.Tool;
48+
import com.google.cloud.vertexai.api.ToolConfig;
4749
import com.google.cloud.vertexai.api.Type;
4850
import com.google.cloud.vertexai.api.VertexAISearch;
4951
import java.util.ArrayList;
@@ -96,6 +98,13 @@ public final class GenerativeModelTest {
9698
.build())
9799
.addRequired("location")))
98100
.build();
101+
private static final ToolConfig DEFAULT_TOOL_CONFIG =
102+
ToolConfig.newBuilder()
103+
.setFunctionCallingConfig(
104+
FunctionCallingConfig.newBuilder()
105+
.setMode(FunctionCallingConfig.Mode.ANY)
106+
.addAllowedFunctionNames("getCurrentWeather"))
107+
.build();
99108
private static final Content DEFAULT_SYSTEM_INSTRUCTION =
100109
ContentMaker.fromString(
101110
"You're a helpful assistant that starts all its answers with: \"COOL\"");
@@ -404,6 +413,25 @@ public void generateContent_withDefaultTools_requestHasCorrectToolsAndText() thr
404413
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
405414
}
406415

416+
@Test
417+
public void generateContent_withDefaultToolConfig_requestHasCorrectToolConfigAndText()
418+
throws Exception {
419+
model =
420+
new GenerativeModel.Builder()
421+
.setModelName(MODEL_NAME)
422+
.setVertexAi(vertexAi)
423+
.setToolConfig(DEFAULT_TOOL_CONFIG)
424+
.build();
425+
426+
GenerateContentResponse unused = model.generateContent(TEXT);
427+
428+
ArgumentCaptor<GenerateContentRequest> request =
429+
ArgumentCaptor.forClass(GenerateContentRequest.class);
430+
verify(mockUnaryCallable).call(request.capture());
431+
assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT);
432+
assertThat(request.getValue().getToolConfig()).isEqualTo(DEFAULT_TOOL_CONFIG);
433+
}
434+
407435
@Test
408436
public void
409437
generateContent_withDefaultSystemInstruction_requestHasCorrectSystemInstructionAndText()
@@ -433,6 +461,7 @@ public void generateContent_withAllConfigsInFluentApi_requestHasCorrectFields()
433461
.withGenerationConfig(GENERATION_CONFIG)
434462
.withSafetySettings(safetySettings)
435463
.withTools(tools)
464+
.withToolConfig(DEFAULT_TOOL_CONFIG)
436465
.withSystemInstruction(DEFAULT_SYSTEM_INSTRUCTION)
437466
.generateContent(TEXT);
438467

@@ -444,6 +473,7 @@ public void generateContent_withAllConfigsInFluentApi_requestHasCorrectFields()
444473
assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
445474
assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
446475
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
476+
assertThat(request.getValue().getToolConfig()).isEqualTo(DEFAULT_TOOL_CONFIG);
447477
assertThat(request.getValue().getSystemInstruction()).isEqualTo(expectedSystemInstruction);
448478
}
449479

@@ -546,6 +576,24 @@ public void generateContentStream_withDefaultTools_requestHasCorrectTools() thro
546576
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
547577
}
548578

579+
@Test
580+
public void generateContentStream_withDefaultToolConfig_requestHasCorrectToolConfig()
581+
throws Exception {
582+
model =
583+
new GenerativeModel.Builder()
584+
.setModelName(MODEL_NAME)
585+
.setVertexAi(vertexAi)
586+
.setToolConfig(DEFAULT_TOOL_CONFIG)
587+
.build();
588+
589+
ResponseStream unused = model.generateContentStream(TEXT);
590+
591+
ArgumentCaptor<GenerateContentRequest> request =
592+
ArgumentCaptor.forClass(GenerateContentRequest.class);
593+
verify(mockServerStreamCallable).call(request.capture());
594+
assertThat(request.getValue().getToolConfig()).isEqualTo(DEFAULT_TOOL_CONFIG);
595+
}
596+
549597
@Test
550598
public void
551599
generateContentStream_withDefaultSystemInstruction_requestHasCorrectSystemInstruction()
@@ -576,6 +624,7 @@ public void generateContentStream_withAllConfigsInFluentApi_requestHasCorrectFie
576624
.withGenerationConfig(GENERATION_CONFIG)
577625
.withSafetySettings(safetySettings)
578626
.withTools(tools)
627+
.withToolConfig(DEFAULT_TOOL_CONFIG)
579628
.withSystemInstruction(DEFAULT_SYSTEM_INSTRUCTION)
580629
.generateContentStream(TEXT);
581630

@@ -587,6 +636,7 @@ public void generateContentStream_withAllConfigsInFluentApi_requestHasCorrectFie
587636
assertThat(request.getValue().getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
588637
assertThat(request.getValue().getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
589638
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
639+
assertThat(request.getValue().getToolConfig()).isEqualTo(DEFAULT_TOOL_CONFIG);
590640
assertThat(request.getValue().getSystemInstruction()).isEqualTo(expectedSystemInstruction);
591641
}
592642

0 commit comments

Comments
 (0)