@@ -691,6 +691,35 @@ def raw_prediction_response(self) -> aiplatform.models.Prediction:
691
691
return self ._prediction_response
692
692
693
693
694
+ @dataclasses .dataclass
695
+ class MultiCandidateTextGenerationResponse (TextGenerationResponse ):
696
+ """Represents a multi-candidate response of a language model.
697
+
698
+ Attributes:
699
+ text: The generated text for the first candidate.
700
+ is_blocked: Whether the first candidate response was blocked.
701
+ safety_attributes: Scores for safety attributes for the first candidate.
702
+ Learn more about the safety attributes here:
703
+ https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_descriptions
704
+ candidates: The candidate responses.
705
+ Usually contains a single candidate unless `candidate_count` is used.
706
+ """
707
+
708
+ __module__ = "vertexai.language_models"
709
+
710
+ candidates : List [TextGenerationResponse ] = dataclasses .field (default_factory = list )
711
+
712
+ def _repr_pretty_ (self , p , cycle ):
713
+ """Pretty prints self in IPython environments."""
714
+ if cycle :
715
+ p .text (f"{ self .__class__ .__name__ } (...)" )
716
+ else :
717
+ if len (self .candidates ) == 1 :
718
+ p .text (repr (self .candidates [0 ]))
719
+ else :
720
+ p .text (repr (self ))
721
+
722
+
694
723
class _TextGenerationModel (_LanguageModel ):
695
724
"""TextGenerationModel represents a general language model.
696
725
@@ -716,7 +745,8 @@ def predict(
716
745
top_k : Optional [int ] = None ,
717
746
top_p : Optional [float ] = None ,
718
747
stop_sequences : Optional [List [str ]] = None ,
719
- ) -> "TextGenerationResponse" :
748
+ candidate_count : Optional [int ] = None ,
749
+ ) -> "MultiCandidateTextGenerationResponse" :
720
750
"""Gets model response for a single prompt.
721
751
722
752
Args:
@@ -726,9 +756,10 @@ def predict(
726
756
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
727
757
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
728
758
stop_sequences: Customized stop sequences to stop the decoding process.
759
+ candidate_count: Number of response candidates to return.
729
760
730
761
Returns:
731
- A `TextGenerationResponse ` object that contains the text produced by the model.
762
+ A `MultiCandidateTextGenerationResponse ` object that contains the text produced by the model.
732
763
"""
733
764
prediction_request = _create_text_generation_prediction_request (
734
765
prompt = prompt ,
@@ -737,14 +768,15 @@ def predict(
737
768
top_k = top_k ,
738
769
top_p = top_p ,
739
770
stop_sequences = stop_sequences ,
771
+ candidate_count = candidate_count ,
740
772
)
741
773
742
774
prediction_response = self ._endpoint .predict (
743
775
instances = [prediction_request .instance ],
744
776
parameters = prediction_request .parameters ,
745
777
)
746
778
747
- return _parse_text_generation_model_response (prediction_response )
779
+ return _parse_text_generation_model_multi_candidate_response (prediction_response )
748
780
749
781
async def predict_async (
750
782
self ,
@@ -755,7 +787,8 @@ async def predict_async(
755
787
top_k : Optional [int ] = None ,
756
788
top_p : Optional [float ] = None ,
757
789
stop_sequences : Optional [List [str ]] = None ,
758
- ) -> "TextGenerationResponse" :
790
+ candidate_count : Optional [int ] = None ,
791
+ ) -> "MultiCandidateTextGenerationResponse" :
759
792
"""Asynchronously gets model response for a single prompt.
760
793
761
794
Args:
@@ -765,9 +798,10 @@ async def predict_async(
765
798
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
766
799
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
767
800
stop_sequences: Customized stop sequences to stop the decoding process.
801
+ candidate_count: Number of response candidates to return.
768
802
769
803
Returns:
770
- A `TextGenerationResponse ` object that contains the text produced by the model.
804
+ A `MultiCandidateTextGenerationResponse ` object that contains the text produced by the model.
771
805
"""
772
806
prediction_request = _create_text_generation_prediction_request (
773
807
prompt = prompt ,
@@ -776,14 +810,15 @@ async def predict_async(
776
810
top_k = top_k ,
777
811
top_p = top_p ,
778
812
stop_sequences = stop_sequences ,
813
+ candidate_count = candidate_count ,
779
814
)
780
815
781
816
prediction_response = await self ._endpoint .predict_async (
782
817
instances = [prediction_request .instance ],
783
818
parameters = prediction_request .parameters ,
784
819
)
785
820
786
- return _parse_text_generation_model_response (prediction_response )
821
+ return _parse_text_generation_model_multi_candidate_response (prediction_response )
787
822
788
823
def predict_streaming (
789
824
self ,
@@ -844,6 +879,7 @@ def _create_text_generation_prediction_request(
844
879
top_k : Optional [int ] = None ,
845
880
top_p : Optional [float ] = None ,
846
881
stop_sequences : Optional [List [str ]] = None ,
882
+ candidate_count : Optional [int ] = None ,
847
883
) -> "_PredictionRequest" :
848
884
"""Prepares the text generation request for a single prompt.
849
885
@@ -854,6 +890,7 @@ def _create_text_generation_prediction_request(
854
890
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
855
891
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
856
892
stop_sequences: Customized stop sequences to stop the decoding process.
893
+ candidate_count: Number of candidates to return.
857
894
858
895
Returns:
859
896
A `_PredictionRequest` object that contains prediction instance and parameters.
@@ -880,6 +917,9 @@ def _create_text_generation_prediction_request(
880
917
if stop_sequences :
881
918
prediction_parameters ["stopSequences" ] = stop_sequences
882
919
920
+ if candidate_count is not None :
921
+ prediction_parameters ["candidateCount" ] = candidate_count
922
+
883
923
return _PredictionRequest (
884
924
instance = instance ,
885
925
parameters = prediction_parameters ,
@@ -906,6 +946,32 @@ def _parse_text_generation_model_response(
906
946
)
907
947
908
948
949
+ def _parse_text_generation_model_multi_candidate_response (
950
+ prediction_response : aiplatform .models .Prediction ,
951
+ ) -> MultiCandidateTextGenerationResponse :
952
+ """Converts the raw text_generation model response to `MultiCandidateTextGenerationResponse`."""
953
+ # The contract for the PredictionService is that there is a 1:1 mapping
954
+ # between request `instances` and response `predictions`.
955
+ # Unfortunetely, the text-bison models violate this contract.
956
+
957
+ prediction_count = len (prediction_response .predictions )
958
+ candidates = []
959
+ for prediction_idx in range (prediction_count ):
960
+ candidate = _parse_text_generation_model_response (
961
+ prediction_response = prediction_response ,
962
+ prediction_idx = prediction_idx ,
963
+ )
964
+ candidates .append (candidate )
965
+
966
+ return MultiCandidateTextGenerationResponse (
967
+ text = candidates [0 ].text ,
968
+ _prediction_response = prediction_response ,
969
+ is_blocked = candidates [0 ].is_blocked ,
970
+ safety_attributes = candidates [0 ].safety_attributes ,
971
+ candidates = candidates ,
972
+ )
973
+
974
+
909
975
class _ModelWithBatchPredict (_LanguageModel ):
910
976
"""Model that supports batch prediction."""
911
977
0 commit comments