Skip to content

Commit 83cb52d

Browse files
lingyinwcopybara-github
authored andcommitted
fix: Add restricts and crowding tag to MatchingEngineIndexEndpoint query response.
PiperOrigin-RevId: 607012218
1 parent 14b41b5 commit 83cb52d

File tree

2 files changed

+281
-42
lines changed

2 files changed

+281
-42
lines changed

Diff for: google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py

+138-32
Original file line numberDiff line numberDiff line change
@@ -42,24 +42,6 @@
4242
_LOGGER = base.Logger(__name__)
4343

4444

45-
@dataclass
46-
class MatchNeighbor:
47-
"""The id and distance of a nearest neighbor match for a given query embedding.
48-
49-
Args:
50-
id (str):
51-
Required. The id of the neighbor.
52-
distance (float):
53-
Required. The distance to the query embedding.
54-
feature_vector (List(float)):
55-
Optional. The feature vector of the matching datapoint.
56-
"""
57-
58-
id: str
59-
distance: float
60-
feature_vector: Optional[List[float]] = None
61-
62-
6345
@dataclass
6446
class Namespace:
6547
"""Namespace specifies the rules for determining the datapoints that are eligible for each matching query, overall query is an AND across namespaces.
@@ -156,14 +138,136 @@ def __post_init__(self):
156138
)
157139
# Check operator validity
158140
if (
159-
self.op
141+
self.op is not None
142+
and self.op
160143
not in gca_index_v1beta1.IndexDatapoint.NumericRestriction.Operator._member_names_
161144
):
162145
raise ValueError(
163146
f"Invalid operator '{self.op}'," " must be one of the valid operators."
164147
)
165148

166149

150+
@dataclass
151+
class MatchNeighbor:
152+
"""The id and distance of a nearest neighbor match for a given query embedding.
153+
154+
Args:
155+
id (str):
156+
Required. The id of the neighbor.
157+
distance (float):
158+
Required. The distance to the query embedding.
159+
feature_vector (List(float)):
160+
Optional. The feature vector of the matching datapoint.
161+
crowding_tag (Optional[str]):
162+
Optional. Crowding tag of the datapoint, the
163+
number of neighbors to return in each crowding,
164+
can be configured during query.
165+
restricts (List[Namespace]):
166+
Optional. The restricts of the matching datapoint.
167+
numeric_restricts:
168+
Optional. The numeric restricts of the matching datapoint.
169+
170+
"""
171+
172+
id: str
173+
distance: float
174+
feature_vector: Optional[List[float]] = None
175+
crowding_tag: Optional[str] = None
176+
restricts: Optional[List[Namespace]] = None
177+
numeric_restricts: Optional[List[NumericNamespace]] = None
178+
179+
def from_index_datapoint(
180+
self, index_datapoint: gca_index_v1beta1.IndexDatapoint
181+
) -> "MatchNeighbor":
182+
"""Copies MatchNeighbor fields from an IndexDatapoint.
183+
184+
Args:
185+
index_datapoint (gca_index_v1beta1.IndexDatapoint):
186+
Required. An index datapoint.
187+
188+
Returns:
189+
MatchNeighbor
190+
"""
191+
if not index_datapoint:
192+
return self
193+
self.feature_vector = index_datapoint.feature_vector
194+
if (
195+
index_datapoint.crowding_tag is not None
196+
and index_datapoint.crowding_tag.crowding_attribute is not None
197+
):
198+
self.crowding_tag = index_datapoint.crowding_tag.crowding_attribute
199+
self.restricts = [
200+
Namespace(
201+
name=restrict.namespace,
202+
allow_tokens=restrict.allow_list,
203+
deny_tokens=restrict.deny_list,
204+
)
205+
for restrict in index_datapoint.restricts
206+
]
207+
if index_datapoint.numeric_restricts is not None:
208+
self.numeric_restricts = []
209+
for restrict in index_datapoint.numeric_restricts:
210+
numeric_namespace = None
211+
restrict_value_type = restrict._pb.WhichOneof("Value")
212+
if restrict_value_type == "value_int":
213+
numeric_namespace = NumericNamespace(
214+
name=restrict.namespace, value_int=restrict.value_int
215+
)
216+
elif restrict_value_type == "value_float":
217+
numeric_namespace = NumericNamespace(
218+
name=restrict.namespace, value_float=restrict.value_float
219+
)
220+
elif restrict_value_type == "value_double":
221+
numeric_namespace = NumericNamespace(
222+
name=restrict.namespace, value_double=restrict.value_double
223+
)
224+
self.numeric_restricts.append(numeric_namespace)
225+
return self
226+
227+
def from_embedding(self, embedding: match_service_pb2.Embedding) -> "MatchNeighbor":
228+
"""Copies MatchNeighbor fields from an Embedding.
229+
230+
Args:
231+
embedding (gca_index_v1beta1.Embedding):
232+
Required. An embedding.
233+
234+
Returns:
235+
MatchNeighbor
236+
"""
237+
if not embedding:
238+
return self
239+
self.feature_vector = embedding.float_val
240+
if not self.crowding_tag and embedding.crowding_attribute is not None:
241+
self.crowding_tag = str(embedding.crowding_attribute)
242+
self.restricts = [
243+
Namespace(
244+
name=restrict.name,
245+
allow_tokens=restrict.allow_tokens,
246+
deny_tokens=restrict.deny_tokens,
247+
)
248+
for restrict in embedding.restricts
249+
]
250+
if embedding.numeric_restricts:
251+
self.numeric_restricts = []
252+
for restrict in embedding.numeric_restricts:
253+
numeric_namespace = None
254+
restrict_value_type = restrict.WhichOneof("Value")
255+
if restrict_value_type == "value_int":
256+
numeric_namespace = NumericNamespace(
257+
name=restrict.name, value_int=restrict.value_int
258+
)
259+
elif restrict_value_type == "value_float":
260+
numeric_namespace = NumericNamespace(
261+
name=restrict.name, value_float=restrict.value_float
262+
)
263+
elif restrict_value_type == "value_double":
264+
numeric_namespace = NumericNamespace(
265+
name=restrict.name, value_double=restrict.value_double
266+
)
267+
self.numeric_restricts.append(numeric_namespace)
268+
return self
269+
270+
167271
class MatchingEngineIndexEndpoint(base.VertexAiResourceNounWithFutureManager):
168272
"""Matching Engine index endpoint resource for Vertex AI."""
169273

@@ -1333,10 +1437,8 @@ def find_neighbors(
13331437
return [
13341438
[
13351439
MatchNeighbor(
1336-
id=neighbor.datapoint.datapoint_id,
1337-
distance=neighbor.distance,
1338-
feature_vector=neighbor.datapoint.feature_vector,
1339-
)
1440+
id=neighbor.datapoint.datapoint_id, distance=neighbor.distance
1441+
).from_index_datapoint(index_datapoint=neighbor.datapoint)
13401442
for neighbor in embedding_neighbors.neighbors
13411443
]
13421444
for embedding_neighbors in response.nearest_neighbors
@@ -1572,13 +1674,17 @@ def match(
15721674
response = stub.BatchMatch(batch_request)
15731675

15741676
# Wrap the results in MatchNeighbor objects and return
1575-
return [
1576-
[
1577-
MatchNeighbor(
1578-
id=embedding_neighbors.neighbor[i].id,
1579-
distance=embedding_neighbors.neighbor[i].distance,
1677+
match_neighbors_response = []
1678+
for resp in response.responses[0].responses:
1679+
match_neighbors_id_map = {}
1680+
for neighbor in resp.neighbor:
1681+
match_neighbors_id_map[neighbor.id] = MatchNeighbor(
1682+
id=neighbor.id, distance=neighbor.distance
15801683
)
1581-
for i in range(len(embedding_neighbors.neighbor))
1582-
]
1583-
for embedding_neighbors in response.responses[0].responses
1584-
]
1684+
for embedding in resp.embeddings:
1685+
if embedding.id in match_neighbors_id_map:
1686+
match_neighbors_id_map[embedding.id] = match_neighbors_id_map[
1687+
embedding.id
1688+
].from_embedding(embedding=embedding)
1689+
match_neighbors_response.append(list(match_neighbors_id_map.values()))
1690+
return match_neighbors_response

0 commit comments

Comments
 (0)