Skip to content

Commit c3ae475

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Added support for multiple text generation response candidates
PiperOrigin-RevId: 570261595
1 parent 603cdbe commit c3ae475

File tree

2 files changed

+109
-6
lines changed

2 files changed

+109
-6
lines changed

Diff for: tests/unit/aiplatform/test_language_models.py

+37
Original file line numberDiff line numberDiff line change
@@ -1326,6 +1326,43 @@ def test_text_generation_ga(self):
13261326
assert "topP" not in prediction_parameters
13271327
assert "topK" not in prediction_parameters
13281328

1329+
def test_text_generation_multiple_candidates(self):
1330+
"""Tests the text generation model with multiple candidates."""
1331+
with mock.patch.object(
1332+
target=model_garden_service_client.ModelGardenServiceClient,
1333+
attribute="get_publisher_model",
1334+
return_value=gca_publisher_model.PublisherModel(
1335+
_TEXT_BISON_PUBLISHER_MODEL_DICT
1336+
),
1337+
):
1338+
model = language_models.TextGenerationModel.from_pretrained(
1339+
"text-bison@001"
1340+
)
1341+
1342+
gca_predict_response = gca_prediction_service.PredictResponse()
1343+
# Discrepancy between the number of `instances` and the number of `predictions`
1344+
# is a violation of the prediction service invariant, but the service does this.
1345+
gca_predict_response.predictions.append(_TEST_TEXT_GENERATION_PREDICTION)
1346+
gca_predict_response.predictions.append(_TEST_TEXT_GENERATION_PREDICTION)
1347+
1348+
with mock.patch.object(
1349+
target=prediction_service_client.PredictionServiceClient,
1350+
attribute="predict",
1351+
return_value=gca_predict_response,
1352+
) as mock_predict:
1353+
response = model.predict(
1354+
"What is the best recipe for banana bread? Recipe:",
1355+
candidate_count=2,
1356+
)
1357+
prediction_parameters = mock_predict.call_args[1]["parameters"]
1358+
assert prediction_parameters["candidateCount"] == 2
1359+
1360+
assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]
1361+
assert len(response.candidates) == 2
1362+
assert (
1363+
response.candidates[0].text == _TEST_TEXT_GENERATION_PREDICTION["content"]
1364+
)
1365+
13291366
@pytest.mark.asyncio
13301367
async def test_text_generation_async(self):
13311368
"""Tests the text generation model."""

Diff for: vertexai/language_models/_language_models.py

+72-6
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,35 @@ def raw_prediction_response(self) -> aiplatform.models.Prediction:
691691
return self._prediction_response
692692

693693

694+
@dataclasses.dataclass
695+
class MultiCandidateTextGenerationResponse(TextGenerationResponse):
696+
"""Represents a multi-candidate response of a language model.
697+
698+
Attributes:
699+
text: The generated text for the first candidate.
700+
is_blocked: Whether the first candidate response was blocked.
701+
safety_attributes: Scores for safety attributes for the first candidate.
702+
Learn more about the safety attributes here:
703+
https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_descriptions
704+
candidates: The candidate responses.
705+
Usually contains a single candidate unless `candidate_count` is used.
706+
"""
707+
708+
__module__ = "vertexai.language_models"
709+
710+
candidates: List[TextGenerationResponse] = dataclasses.field(default_factory=list)
711+
712+
def _repr_pretty_(self, p, cycle):
713+
"""Pretty prints self in IPython environments."""
714+
if cycle:
715+
p.text(f"{self.__class__.__name__}(...)")
716+
else:
717+
if len(self.candidates) == 1:
718+
p.text(repr(self.candidates[0]))
719+
else:
720+
p.text(repr(self))
721+
722+
694723
class _TextGenerationModel(_LanguageModel):
695724
"""TextGenerationModel represents a general language model.
696725
@@ -716,7 +745,8 @@ def predict(
716745
top_k: Optional[int] = None,
717746
top_p: Optional[float] = None,
718747
stop_sequences: Optional[List[str]] = None,
719-
) -> "TextGenerationResponse":
748+
candidate_count: Optional[int] = None,
749+
) -> "MultiCandidateTextGenerationResponse":
720750
"""Gets model response for a single prompt.
721751
722752
Args:
@@ -726,9 +756,10 @@ def predict(
726756
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
727757
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
728758
stop_sequences: Customized stop sequences to stop the decoding process.
759+
candidate_count: Number of response candidates to return.
729760
730761
Returns:
731-
A `TextGenerationResponse` object that contains the text produced by the model.
762+
A `MultiCandidateTextGenerationResponse` object that contains the text produced by the model.
732763
"""
733764
prediction_request = _create_text_generation_prediction_request(
734765
prompt=prompt,
@@ -737,14 +768,15 @@ def predict(
737768
top_k=top_k,
738769
top_p=top_p,
739770
stop_sequences=stop_sequences,
771+
candidate_count=candidate_count,
740772
)
741773

742774
prediction_response = self._endpoint.predict(
743775
instances=[prediction_request.instance],
744776
parameters=prediction_request.parameters,
745777
)
746778

747-
return _parse_text_generation_model_response(prediction_response)
779+
return _parse_text_generation_model_multi_candidate_response(prediction_response)
748780

749781
async def predict_async(
750782
self,
@@ -755,7 +787,8 @@ async def predict_async(
755787
top_k: Optional[int] = None,
756788
top_p: Optional[float] = None,
757789
stop_sequences: Optional[List[str]] = None,
758-
) -> "TextGenerationResponse":
790+
candidate_count: Optional[int] = None,
791+
) -> "MultiCandidateTextGenerationResponse":
759792
"""Asynchronously gets model response for a single prompt.
760793
761794
Args:
@@ -765,9 +798,10 @@ async def predict_async(
765798
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
766799
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
767800
stop_sequences: Customized stop sequences to stop the decoding process.
801+
candidate_count: Number of response candidates to return.
768802
769803
Returns:
770-
A `TextGenerationResponse` object that contains the text produced by the model.
804+
A `MultiCandidateTextGenerationResponse` object that contains the text produced by the model.
771805
"""
772806
prediction_request = _create_text_generation_prediction_request(
773807
prompt=prompt,
@@ -776,14 +810,15 @@ async def predict_async(
776810
top_k=top_k,
777811
top_p=top_p,
778812
stop_sequences=stop_sequences,
813+
candidate_count=candidate_count,
779814
)
780815

781816
prediction_response = await self._endpoint.predict_async(
782817
instances=[prediction_request.instance],
783818
parameters=prediction_request.parameters,
784819
)
785820

786-
return _parse_text_generation_model_response(prediction_response)
821+
return _parse_text_generation_model_multi_candidate_response(prediction_response)
787822

788823
def predict_streaming(
789824
self,
@@ -844,6 +879,7 @@ def _create_text_generation_prediction_request(
844879
top_k: Optional[int] = None,
845880
top_p: Optional[float] = None,
846881
stop_sequences: Optional[List[str]] = None,
882+
candidate_count: Optional[int] = None,
847883
) -> "_PredictionRequest":
848884
"""Prepares the text generation request for a single prompt.
849885
@@ -854,6 +890,7 @@ def _create_text_generation_prediction_request(
854890
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
855891
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
856892
stop_sequences: Customized stop sequences to stop the decoding process.
893+
candidate_count: Number of candidates to return.
857894
858895
Returns:
859896
A `_PredictionRequest` object that contains prediction instance and parameters.
@@ -880,6 +917,9 @@ def _create_text_generation_prediction_request(
880917
if stop_sequences:
881918
prediction_parameters["stopSequences"] = stop_sequences
882919

920+
if candidate_count is not None:
921+
prediction_parameters["candidateCount"] = candidate_count
922+
883923
return _PredictionRequest(
884924
instance=instance,
885925
parameters=prediction_parameters,
@@ -906,6 +946,32 @@ def _parse_text_generation_model_response(
906946
)
907947

908948

949+
def _parse_text_generation_model_multi_candidate_response(
950+
prediction_response: aiplatform.models.Prediction,
951+
) -> MultiCandidateTextGenerationResponse:
952+
"""Converts the raw text_generation model response to `MultiCandidateTextGenerationResponse`."""
953+
# The contract for the PredictionService is that there is a 1:1 mapping
954+
# between request `instances` and response `predictions`.
955+
# Unfortunetely, the text-bison models violate this contract.
956+
957+
prediction_count = len(prediction_response.predictions)
958+
candidates = []
959+
for prediction_idx in range(prediction_count):
960+
candidate = _parse_text_generation_model_response(
961+
prediction_response=prediction_response,
962+
prediction_idx=prediction_idx,
963+
)
964+
candidates.append(candidate)
965+
966+
return MultiCandidateTextGenerationResponse(
967+
text=candidates[0].text,
968+
_prediction_response=prediction_response,
969+
is_blocked=candidates[0].is_blocked,
970+
safety_attributes=candidates[0].safety_attributes,
971+
candidates=candidates,
972+
)
973+
974+
909975
class _ModelWithBatchPredict(_LanguageModel):
910976
"""Model that supports batch prediction."""
911977

0 commit comments

Comments
 (0)