@@ -269,7 +269,7 @@ def test_still_correct_with_early_return_generated_dotproduct(self):
269
269
270
270
class TestQueryResultsAggregatorOutputUX :
271
271
def test_can_interact_with_attributes (self ):
272
- aggregator = QueryResultsAggregator (top_k = 1 )
272
+ aggregator = QueryResultsAggregator (top_k = 2 )
273
273
results1 = {
274
274
"matches" : [
275
275
{
@@ -414,6 +414,8 @@ class TestQueryAggregatorEdgeCases:
414
414
def test_topK_too_small (self ):
415
415
with pytest .raises (QueryResultsAggregatorInvalidTopKError ):
416
416
QueryResultsAggregator (top_k = 0 )
417
+ with pytest .raises (QueryResultsAggregatorInvalidTopKError ):
418
+ QueryResultsAggregator (top_k = 1 )
417
419
418
420
def test_matches_too_small (self ):
419
421
aggregator = QueryResultsAggregator (top_k = 3 )
@@ -431,3 +433,121 @@ def test_empty_results(self):
431
433
assert results is not None
432
434
assert results .usage .read_units == 0
433
435
assert len (results .matches ) == 0
436
+
437
+ def test_empty_results_with_usage (self ):
438
+ aggregator = QueryResultsAggregator (top_k = 3 )
439
+
440
+ aggregator .add_results ({"matches" : [], "usage" : {"readUnits" : 5 }, "namespace" : "ns1" })
441
+ aggregator .add_results ({"matches" : [], "usage" : {"readUnits" : 5 }, "namespace" : "ns2" })
442
+ aggregator .add_results ({"matches" : [], "usage" : {"readUnits" : 5 }, "namespace" : "ns3" })
443
+
444
+ results = aggregator .get_results ()
445
+ assert results is not None
446
+ assert results .usage .read_units == 15
447
+ assert len (results .matches ) == 0
448
+
449
+ def test_exactly_one_result (self ):
450
+ aggregator = QueryResultsAggregator (top_k = 3 )
451
+ results1 = {
452
+ "matches" : [{"id" : "2" , "score" : 0.01 }, {"id" : "3" , "score" : 0.2 }],
453
+ "usage" : {"readUnits" : 5 },
454
+ "namespace" : "ns2" ,
455
+ }
456
+ aggregator .add_results (results1 )
457
+
458
+ results2 = {
459
+ "matches" : [{"id" : "1" , "score" : 0.1 }],
460
+ "usage" : {"readUnits" : 5 },
461
+ "namespace" : "ns1" ,
462
+ }
463
+ aggregator .add_results (results2 )
464
+ results = aggregator .get_results ()
465
+ assert results .usage .read_units == 10
466
+ assert len (results .matches ) == 3
467
+ assert results .matches [0 ].id == "2"
468
+ assert results .matches [0 ].namespace == "ns2"
469
+ assert results .matches [0 ].score == 0.01
470
+ assert results .matches [1 ].id == "1"
471
+ assert results .matches [1 ].namespace == "ns1"
472
+ assert results .matches [1 ].score == 0.1
473
+ assert results .matches [2 ].id == "3"
474
+ assert results .matches [2 ].namespace == "ns2"
475
+ assert results .matches [2 ].score == 0.2
476
+
477
+ def test_two_result_sets_with_single_result_errors (self ):
478
+ with pytest .raises (QueryResultsAggregregatorNotEnoughResultsError ):
479
+ aggregator = QueryResultsAggregator (top_k = 3 )
480
+ results1 = {
481
+ "matches" : [{"id" : "1" , "score" : 0.1 }],
482
+ "usage" : {"readUnits" : 5 },
483
+ "namespace" : "ns1" ,
484
+ }
485
+ aggregator .add_results (results1 )
486
+ results2 = {
487
+ "matches" : [{"id" : "2" , "score" : 0.01 }],
488
+ "usage" : {"readUnits" : 5 },
489
+ "namespace" : "ns2" ,
490
+ }
491
+ aggregator .add_results (results2 )
492
+
493
+ def test_single_result_after_index_type_known_no_error (self ):
494
+ aggregator = QueryResultsAggregator (top_k = 3 )
495
+
496
+ results3 = {
497
+ "matches" : [{"id" : "2" , "score" : 0.01 }, {"id" : "3" , "score" : 0.2 }],
498
+ "usage" : {"readUnits" : 5 },
499
+ "namespace" : "ns3" ,
500
+ }
501
+ aggregator .add_results (results3 )
502
+
503
+ results1 = {
504
+ "matches" : [{"id" : "1" , "score" : 0.1 }],
505
+ "usage" : {"readUnits" : 5 },
506
+ "namespace" : "ns1" ,
507
+ }
508
+ aggregator .add_results (results1 )
509
+ results2 = {"matches" : [], "usage" : {"readUnits" : 5 }, "namespace" : "ns2" }
510
+ aggregator .add_results (results2 )
511
+
512
+ results = aggregator .get_results ()
513
+ assert results .usage .read_units == 15
514
+ assert len (results .matches ) == 3
515
+ assert results .matches [0 ].id == "2"
516
+ assert results .matches [0 ].namespace == "ns3"
517
+ assert results .matches [0 ].score == 0.01
518
+ assert results .matches [1 ].id == "1"
519
+ assert results .matches [1 ].namespace == "ns1"
520
+ assert results .matches [1 ].score == 0.1
521
+ assert results .matches [2 ].id == "3"
522
+ assert results .matches [2 ].namespace == "ns3"
523
+ assert results .matches [2 ].score == 0.2
524
+
525
+ def test_all_empty_results (self ):
526
+ aggregator = QueryResultsAggregator (top_k = 10 )
527
+
528
+ aggregator .add_results ({"matches" : [], "usage" : {"readUnits" : 5 }, "namespace" : "ns1" })
529
+ aggregator .add_results ({"matches" : [], "usage" : {"readUnits" : 5 }, "namespace" : "ns2" })
530
+ aggregator .add_results ({"matches" : [], "usage" : {"readUnits" : 5 }, "namespace" : "ns3" })
531
+
532
+ results = aggregator .get_results ()
533
+
534
+ assert results .usage .read_units == 15
535
+ assert len (results .matches ) == 0
536
+
537
+ def test_some_empty_results (self ):
538
+ aggregator = QueryResultsAggregator (top_k = 10 )
539
+ results2 = {
540
+ "matches" : [{"id" : "2" , "score" : 0.01 }, {"id" : "3" , "score" : 0.2 }],
541
+ "usage" : {"readUnits" : 5 },
542
+ "namespace" : "ns0" ,
543
+ }
544
+ aggregator .add_results (results2 )
545
+
546
+ aggregator .add_results ({"matches" : [], "usage" : {"readUnits" : 5 }, "namespace" : "ns1" })
547
+ aggregator .add_results ({"matches" : [], "usage" : {"readUnits" : 5 }, "namespace" : "ns2" })
548
+ aggregator .add_results ({"matches" : [], "usage" : {"readUnits" : 5 }, "namespace" : "ns3" })
549
+
550
+ results = aggregator .get_results ()
551
+
552
+ assert results .usage .read_units == 20
553
+ assert len (results .matches ) == 2
0 commit comments