Skip to content

Commit 5a10656

Browse files
feat: [vertexai] add FunctionDeclarationMaker.fromFunc to create FunctionDeclaration from a Java static method (#10915)
PiperOrigin-RevId: 639154403 Co-authored-by: Jaycee Li <jayceeli@google.com>
1 parent 5ebfc33 commit 5a10656

File tree

5 files changed

+360
-178
lines changed

5 files changed

+360
-178
lines changed

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java

+35-67
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import com.google.cloud.vertexai.api.GenerationConfig;
3030
import com.google.cloud.vertexai.api.SafetySetting;
3131
import com.google.cloud.vertexai.api.Tool;
32-
import com.google.cloud.vertexai.api.ToolConfig;
3332
import com.google.common.collect.ImmutableList;
3433
import java.io.IOException;
3534
import java.util.ArrayList;
@@ -41,8 +40,8 @@ public final class ChatSession {
4140
private final GenerativeModel model;
4241
private final Optional<ChatSession> rootChatSession;
4342
private final Optional<AutomaticFunctionCallingResponder> automaticFunctionCallingResponder;
44-
private List<Content> history;
45-
private int previousHistorySize;
43+
private List<Content> history = new ArrayList<>();
44+
private int previousHistorySize = 0;
4645
private Optional<ResponseStream<GenerateContentResponse>> currentResponseStream;
4746
private Optional<GenerateContentResponse> currentResponse;
4847

@@ -51,17 +50,14 @@ public final class ChatSession {
5150
* GenerationConfig) inherits from the model.
5251
*/
5352
public ChatSession(GenerativeModel model) {
54-
this(model, new ArrayList<>(), 0, Optional.empty(), Optional.empty());
53+
this(model, Optional.empty(), Optional.empty());
5554
}
5655

5756
/**
5857
* Creates a new chat session given a GenerativeModel instance and a root chat session.
5958
* Configurations of the chat (e.g., GenerationConfig) inherits from the model.
6059
*
6160
* @param model a {@link GenerativeModel} instance that generates contents in the chat.
62-
* @param history a list of {@link Content} containing interleaving conversation between "user"
63-
* and "model".
64-
* @param previousHistorySize the size of the previous history.
6561
* @param rootChatSession a root {@link ChatSession} instance. All the chat history in the current
6662
* chat session will be merged to the root chat session.
6763
* @param automaticFunctionCallingResponder an {@link AutomaticFunctionCallingResponder} instance
@@ -70,14 +66,10 @@ public ChatSession(GenerativeModel model) {
7066
*/
7167
private ChatSession(
7268
GenerativeModel model,
73-
List<Content> history,
74-
int previousHistorySize,
7569
Optional<ChatSession> rootChatSession,
7670
Optional<AutomaticFunctionCallingResponder> automaticFunctionCallingResponder) {
7771
checkNotNull(model, "model should not be null");
7872
this.model = model;
79-
this.history = history;
80-
this.previousHistorySize = previousHistorySize;
8173
this.rootChatSession = rootChatSession;
8274
this.automaticFunctionCallingResponder = automaticFunctionCallingResponder;
8375
currentResponseStream = Optional.empty();
@@ -92,12 +84,15 @@ private ChatSession(
9284
* @return a new {@link ChatSession} instance with the specified GenerationConfig.
9385
*/
9486
public ChatSession withGenerationConfig(GenerationConfig generationConfig) {
95-
return new ChatSession(
96-
model.withGenerationConfig(generationConfig),
97-
history,
98-
previousHistorySize,
99-
Optional.of(rootChatSession.orElse(this)),
100-
automaticFunctionCallingResponder);
87+
ChatSession rootChat = rootChatSession.orElse(this);
88+
ChatSession newChatSession =
89+
new ChatSession(
90+
model.withGenerationConfig(generationConfig),
91+
Optional.of(rootChat),
92+
automaticFunctionCallingResponder);
93+
newChatSession.history = history;
94+
newChatSession.previousHistorySize = previousHistorySize;
95+
return newChatSession;
10196
}
10297

10398
/**
@@ -108,12 +103,15 @@ public ChatSession withGenerationConfig(GenerationConfig generationConfig) {
108103
* @return a new {@link ChatSession} instance with the specified SafetySettings.
109104
*/
110105
public ChatSession withSafetySettings(List<SafetySetting> safetySettings) {
111-
return new ChatSession(
112-
model.withSafetySettings(safetySettings),
113-
history,
114-
previousHistorySize,
115-
Optional.of(rootChatSession.orElse(this)),
116-
automaticFunctionCallingResponder);
106+
ChatSession rootChat = rootChatSession.orElse(this);
107+
ChatSession newChatSession =
108+
new ChatSession(
109+
model.withSafetySettings(safetySettings),
110+
Optional.of(rootChat),
111+
automaticFunctionCallingResponder);
112+
newChatSession.history = history;
113+
newChatSession.previousHistorySize = previousHistorySize;
114+
return newChatSession;
117115
}
118116

119117
/**
@@ -124,44 +122,13 @@ public ChatSession withSafetySettings(List<SafetySetting> safetySettings) {
124122
* @return a new {@link ChatSession} instance with the specified Tools.
125123
*/
126124
public ChatSession withTools(List<Tool> tools) {
127-
return new ChatSession(
128-
model.withTools(tools),
129-
history,
130-
previousHistorySize,
131-
Optional.of(rootChatSession.orElse(this)),
132-
automaticFunctionCallingResponder);
133-
}
134-
135-
/**
136-
* Creates a copy of the current ChatSession with updated ToolConfig.
137-
*
138-
* @param toolConfig a {@link com.google.cloud.vertexai.api.ToolConfig} that will be used in the
139-
* new ChatSession.
140-
* @return a new {@link ChatSession} instance with the specified ToolConfigs.
141-
*/
142-
public ChatSession withToolConfig(ToolConfig toolConfig) {
143-
return new ChatSession(
144-
model.withToolConfig(toolConfig),
145-
history,
146-
previousHistorySize,
147-
Optional.of(rootChatSession.orElse(this)),
148-
automaticFunctionCallingResponder);
149-
}
150-
151-
/**
152-
* Creates a copy of the current ChatSession with updated SystemInstruction.
153-
*
154-
* @param systemInstruction a {@link com.google.cloud.vertexai.api.Content} containing system
155-
* instructions.
156-
* @return a new {@link ChatSession} instance with the specified ToolConfigs.
157-
*/
158-
public ChatSession withSystemInstruction(Content systemInstruction) {
159-
return new ChatSession(
160-
model.withSystemInstruction(systemInstruction),
161-
history,
162-
previousHistorySize,
163-
Optional.of(rootChatSession.orElse(this)),
164-
automaticFunctionCallingResponder);
125+
ChatSession rootChat = rootChatSession.orElse(this);
126+
ChatSession newChatSession =
127+
new ChatSession(
128+
model.withTools(tools), Optional.of(rootChat), automaticFunctionCallingResponder);
129+
newChatSession.history = history;
130+
newChatSession.previousHistorySize = previousHistorySize;
131+
return newChatSession;
165132
}
166133

167134
/**
@@ -174,12 +141,13 @@ public ChatSession withSystemInstruction(Content systemInstruction) {
174141
*/
175142
public ChatSession withAutomaticFunctionCallingResponder(
176143
AutomaticFunctionCallingResponder automaticFunctionCallingResponder) {
177-
return new ChatSession(
178-
model,
179-
history,
180-
previousHistorySize,
181-
Optional.of(rootChatSession.orElse(this)),
182-
Optional.of(automaticFunctionCallingResponder));
144+
ChatSession rootChat = rootChatSession.orElse(this);
145+
ChatSession newChatSession =
146+
new ChatSession(
147+
model, Optional.of(rootChat), Optional.of(automaticFunctionCallingResponder));
148+
newChatSession.history = history;
149+
newChatSession.previousHistorySize = previousHistorySize;
150+
return newChatSession;
183151
}
184152

185153
/**

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/FunctionDeclarationMaker.java

+93
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,15 @@
1919
import static com.google.common.base.Preconditions.checkNotNull;
2020

2121
import com.google.cloud.vertexai.api.FunctionDeclaration;
22+
import com.google.cloud.vertexai.api.Schema;
23+
import com.google.cloud.vertexai.api.Type;
2224
import com.google.common.base.Strings;
2325
import com.google.gson.JsonObject;
2426
import com.google.protobuf.InvalidProtocolBufferException;
2527
import com.google.protobuf.util.JsonFormat;
28+
import java.lang.reflect.Method;
29+
import java.lang.reflect.Modifier;
30+
import java.lang.reflect.Parameter;
2631

2732
/** Helper class to create {@link com.google.cloud.vertexai.api.FunctionDeclaration} */
2833
public final class FunctionDeclarationMaker {
@@ -60,4 +65,92 @@ public static FunctionDeclaration fromJsonObject(JsonObject jsonObject)
6065
checkNotNull(jsonObject, "JsonObject can't be null.");
6166
return fromJsonString(jsonObject.toString());
6267
}
68+
69+
/**
70+
* Creates a FunctionDeclaration from a Java static method
71+
*
72+
* <p><b>Note:</b>: If you don't want to manually provide parameter names, you can ignore
73+
* `orderedParameterNames` and compile your code with the "-parameters" flag. In this case, the
74+
* parameter names can be auto retrieved from reflection.
75+
*
76+
* @param functionDescription A description of the method.
77+
* @param function A Java static method.
78+
* @param orderedParameterNames A list of parameter names in the order they are passed to the
79+
* method.
80+
* @return a {@link com.google.cloud.vertexai.api.FunctionDeclaration} instance.
81+
* @throws IllegalArgumentException if the method is not a static method or the number of provided
82+
* parameter names doesn't match the number of parameters in the callable function or
83+
* parameter types in this method are not String, boolean, int, double, or float.
84+
* @throws IllegalStateException if the parameter names are not provided and cannot be retrieved
85+
* from reflection
86+
*/
87+
public static FunctionDeclaration fromFunc(
88+
String functionDescription, Method function, String... orderedParameterNames) {
89+
if (!Modifier.isStatic(function.getModifiers())) {
90+
throw new IllegalArgumentException(
91+
"Instance methods are not supported. Please use static methods.");
92+
}
93+
Schema.Builder parametersBuilder = Schema.newBuilder().setType(Type.OBJECT);
94+
95+
Parameter[] parameters = function.getParameters();
96+
// If parameter names are provided, the number of parameter names should match the number of
97+
// parameters in the method.
98+
if (orderedParameterNames.length > 0 && orderedParameterNames.length != parameters.length) {
99+
throw new IllegalArgumentException(
100+
"The number of parameter names does not match the number of parameters in the method.");
101+
}
102+
103+
for (int i = 0; i < parameters.length; i++) {
104+
if (orderedParameterNames.length == 0) {
105+
// If parameter names are not provided, try to retrieve them from reflection.
106+
if (!parameters[i].isNamePresent()) {
107+
throw new IllegalStateException(
108+
"Failed to retrieve the parameter name from reflection. Please compile your"
109+
+ " code with \"-parameters\" flag or use `fromFunc(String, Method,"
110+
+ " String...)` to manually enter parameter names");
111+
}
112+
addParameterToParametersBuilder(
113+
parametersBuilder, parameters[i].getName(), parameters[i].getType());
114+
} else {
115+
addParameterToParametersBuilder(
116+
parametersBuilder, orderedParameterNames[i], parameters[i].getType());
117+
}
118+
}
119+
120+
return FunctionDeclaration.newBuilder()
121+
.setName(function.getName())
122+
.setDescription(functionDescription)
123+
.setParameters(parametersBuilder)
124+
.build();
125+
}
126+
127+
/** Adds a parameter to the parameters builder. */
128+
private static void addParameterToParametersBuilder(
129+
Schema.Builder parametersBuilder, String parameterName, Class<?> parameterType) {
130+
Schema.Builder parameterBuilder = Schema.newBuilder().setDescription(parameterName);
131+
switch (parameterType.getName()) {
132+
case "java.lang.String":
133+
parameterBuilder.setType(Type.STRING);
134+
break;
135+
case "boolean":
136+
parameterBuilder.setType(Type.BOOLEAN);
137+
break;
138+
case "int":
139+
parameterBuilder.setType(Type.INTEGER);
140+
break;
141+
case "double":
142+
case "float":
143+
parameterBuilder.setType(Type.NUMBER);
144+
break;
145+
default:
146+
throw new IllegalArgumentException(
147+
"Unsupported parameter type "
148+
+ parameterType.getName()
149+
+ " for parameter "
150+
+ parameterName);
151+
}
152+
parametersBuilder
153+
.addRequired(parameterName)
154+
.putProperties(parameterName, parameterBuilder.build());
155+
}
63156
}

java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java

+1-19
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import com.google.cloud.vertexai.api.Candidate.FinishReason;
3030
import com.google.cloud.vertexai.api.Content;
3131
import com.google.cloud.vertexai.api.FunctionCall;
32-
import com.google.cloud.vertexai.api.FunctionCallingConfig;
3332
import com.google.cloud.vertexai.api.FunctionDeclaration;
3433
import com.google.cloud.vertexai.api.GenerateContentRequest;
3534
import com.google.cloud.vertexai.api.GenerateContentResponse;
@@ -41,7 +40,6 @@
4140
import com.google.cloud.vertexai.api.SafetySetting.HarmBlockThreshold;
4241
import com.google.cloud.vertexai.api.Schema;
4342
import com.google.cloud.vertexai.api.Tool;
44-
import com.google.cloud.vertexai.api.ToolConfig;
4543
import com.google.cloud.vertexai.api.Type;
4644
import com.google.protobuf.Struct;
4745
import com.google.protobuf.Value;
@@ -176,16 +174,6 @@ public final class ChatSessionTest {
176174
.build())
177175
.addRequired("location")))
178176
.build();
179-
private static final ToolConfig TOOL_CONFIG =
180-
ToolConfig.newBuilder()
181-
.setFunctionCallingConfig(
182-
FunctionCallingConfig.newBuilder()
183-
.setMode(FunctionCallingConfig.Mode.ANY)
184-
.addAllowedFunctionNames("getCurrentWeather"))
185-
.build();
186-
private static final Content SYSTEM_INSTRUCTION =
187-
ContentMaker.fromString(
188-
"You're a helpful assistant that starts all its answers with: \"COOL\"");
189177

190178
@Rule public final MockitoRule mocksRule = MockitoJUnit.rule();
191179

@@ -530,9 +518,7 @@ public void testChatSessionMergeHistoryToRootChatSession() throws Exception {
530518
rootChat
531519
.withGenerationConfig(GENERATION_CONFIG)
532520
.withSafetySettings(Arrays.asList(SAFETY_SETTING))
533-
.withTools(Arrays.asList(TOOL))
534-
.withToolConfig(TOOL_CONFIG)
535-
.withSystemInstruction(SYSTEM_INSTRUCTION);
521+
.withTools(Arrays.asList(TOOL));
536522
response = childChat.sendMessage(SAMPLE_MESSAGE_2);
537523

538524
// (Assert) root chat history should contain all 4 contents
@@ -546,12 +532,8 @@ public void testChatSessionMergeHistoryToRootChatSession() throws Exception {
546532
ArgumentCaptor<GenerateContentRequest> request =
547533
ArgumentCaptor.forClass(GenerateContentRequest.class);
548534
verify(mockUnaryCallable, times(2)).call(request.capture());
549-
Content expectedSystemInstruction = SYSTEM_INSTRUCTION.toBuilder().clearRole().build();
550535
assertThat(request.getAllValues().get(1).getGenerationConfig()).isEqualTo(GENERATION_CONFIG);
551536
assertThat(request.getAllValues().get(1).getSafetySettings(0)).isEqualTo(SAFETY_SETTING);
552537
assertThat(request.getAllValues().get(1).getTools(0)).isEqualTo(TOOL);
553-
assertThat(request.getAllValues().get(1).getToolConfig()).isEqualTo(TOOL_CONFIG);
554-
assertThat(request.getAllValues().get(1).getSystemInstruction())
555-
.isEqualTo(expectedSystemInstruction);
556538
}
557539
}

0 commit comments

Comments
 (0)