Skip to content

Commit 1f81cf2

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: LLM - include error code into blocked response from TextGenerationModel, ChatModel, CodeChatMode, and CodeGenerationModel.
PiperOrigin-RevId: 582832899
1 parent 469c595 commit 1f81cf2

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

Diff for: tests/unit/aiplatform/test_language_models.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,8 @@
221221
_TEST_TEXT_GENERATION_PREDICTION = {
222222
"safetyAttributes": {
223223
"categories": ["Violent"],
224-
"blocked": False,
224+
"blocked": True,
225+
"errors": [100],
225226
"scores": [0.10000000149011612],
226227
},
227228
"content": """
@@ -254,6 +255,7 @@
254255
},
255256
"safetyAttributes": {
256257
"blocked": True,
258+
"errors": [100],
257259
"categories": ["Finance"],
258260
"scores": [0.1],
259261
},
@@ -301,6 +303,7 @@
301303
"scores": [0.1],
302304
"categories": ["Finance"],
303305
"blocked": True,
306+
"errors": [100],
304307
},
305308
],
306309
"candidates": [
@@ -326,6 +329,7 @@
326329
"scores": [0.1],
327330
"categories": ["Finance"],
328331
"blocked": True,
332+
"errors": [100],
329333
},
330334
],
331335
"groundingMetadata": [
@@ -373,6 +377,7 @@
373377
"scores": [0.1],
374378
"categories": ["Finance"],
375379
"blocked": True,
380+
"errors": [100],
376381
},
377382
],
378383
"groundingMetadata": [
@@ -430,6 +435,7 @@
430435
"safetyAttributes": [
431436
{
432437
"blocked": True,
438+
"errors": [100],
433439
"categories": ["Finance"],
434440
"scores": [0.1],
435441
}
@@ -440,6 +446,7 @@
440446
_TEST_CODE_GENERATION_PREDICTION = {
441447
"safetyAttributes": {
442448
"blocked": True,
449+
"errors": [100],
443450
"categories": ["Finance"],
444451
"scores": [0.1],
445452
},
@@ -1478,13 +1485,15 @@ def test_text_generation_ga(self):
14781485
stop_sequences=["\n"],
14791486
)
14801487

1488+
expected_errors = (100,)
14811489
prediction_parameters = mock_predict.call_args[1]["parameters"]
14821490
assert prediction_parameters["maxDecodeSteps"] == 128
14831491
assert prediction_parameters["temperature"] == 0.0
14841492
assert prediction_parameters["topP"] == 1.0
14851493
assert prediction_parameters["topK"] == 5
14861494
assert prediction_parameters["stopSequences"] == ["\n"]
14871495
assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]
1496+
assert response.errors == expected_errors
14881497

14891498
# Validating that unspecified parameters are not passed to the model
14901499
# (except `max_output_tokens`).
@@ -2893,12 +2902,16 @@ def test_chat_model_send_message_with_multiple_candidates(self):
28932902
)
28942903
expected_candidate_0 = expected_response_candidates[0]["content"]
28952904
expected_candidate_1 = expected_response_candidates[1]["content"]
2905+
expected_errors_0 = ()
2906+
expected_errors_1 = (100,)
28962907

28972908
response = chat.send_message(message_text1, candidate_count=2)
28982909
assert response.text == expected_candidate_0
28992910
assert len(response.candidates) == 2
29002911
assert response.candidates[0].text == expected_candidate_0
29012912
assert response.candidates[1].text == expected_candidate_1
2913+
assert response.candidates[0].errors == expected_errors_0
2914+
assert response.candidates[1].errors == expected_errors_1
29022915

29032916
assert len(chat.message_history) == 2
29042917
assert chat.message_history[0].author == chat.USER_AUTHOR

Diff for: vertexai/language_models/_language_models.py

+15
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Literal,
2626
Optional,
2727
Sequence,
28+
Tuple,
2829
Union,
2930
)
3031
import warnings
@@ -859,6 +860,9 @@ class TextGenerationResponse:
859860
Attributes:
860861
text: The generated text
861862
is_blocked: Whether the the request was blocked.
863+
errors: The error codes indicate why the response was blocked.
864+
Learn more information about safety errors here:
865+
this documentation https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_errors
862866
safety_attributes: Scores for safety attributes.
863867
Learn more about the safety attributes here:
864868
https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_descriptions
@@ -870,6 +874,7 @@ class TextGenerationResponse:
870874
text: str
871875
_prediction_response: Any
872876
is_blocked: bool = False
877+
errors: Tuple[int] = tuple()
873878
safety_attributes: Dict[str, float] = dataclasses.field(default_factory=dict)
874879
grounding_metadata: Optional[GroundingMetadata] = None
875880

@@ -882,6 +887,7 @@ def __repr__(self):
882887
"TextGenerationResponse("
883888
f"text={self.text!r}"
884889
f", is_blocked={self.is_blocked!r}"
890+
f", errors={self.errors!r}"
885891
f", safety_attributes={self.safety_attributes!r}"
886892
f", grounding_metadata={self.grounding_metadata!r}"
887893
")"
@@ -891,6 +897,7 @@ def __repr__(self):
891897
"TextGenerationResponse("
892898
f"text={self.text!r}"
893899
f", is_blocked={self.is_blocked!r}"
900+
f", errors={self.errors!r}"
894901
f", safety_attributes={self.safety_attributes!r}"
895902
")"
896903
)
@@ -1216,10 +1223,13 @@ def _parse_text_generation_model_response(
12161223
prediction = prediction_response.predictions[prediction_idx]
12171224
safety_attributes_dict = prediction.get("safetyAttributes", {})
12181225
grounding_metadata_dict = prediction.get("groundingMetadata", {})
1226+
errors_list = safety_attributes_dict.get("errors", [])
1227+
errors = tuple(map(int, errors_list))
12191228
return TextGenerationResponse(
12201229
text=prediction["content"],
12211230
_prediction_response=prediction_response,
12221231
is_blocked=safety_attributes_dict.get("blocked", False),
1232+
errors=errors,
12231233
safety_attributes=dict(
12241234
zip(
12251235
safety_attributes_dict.get("categories") or [],
@@ -1251,6 +1261,7 @@ def _parse_text_generation_model_multi_candidate_response(
12511261
text=candidates[0].text,
12521262
_prediction_response=prediction_response,
12531263
is_blocked=candidates[0].is_blocked,
1264+
errors=candidates[0].errors,
12541265
safety_attributes=candidates[0].safety_attributes,
12551266
grounding_metadata=candidates[0].grounding_metadata,
12561267
candidates=candidates,
@@ -2090,13 +2101,16 @@ def _parse_chat_prediction_response(
20902101
grounding_metadata_list = prediction.get("groundingMetadata")
20912102
for candidate_idx in range(candidate_count):
20922103
safety_attributes = prediction["safetyAttributes"][candidate_idx]
2104+
errors_list = safety_attributes.get("errors", [])
2105+
errors = tuple(map(int, errors_list))
20932106
grounding_metadata_dict = {}
20942107
if grounding_metadata_list and grounding_metadata_list[candidate_idx]:
20952108
grounding_metadata_dict = grounding_metadata_list[candidate_idx]
20962109
candidate_response = TextGenerationResponse(
20972110
text=prediction["candidates"][candidate_idx]["content"],
20982111
_prediction_response=prediction_response,
20992112
is_blocked=safety_attributes.get("blocked", False),
2113+
errors=errors,
21002114
safety_attributes=dict(
21012115
zip(
21022116
# Unlike with normal prediction, in streaming prediction
@@ -2112,6 +2126,7 @@ def _parse_chat_prediction_response(
21122126
text=candidates[0].text,
21132127
_prediction_response=prediction_response,
21142128
is_blocked=candidates[0].is_blocked,
2129+
errors=candidates[0].errors,
21152130
safety_attributes=candidates[0].safety_attributes,
21162131
grounding_metadata=candidates[0].grounding_metadata,
21172132
candidates=candidates,

0 commit comments

Comments
 (0)