Skip to content

Commit 4345ece

Browse files
committed
Do not raise when result set is empty
1 parent fbf4165 commit 4345ece

File tree

2 files changed

+132
-7
lines changed

2 files changed

+132
-7
lines changed

pinecone/grpc/query_results_aggregator.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -89,20 +89,22 @@ def __init__(self, namespace: str):
8989

9090

9191
class QueryResultsAggregregatorNotEnoughResultsError(Exception):
92-
def __init__(self, top_k: int, num_results: int):
92+
def __init__(self, num_results: int):
9393
super().__init__(
94-
f"Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores. Expected at least {top_k} results but got {num_results}."
94+
"Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores."
9595
)
9696

9797

9898
class QueryResultsAggregatorInvalidTopKError(Exception):
9999
def __init__(self, top_k: int):
100-
super().__init__(f"Invalid top_k value {top_k}. top_k must be a positive integer.")
100+
super().__init__(
101+
f"Invalid top_k value {top_k}. To aggregate results from multiple queries the top_k must be at least 2."
102+
)
101103

102104

103105
class QueryResultsAggregator:
104106
def __init__(self, top_k: int):
105-
if top_k < 1:
107+
if top_k < 2:
106108
raise QueryResultsAggregatorInvalidTopKError(top_k)
107109
self.top_k = top_k
108110
self.usage_read_units = 0
@@ -155,11 +157,14 @@ def add_results(self, results: Dict[str, Any]):
155157
self.usage_read_units += results.get("usage", {}).get("readUnits", 0)
156158

157159
if len(matches) == 0:
158-
raise QueryResultsAggregationEmptyResultsError(ns)
160+
return
159161

160162
if self.is_dotproduct is None:
161163
if len(matches) == 1:
162-
raise QueryResultsAggregregatorNotEnoughResultsError(self.top_k, len(matches))
164+
# This condition should match the second time we add results containing
165+
# only one match. We need at least two matches in a single response in order
166+
# to infer the similarity metric
167+
raise QueryResultsAggregregatorNotEnoughResultsError(len(matches))
163168
self.is_dotproduct = self._is_dotproduct_index(matches)
164169

165170
if self.is_dotproduct:

tests/unit_grpc/test_query_results_aggregator.py

+121-1
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def test_still_correct_with_early_return_generated_dotproduct(self):
269269

270270
class TestQueryResultsAggregatorOutputUX:
271271
def test_can_interact_with_attributes(self):
272-
aggregator = QueryResultsAggregator(top_k=1)
272+
aggregator = QueryResultsAggregator(top_k=2)
273273
results1 = {
274274
"matches": [
275275
{
@@ -414,6 +414,8 @@ class TestQueryAggregatorEdgeCases:
414414
def test_topK_too_small(self):
415415
with pytest.raises(QueryResultsAggregatorInvalidTopKError):
416416
QueryResultsAggregator(top_k=0)
417+
with pytest.raises(QueryResultsAggregatorInvalidTopKError):
418+
QueryResultsAggregator(top_k=1)
417419

418420
def test_matches_too_small(self):
419421
aggregator = QueryResultsAggregator(top_k=3)
@@ -431,3 +433,121 @@ def test_empty_results(self):
431433
assert results is not None
432434
assert results.usage.read_units == 0
433435
assert len(results.matches) == 0
436+
437+
def test_empty_results_with_usage(self):
438+
aggregator = QueryResultsAggregator(top_k=3)
439+
440+
aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns1"})
441+
aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"})
442+
aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns3"})
443+
444+
results = aggregator.get_results()
445+
assert results is not None
446+
assert results.usage.read_units == 15
447+
assert len(results.matches) == 0
448+
449+
def test_exactly_one_result(self):
450+
aggregator = QueryResultsAggregator(top_k=3)
451+
results1 = {
452+
"matches": [{"id": "2", "score": 0.01}, {"id": "3", "score": 0.2}],
453+
"usage": {"readUnits": 5},
454+
"namespace": "ns2",
455+
}
456+
aggregator.add_results(results1)
457+
458+
results2 = {
459+
"matches": [{"id": "1", "score": 0.1}],
460+
"usage": {"readUnits": 5},
461+
"namespace": "ns1",
462+
}
463+
aggregator.add_results(results2)
464+
results = aggregator.get_results()
465+
assert results.usage.read_units == 10
466+
assert len(results.matches) == 3
467+
assert results.matches[0].id == "2"
468+
assert results.matches[0].namespace == "ns2"
469+
assert results.matches[0].score == 0.01
470+
assert results.matches[1].id == "1"
471+
assert results.matches[1].namespace == "ns1"
472+
assert results.matches[1].score == 0.1
473+
assert results.matches[2].id == "3"
474+
assert results.matches[2].namespace == "ns2"
475+
assert results.matches[2].score == 0.2
476+
477+
def test_two_result_sets_with_single_result_errors(self):
478+
with pytest.raises(QueryResultsAggregregatorNotEnoughResultsError):
479+
aggregator = QueryResultsAggregator(top_k=3)
480+
results1 = {
481+
"matches": [{"id": "1", "score": 0.1}],
482+
"usage": {"readUnits": 5},
483+
"namespace": "ns1",
484+
}
485+
aggregator.add_results(results1)
486+
results2 = {
487+
"matches": [{"id": "2", "score": 0.01}],
488+
"usage": {"readUnits": 5},
489+
"namespace": "ns2",
490+
}
491+
aggregator.add_results(results2)
492+
493+
def test_single_result_after_index_type_known_no_error(self):
494+
aggregator = QueryResultsAggregator(top_k=3)
495+
496+
results3 = {
497+
"matches": [{"id": "2", "score": 0.01}, {"id": "3", "score": 0.2}],
498+
"usage": {"readUnits": 5},
499+
"namespace": "ns3",
500+
}
501+
aggregator.add_results(results3)
502+
503+
results1 = {
504+
"matches": [{"id": "1", "score": 0.1}],
505+
"usage": {"readUnits": 5},
506+
"namespace": "ns1",
507+
}
508+
aggregator.add_results(results1)
509+
results2 = {"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"}
510+
aggregator.add_results(results2)
511+
512+
results = aggregator.get_results()
513+
assert results.usage.read_units == 15
514+
assert len(results.matches) == 3
515+
assert results.matches[0].id == "2"
516+
assert results.matches[0].namespace == "ns3"
517+
assert results.matches[0].score == 0.01
518+
assert results.matches[1].id == "1"
519+
assert results.matches[1].namespace == "ns1"
520+
assert results.matches[1].score == 0.1
521+
assert results.matches[2].id == "3"
522+
assert results.matches[2].namespace == "ns3"
523+
assert results.matches[2].score == 0.2
524+
525+
def test_all_empty_results(self):
526+
aggregator = QueryResultsAggregator(top_k=10)
527+
528+
aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns1"})
529+
aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"})
530+
aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns3"})
531+
532+
results = aggregator.get_results()
533+
534+
assert results.usage.read_units == 15
535+
assert len(results.matches) == 0
536+
537+
def test_some_empty_results(self):
538+
aggregator = QueryResultsAggregator(top_k=10)
539+
results2 = {
540+
"matches": [{"id": "2", "score": 0.01}, {"id": "3", "score": 0.2}],
541+
"usage": {"readUnits": 5},
542+
"namespace": "ns0",
543+
}
544+
aggregator.add_results(results2)
545+
546+
aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns1"})
547+
aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"})
548+
aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns3"})
549+
550+
results = aggregator.get_results()
551+
552+
assert results.usage.read_units == 20
553+
assert len(results.matches) == 2

0 commit comments

Comments
 (0)