Skip to content

Commit 33c551e

Browse files
lingyinwcopybara-github
authored andcommitted
feat: add per_crowding_attribute_neighbor_count, approx_num_neighbors, fraction_leaf_nodes_to_search_override, and return_full_datapoint to MatchingEngineIndexEndpoint find_neighbors
PiperOrigin-RevId: 579967420
1 parent a0103c5 commit 33c551e

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,10 @@ def find_neighbors(
956956
queries: List[List[float]],
957957
num_neighbors: int = 10,
958958
filter: Optional[List[Namespace]] = [],
959+
per_crowding_attribute_neighbor_count: Optional[int] = None,
960+
approx_num_neighbors: Optional[int] = None,
961+
fraction_leaf_nodes_to_search_override: Optional[float] = None,
962+
return_full_datapoint: bool = False,
959963
) -> List[List[MatchNeighbor]]:
960964
"""Retrieves nearest neighbors for the given embedding queries on the specified deployed index which is deployed to public endpoint.
961965
@@ -979,25 +983,58 @@ def find_neighbors(
979983
For example, [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])] will match datapoints
980984
that satisfy "red color" but not include datapoints with "squared shape".
981985
Please refer to https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json for more detail.
986+
987+
per_crowding_attribute_neighbor_count (int):
988+
Optional. Crowding is a constraint on a neighbor list produced
989+
by nearest neighbor search requiring that no more than some
990+
value k' of the k neighbors returned have the same value of
991+
crowding_attribute. It's used for improving result diversity.
992+
This field is the maximum number of matches with the same crowding tag.
993+
994+
approx_num_neighbors (int):
995+
Optional. The number of neighbors to find via approximate search
996+
before exact reordering is performed. If not set, the default
997+
value from scam config is used; if set, this value must be > 0.
998+
999+
fraction_leaf_nodes_to_search_override (float):
1000+
Optional. The fraction of the number of leaves to search, set at
1001+
query time allows user to tune search performance. This value
1002+
increase result in both search accuracy and latency increase.
1003+
The value should be between 0.0 and 1.0.
1004+
1005+
return_full_datapoint (bool):
1006+
Optional. If set to true, the full datapoints (including all
1007+
vector values and of the nearest neighbors are returned.
1008+
Note that returning full datapoint will significantly increase the
1009+
latency and cost of the query.
1010+
9821011
Returns:
9831012
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
9841013
"""
9851014

9861015
if not self._public_match_client:
9871016
raise ValueError(
988-
"Please make sure index has been deployed to public endpoint, and follow the example usage to call this method."
1017+
"Please make sure index has been deployed to public endpoint,and follow the example usage to call this method."
9891018
)
9901019

9911020
# Create the FindNeighbors request
9921021
find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest()
9931022
find_neighbors_request.index_endpoint = self.resource_name
9941023
find_neighbors_request.deployed_index_id = deployed_index_id
1024+
find_neighbors_request.return_full_datapoint = return_full_datapoint
9951025

9961026
for query in queries:
9971027
find_neighbors_query = (
9981028
gca_match_service_v1beta1.FindNeighborsRequest.Query()
9991029
)
10001030
find_neighbors_query.neighbor_count = num_neighbors
1031+
find_neighbors_query.per_crowding_attribute_neighbor_count = (
1032+
per_crowding_attribute_neighbor_count
1033+
)
1034+
find_neighbors_query.approximate_neighbor_count = approx_num_neighbors
1035+
find_neighbors_query.fraction_leaf_nodes_to_search_override = (
1036+
fraction_leaf_nodes_to_search_override
1037+
)
10011038
datapoint = gca_index_v1beta1.IndexDatapoint(feature_vector=query)
10021039
for namespace in filter:
10031040
restrict = gca_index_v1beta1.IndexDatapoint.Restriction()

tests/unit/aiplatform/test_matching_engine_index_endpoint.py

+10
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@
234234
_TEST_IDS = ["123", "456", "789"]
235235
_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS = 3
236236
_TEST_APPROX_NUM_NEIGHBORS = 2
237+
_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE = 0.8
238+
_TEST_RETURN_FULL_DATAPOINT = True
237239

238240

239241
def uuid_mock():
@@ -954,6 +956,10 @@ def test_index_public_endpoint_match_queries(
954956
queries=_TEST_QUERIES,
955957
num_neighbors=_TEST_NUM_NEIGHBOURS,
956958
filter=_TEST_FILTER,
959+
per_crowding_attribute_neighbor_count=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
960+
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
961+
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
962+
return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT,
957963
)
958964

959965
find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest(
@@ -972,8 +978,12 @@ def test_index_public_endpoint_match_queries(
972978
)
973979
],
974980
),
981+
per_crowding_attribute_neighbor_count=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
982+
approximate_neighbor_count=_TEST_APPROX_NUM_NEIGHBORS,
983+
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
975984
)
976985
],
986+
return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT,
977987
)
978988

979989
index_public_endpoint_match_queries_mock.assert_called_with(

0 commit comments

Comments
 (0)