Skip to content

Commit cbf9b6e

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - TextEmbeddingModel - Added support for structural inputs (TextEmbeddingInput), auto_truncate parameter and result statistics
PiperOrigin-RevId: 558465128
1 parent 76b95b9 commit cbf9b6e

File tree

5 files changed

+139
-24
lines changed

5 files changed

+139
-24
lines changed

Diff for: tests/system/aiplatform/test_language_models.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,17 @@ def test_text_embedding(self):
143143
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
144144

145145
model = TextEmbeddingModel.from_pretrained("google/textembedding-gecko@001")
146-
embeddings = model.get_embeddings(["What is life?"])
147-
assert embeddings
148-
for embedding in embeddings:
149-
vector = embedding.values
150-
assert len(vector) == 768
146+
# One short text, one llong text (to check truncation)
147+
texts = ["What is life?", "What is life?" * 1000]
148+
embeddings = model.get_embeddings(texts)
149+
assert len(embeddings) == 2
150+
assert len(embeddings[0].values) == 768
151+
assert embeddings[0].statistics.token_count > 0
152+
assert not embeddings[0].statistics.truncated
153+
154+
assert len(embeddings[1].values) == 768
155+
assert embeddings[1].statistics.token_count > 1000
156+
assert embeddings[1].statistics.truncated
151157

152158
def test_tuning(self, shared_state):
153159
"""Test tuning, listing and loading models."""

Diff for: tests/unit/aiplatform/test_language_models.py

+43-3
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ def reverse_string_2(s):""",
298298
_TEST_TEXT_EMBEDDING_PREDICTION = {
299299
"embeddings": {
300300
"values": list([1.0] * _TEXT_EMBEDDING_VECTOR_LENGTH),
301+
"statistics": {"truncated": False, "token_count": 4.0},
301302
}
302303
}
303304

@@ -2170,18 +2171,57 @@ def test_text_embedding(self):
21702171

21712172
gca_predict_response = gca_prediction_service.PredictResponse()
21722173
gca_predict_response.predictions.append(_TEST_TEXT_EMBEDDING_PREDICTION)
2174+
gca_predict_response.predictions.append(_TEST_TEXT_EMBEDDING_PREDICTION)
21732175

2176+
expected_embedding = _TEST_TEXT_EMBEDDING_PREDICTION["embeddings"]
21742177
with mock.patch.object(
21752178
target=prediction_service_client.PredictionServiceClient,
21762179
attribute="predict",
21772180
return_value=gca_predict_response,
2178-
):
2179-
embeddings = model.get_embeddings(["What is life?"])
2181+
) as mock_predict:
2182+
embeddings = model.get_embeddings(
2183+
[
2184+
"What is life?",
2185+
language_models.TextEmbeddingInput(
2186+
text="Foo",
2187+
task_type="RETRIEVAL_DOCUMENT",
2188+
title="Bar",
2189+
),
2190+
language_models.TextEmbeddingInput(
2191+
text="Baz",
2192+
task_type="CLASSIFICATION",
2193+
),
2194+
],
2195+
auto_truncate=False,
2196+
)
2197+
prediction_instances = mock_predict.call_args[1]["instances"]
2198+
assert prediction_instances == [
2199+
{"content": "What is life?"},
2200+
{
2201+
"content": "Foo",
2202+
"taskType": "RETRIEVAL_DOCUMENT",
2203+
"title": "Bar",
2204+
},
2205+
{
2206+
"content": "Baz",
2207+
"taskType": "CLASSIFICATION",
2208+
},
2209+
]
2210+
prediction_parameters = mock_predict.call_args[1]["parameters"]
2211+
assert not prediction_parameters["autoTruncate"]
21802212
assert embeddings
21812213
for embedding in embeddings:
21822214
vector = embedding.values
21832215
assert len(vector) == _TEXT_EMBEDDING_VECTOR_LENGTH
2184-
assert vector == _TEST_TEXT_EMBEDDING_PREDICTION["embeddings"]["values"]
2216+
assert vector == expected_embedding["values"]
2217+
assert (
2218+
embedding.statistics.token_count
2219+
== expected_embedding["statistics"]["token_count"]
2220+
)
2221+
assert (
2222+
embedding.statistics.truncated
2223+
== expected_embedding["statistics"]["truncated"]
2224+
)
21852225

21862226
def test_text_embedding_ga(self):
21872227
"""Tests the text embedding model."""

Diff for: vertexai/language_models/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
CodeGenerationModel,
2424
InputOutputTextPair,
2525
TextEmbedding,
26+
TextEmbeddingInput,
2627
TextEmbeddingModel,
2728
TextGenerationModel,
2829
TextGenerationResponse,
@@ -37,6 +38,7 @@
3738
"CodeGenerationModel",
3839
"InputOutputTextPair",
3940
"TextEmbedding",
41+
"TextEmbeddingInput",
4042
"TextEmbeddingModel",
4143
"TextGenerationModel",
4244
"TextGenerationResponse",

Diff for: vertexai/language_models/_language_models.py

+81-16
Original file line numberDiff line numberDiff line change
@@ -692,8 +692,33 @@ def send_message(
692692
return response_obj
693693

694694

695+
@dataclasses.dataclass
696+
class TextEmbeddingInput:
697+
"""Structural text embedding input.
698+
699+
Attributes:
700+
text: The main text content to embed.
701+
task_type: The name of the downstream task the embeddings will be used for.
702+
Valid values:
703+
RETRIEVAL_QUERY
704+
Specifies the given text is a query in a search/retrieval setting.
705+
RETRIEVAL_DOCUMENT
706+
Specifies the given text is a document from the corpus being searched.
707+
SEMANTIC_SIMILARITY
708+
Specifies the given text will be used for STS.
709+
CLASSIFICATION
710+
Specifies that the given text will be classified.
711+
CLUSTERING
712+
Specifies that the embeddings will be used for clustering.
713+
title: Optional identifier of the text content.
714+
"""
715+
text: str
716+
task_type: Optional[str] = None
717+
title: Optional[str] = None
718+
719+
695720
class TextEmbeddingModel(_LanguageModel):
696-
"""TextEmbeddingModel converts text into a vector of floating-point numbers.
721+
"""TextEmbeddingModel class calculates embeddings for the given texts.
697722
698723
Examples::
699724
@@ -711,36 +736,76 @@ class TextEmbeddingModel(_LanguageModel):
711736
"gs://google-cloud-aiplatform/schema/predict/instance/text_embedding_1.0.0.yaml"
712737
)
713738

714-
def get_embeddings(self, texts: List[str]) -> List["TextEmbedding"]:
715-
instances = [{"content": str(text)} for text in texts]
739+
def get_embeddings(self,
740+
texts: List[Union[str, TextEmbeddingInput]],
741+
*,
742+
auto_truncate: bool = True,
743+
) -> List["TextEmbedding"]:
744+
"""Calculates embeddings for the given texts.
745+
746+
Args:
747+
texts(str): A list of texts or `TextEmbeddingInput` objects to embed.
748+
auto_truncate(bool): Whether to automatically truncate long texts. Default: True.
749+
750+
Returns:
751+
A list of `TextEmbedding` objects.
752+
"""
753+
instances = []
754+
for text in texts:
755+
if isinstance(text, TextEmbeddingInput):
756+
instance = {"content": text.text}
757+
if text.task_type:
758+
instance["taskType"] = text.task_type
759+
if text.title:
760+
instance["title"] = text.title
761+
elif isinstance(text, str):
762+
instance = {"content": text}
763+
else:
764+
raise TypeError(f"Unsupported text embedding input type: {text}.")
765+
instances.append(instance)
766+
parameters = {"autoTruncate": auto_truncate}
716767

717768
prediction_response = self._endpoint.predict(
718769
instances=instances,
770+
parameters=parameters,
719771
)
720772

721-
return [
722-
TextEmbedding(
723-
values=prediction["embeddings"]["values"],
773+
results = []
774+
for prediction in prediction_response.predictions:
775+
embeddings = prediction["embeddings"]
776+
statistics = embeddings["statistics"]
777+
result = TextEmbedding(
778+
values=embeddings["values"],
779+
statistics=TextEmbeddingStatistics(
780+
token_count=statistics["token_count"],
781+
truncated=statistics["truncated"],
782+
),
724783
_prediction_response=prediction_response,
725784
)
726-
for prediction in prediction_response.predictions
727-
]
785+
results.append(result)
786+
787+
return results
728788

729789

730790
class _PreviewTextEmbeddingModel(TextEmbeddingModel, _ModelWithBatchPredict):
731791
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
732792

733793

794+
@dataclasses.dataclass
795+
class TextEmbeddingStatistics:
796+
"""Text embedding statistics."""
797+
798+
token_count: int
799+
truncated: bool
800+
801+
802+
@dataclasses.dataclass
734803
class TextEmbedding:
735-
"""Contains text embedding vector."""
804+
"""Text embedding vector and statistics."""
736805

737-
def __init__(
738-
self,
739-
values: List[float],
740-
_prediction_response: Any = None,
741-
):
742-
self.values = values
743-
self._prediction_response = _prediction_response
806+
values: List[float]
807+
statistics: TextEmbeddingStatistics
808+
_prediction_response: aiplatform.models.Prediction = None
744809

745810

746811
@dataclasses.dataclass

Diff for: vertexai/preview/language_models.py

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
CodeChatSession,
2727
InputOutputTextPair,
2828
TextEmbedding,
29+
TextEmbeddingInput,
2930
TextGenerationResponse,
3031
)
3132

@@ -60,6 +61,7 @@
6061
"EvaluationTextClassificationSpec",
6162
"InputOutputTextPair",
6263
"TextEmbedding",
64+
"TextEmbeddingInput",
6365
"TextEmbeddingModel",
6466
"TextGenerationModel",
6567
"TextGenerationResponse",

0 commit comments

Comments
 (0)