|
42 | 42 | _LOGGER = base.Logger(__name__)
|
43 | 43 |
|
44 | 44 |
|
45 |
| -@dataclass |
46 |
| -class MatchNeighbor: |
47 |
| - """The id and distance of a nearest neighbor match for a given query embedding. |
48 |
| -
|
49 |
| - Args: |
50 |
| - id (str): |
51 |
| - Required. The id of the neighbor. |
52 |
| - distance (float): |
53 |
| - Required. The distance to the query embedding. |
54 |
| - feature_vector (List(float)): |
55 |
| - Optional. The feature vector of the matching datapoint. |
56 |
| - """ |
57 |
| - |
58 |
| - id: str |
59 |
| - distance: float |
60 |
| - feature_vector: Optional[List[float]] = None |
61 |
| - |
62 |
| - |
63 | 45 | @dataclass
|
64 | 46 | class Namespace:
|
65 | 47 | """Namespace specifies the rules for determining the datapoints that are eligible for each matching query, overall query is an AND across namespaces.
|
@@ -156,14 +138,136 @@ def __post_init__(self):
|
156 | 138 | )
|
157 | 139 | # Check operator validity
|
158 | 140 | if (
|
159 |
| - self.op |
| 141 | + self.op is not None |
| 142 | + and self.op |
160 | 143 | not in gca_index_v1beta1.IndexDatapoint.NumericRestriction.Operator._member_names_
|
161 | 144 | ):
|
162 | 145 | raise ValueError(
|
163 | 146 | f"Invalid operator '{self.op}'," " must be one of the valid operators."
|
164 | 147 | )
|
165 | 148 |
|
166 | 149 |
|
| 150 | +@dataclass |
| 151 | +class MatchNeighbor: |
| 152 | + """The id and distance of a nearest neighbor match for a given query embedding. |
| 153 | +
|
| 154 | + Args: |
| 155 | + id (str): |
| 156 | + Required. The id of the neighbor. |
| 157 | + distance (float): |
| 158 | + Required. The distance to the query embedding. |
| 159 | + feature_vector (List(float)): |
| 160 | + Optional. The feature vector of the matching datapoint. |
| 161 | + crowding_tag (Optional[str]): |
| 162 | + Optional. Crowding tag of the datapoint, the |
| 163 | + number of neighbors to return in each crowding, |
| 164 | + can be configured during query. |
| 165 | + restricts (List[Namespace]): |
| 166 | + Optional. The restricts of the matching datapoint. |
| 167 | + numeric_restricts: |
| 168 | + Optional. The numeric restricts of the matching datapoint. |
| 169 | +
|
| 170 | + """ |
| 171 | + |
| 172 | + id: str |
| 173 | + distance: float |
| 174 | + feature_vector: Optional[List[float]] = None |
| 175 | + crowding_tag: Optional[str] = None |
| 176 | + restricts: Optional[List[Namespace]] = None |
| 177 | + numeric_restricts: Optional[List[NumericNamespace]] = None |
| 178 | + |
| 179 | + def from_index_datapoint( |
| 180 | + self, index_datapoint: gca_index_v1beta1.IndexDatapoint |
| 181 | + ) -> "MatchNeighbor": |
| 182 | + """Copies MatchNeighbor fields from an IndexDatapoint. |
| 183 | +
|
| 184 | + Args: |
| 185 | + index_datapoint (gca_index_v1beta1.IndexDatapoint): |
| 186 | + Required. An index datapoint. |
| 187 | +
|
| 188 | + Returns: |
| 189 | + MatchNeighbor |
| 190 | + """ |
| 191 | + if not index_datapoint: |
| 192 | + return self |
| 193 | + self.feature_vector = index_datapoint.feature_vector |
| 194 | + if ( |
| 195 | + index_datapoint.crowding_tag is not None |
| 196 | + and index_datapoint.crowding_tag.crowding_attribute is not None |
| 197 | + ): |
| 198 | + self.crowding_tag = index_datapoint.crowding_tag.crowding_attribute |
| 199 | + self.restricts = [ |
| 200 | + Namespace( |
| 201 | + name=restrict.namespace, |
| 202 | + allow_tokens=restrict.allow_list, |
| 203 | + deny_tokens=restrict.deny_list, |
| 204 | + ) |
| 205 | + for restrict in index_datapoint.restricts |
| 206 | + ] |
| 207 | + if index_datapoint.numeric_restricts is not None: |
| 208 | + self.numeric_restricts = [] |
| 209 | + for restrict in index_datapoint.numeric_restricts: |
| 210 | + numeric_namespace = None |
| 211 | + restrict_value_type = restrict._pb.WhichOneof("Value") |
| 212 | + if restrict_value_type == "value_int": |
| 213 | + numeric_namespace = NumericNamespace( |
| 214 | + name=restrict.namespace, value_int=restrict.value_int |
| 215 | + ) |
| 216 | + elif restrict_value_type == "value_float": |
| 217 | + numeric_namespace = NumericNamespace( |
| 218 | + name=restrict.namespace, value_float=restrict.value_float |
| 219 | + ) |
| 220 | + elif restrict_value_type == "value_double": |
| 221 | + numeric_namespace = NumericNamespace( |
| 222 | + name=restrict.namespace, value_double=restrict.value_double |
| 223 | + ) |
| 224 | + self.numeric_restricts.append(numeric_namespace) |
| 225 | + return self |
| 226 | + |
| 227 | + def from_embedding(self, embedding: match_service_pb2.Embedding) -> "MatchNeighbor": |
| 228 | + """Copies MatchNeighbor fields from an Embedding. |
| 229 | +
|
| 230 | + Args: |
| 231 | + embedding (gca_index_v1beta1.Embedding): |
| 232 | + Required. An embedding. |
| 233 | +
|
| 234 | + Returns: |
| 235 | + MatchNeighbor |
| 236 | + """ |
| 237 | + if not embedding: |
| 238 | + return self |
| 239 | + self.feature_vector = embedding.float_val |
| 240 | + if not self.crowding_tag and embedding.crowding_attribute is not None: |
| 241 | + self.crowding_tag = str(embedding.crowding_attribute) |
| 242 | + self.restricts = [ |
| 243 | + Namespace( |
| 244 | + name=restrict.name, |
| 245 | + allow_tokens=restrict.allow_tokens, |
| 246 | + deny_tokens=restrict.deny_tokens, |
| 247 | + ) |
| 248 | + for restrict in embedding.restricts |
| 249 | + ] |
| 250 | + if embedding.numeric_restricts: |
| 251 | + self.numeric_restricts = [] |
| 252 | + for restrict in embedding.numeric_restricts: |
| 253 | + numeric_namespace = None |
| 254 | + restrict_value_type = restrict.WhichOneof("Value") |
| 255 | + if restrict_value_type == "value_int": |
| 256 | + numeric_namespace = NumericNamespace( |
| 257 | + name=restrict.name, value_int=restrict.value_int |
| 258 | + ) |
| 259 | + elif restrict_value_type == "value_float": |
| 260 | + numeric_namespace = NumericNamespace( |
| 261 | + name=restrict.name, value_float=restrict.value_float |
| 262 | + ) |
| 263 | + elif restrict_value_type == "value_double": |
| 264 | + numeric_namespace = NumericNamespace( |
| 265 | + name=restrict.name, value_double=restrict.value_double |
| 266 | + ) |
| 267 | + self.numeric_restricts.append(numeric_namespace) |
| 268 | + return self |
| 269 | + |
| 270 | + |
167 | 271 | class MatchingEngineIndexEndpoint(base.VertexAiResourceNounWithFutureManager):
|
168 | 272 | """Matching Engine index endpoint resource for Vertex AI."""
|
169 | 273 |
|
@@ -1333,10 +1437,8 @@ def find_neighbors(
|
1333 | 1437 | return [
|
1334 | 1438 | [
|
1335 | 1439 | MatchNeighbor(
|
1336 |
| - id=neighbor.datapoint.datapoint_id, |
1337 |
| - distance=neighbor.distance, |
1338 |
| - feature_vector=neighbor.datapoint.feature_vector, |
1339 |
| - ) |
| 1440 | + id=neighbor.datapoint.datapoint_id, distance=neighbor.distance |
| 1441 | + ).from_index_datapoint(index_datapoint=neighbor.datapoint) |
1340 | 1442 | for neighbor in embedding_neighbors.neighbors
|
1341 | 1443 | ]
|
1342 | 1444 | for embedding_neighbors in response.nearest_neighbors
|
@@ -1572,13 +1674,17 @@ def match(
|
1572 | 1674 | response = stub.BatchMatch(batch_request)
|
1573 | 1675 |
|
1574 | 1676 | # Wrap the results in MatchNeighbor objects and return
|
1575 |
| - return [ |
1576 |
| - [ |
1577 |
| - MatchNeighbor( |
1578 |
| - id=embedding_neighbors.neighbor[i].id, |
1579 |
| - distance=embedding_neighbors.neighbor[i].distance, |
| 1677 | + match_neighbors_response = [] |
| 1678 | + for resp in response.responses[0].responses: |
| 1679 | + match_neighbors_id_map = {} |
| 1680 | + for neighbor in resp.neighbor: |
| 1681 | + match_neighbors_id_map[neighbor.id] = MatchNeighbor( |
| 1682 | + id=neighbor.id, distance=neighbor.distance |
1580 | 1683 | )
|
1581 |
| - for i in range(len(embedding_neighbors.neighbor)) |
1582 |
| - ] |
1583 |
| - for embedding_neighbors in response.responses[0].responses |
1584 |
| - ] |
| 1684 | + for embedding in resp.embeddings: |
| 1685 | + if embedding.id in match_neighbors_id_map: |
| 1686 | + match_neighbors_id_map[embedding.id] = match_neighbors_id_map[ |
| 1687 | + embedding.id |
| 1688 | + ].from_embedding(embedding=embedding) |
| 1689 | + match_neighbors_response.append(list(match_neighbors_id_map.values())) |
| 1690 | + return match_neighbors_response |
0 commit comments