Skip to content

Commit

Permalink
Do not raise when result set is empty
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamon committed Oct 22, 2024
1 parent fbf4165 commit 4345ece
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 7 deletions.
17 changes: 11 additions & 6 deletions pinecone/grpc/query_results_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,20 +89,22 @@ def __init__(self, namespace: str):


class QueryResultsAggregregatorNotEnoughResultsError(Exception):
def __init__(self, top_k: int, num_results: int):
def __init__(self, num_results: int):
super().__init__(
f"Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores. Expected at least {top_k} results but got {num_results}."
"Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores."
)


class QueryResultsAggregatorInvalidTopKError(Exception):
def __init__(self, top_k: int):
super().__init__(f"Invalid top_k value {top_k}. top_k must be a positive integer.")
super().__init__(
f"Invalid top_k value {top_k}. To aggregate results from multiple queries the top_k must be at least 2."
)


class QueryResultsAggregator:
def __init__(self, top_k: int):
if top_k < 1:
if top_k < 2:
raise QueryResultsAggregatorInvalidTopKError(top_k)
self.top_k = top_k
self.usage_read_units = 0
Expand Down Expand Up @@ -155,11 +157,14 @@ def add_results(self, results: Dict[str, Any]):
self.usage_read_units += results.get("usage", {}).get("readUnits", 0)

if len(matches) == 0:
raise QueryResultsAggregationEmptyResultsError(ns)
return

if self.is_dotproduct is None:
if len(matches) == 1:
raise QueryResultsAggregregatorNotEnoughResultsError(self.top_k, len(matches))
# This condition should match the second time we add results containing
# only one match. We need at least two matches in a single response in order
# to infer the similarity metric
raise QueryResultsAggregregatorNotEnoughResultsError(len(matches))
self.is_dotproduct = self._is_dotproduct_index(matches)

if self.is_dotproduct:
Expand Down
122 changes: 121 additions & 1 deletion tests/unit_grpc/test_query_results_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def test_still_correct_with_early_return_generated_dotproduct(self):

class TestQueryResultsAggregatorOutputUX:
def test_can_interact_with_attributes(self):
aggregator = QueryResultsAggregator(top_k=1)
aggregator = QueryResultsAggregator(top_k=2)
results1 = {
"matches": [
{
Expand Down Expand Up @@ -414,6 +414,8 @@ class TestQueryAggregatorEdgeCases:
def test_topK_too_small(self):
with pytest.raises(QueryResultsAggregatorInvalidTopKError):
QueryResultsAggregator(top_k=0)
with pytest.raises(QueryResultsAggregatorInvalidTopKError):
QueryResultsAggregator(top_k=1)

def test_matches_too_small(self):
aggregator = QueryResultsAggregator(top_k=3)
Expand All @@ -431,3 +433,121 @@ def test_empty_results(self):
assert results is not None
assert results.usage.read_units == 0
assert len(results.matches) == 0

def test_empty_results_with_usage(self):
aggregator = QueryResultsAggregator(top_k=3)

aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns1"})
aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"})
aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns3"})

results = aggregator.get_results()
assert results is not None
assert results.usage.read_units == 15
assert len(results.matches) == 0

def test_exactly_one_result(self):
aggregator = QueryResultsAggregator(top_k=3)
results1 = {
"matches": [{"id": "2", "score": 0.01}, {"id": "3", "score": 0.2}],
"usage": {"readUnits": 5},
"namespace": "ns2",
}
aggregator.add_results(results1)

results2 = {
"matches": [{"id": "1", "score": 0.1}],
"usage": {"readUnits": 5},
"namespace": "ns1",
}
aggregator.add_results(results2)
results = aggregator.get_results()
assert results.usage.read_units == 10
assert len(results.matches) == 3
assert results.matches[0].id == "2"
assert results.matches[0].namespace == "ns2"
assert results.matches[0].score == 0.01
assert results.matches[1].id == "1"
assert results.matches[1].namespace == "ns1"
assert results.matches[1].score == 0.1
assert results.matches[2].id == "3"
assert results.matches[2].namespace == "ns2"
assert results.matches[2].score == 0.2

def test_two_result_sets_with_single_result_errors(self):
with pytest.raises(QueryResultsAggregregatorNotEnoughResultsError):
aggregator = QueryResultsAggregator(top_k=3)
results1 = {
"matches": [{"id": "1", "score": 0.1}],
"usage": {"readUnits": 5},
"namespace": "ns1",
}
aggregator.add_results(results1)
results2 = {
"matches": [{"id": "2", "score": 0.01}],
"usage": {"readUnits": 5},
"namespace": "ns2",
}
aggregator.add_results(results2)

def test_single_result_after_index_type_known_no_error(self):
aggregator = QueryResultsAggregator(top_k=3)

results3 = {
"matches": [{"id": "2", "score": 0.01}, {"id": "3", "score": 0.2}],
"usage": {"readUnits": 5},
"namespace": "ns3",
}
aggregator.add_results(results3)

results1 = {
"matches": [{"id": "1", "score": 0.1}],
"usage": {"readUnits": 5},
"namespace": "ns1",
}
aggregator.add_results(results1)
results2 = {"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"}
aggregator.add_results(results2)

results = aggregator.get_results()
assert results.usage.read_units == 15
assert len(results.matches) == 3
assert results.matches[0].id == "2"
assert results.matches[0].namespace == "ns3"
assert results.matches[0].score == 0.01
assert results.matches[1].id == "1"
assert results.matches[1].namespace == "ns1"
assert results.matches[1].score == 0.1
assert results.matches[2].id == "3"
assert results.matches[2].namespace == "ns3"
assert results.matches[2].score == 0.2

def test_all_empty_results(self):
aggregator = QueryResultsAggregator(top_k=10)

aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns1"})
aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"})
aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns3"})

results = aggregator.get_results()

assert results.usage.read_units == 15
assert len(results.matches) == 0

def test_some_empty_results(self):
aggregator = QueryResultsAggregator(top_k=10)
results2 = {
"matches": [{"id": "2", "score": 0.01}, {"id": "3", "score": 0.2}],
"usage": {"readUnits": 5},
"namespace": "ns0",
}
aggregator.add_results(results2)

aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns1"})
aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"})
aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns3"})

results = aggregator.get_results()

assert results.usage.read_units == 20
assert len(results.matches) == 2

0 comments on commit 4345ece

Please sign in to comment.