18
18
Union ,
19
19
)
20
20
21
+ from redisvl .redis .utils import convert_bytes , make_dict
21
22
from redisvl .utils .utils import deprecated_argument , deprecated_function , sync_wrapper
22
23
23
24
if TYPE_CHECKING :
39
40
SchemaValidationError ,
40
41
)
41
42
from redisvl .index .storage import BaseStorage , HashStorage , JsonStorage
42
- from redisvl .query import BaseQuery , BaseVectorQuery , CountQuery , FilterQuery
43
+ from redisvl .query import (
44
+ AggregationQuery ,
45
+ BaseQuery ,
46
+ BaseVectorQuery ,
47
+ CountQuery ,
48
+ FilterQuery ,
49
+ HybridQuery ,
50
+ )
43
51
from redisvl .query .filter import FilterExpression
44
52
from redisvl .redis .connection import (
45
53
RedisConnectionFactory ,
@@ -138,6 +146,34 @@ def _process(doc: "Document") -> Dict[str, Any]:
138
146
return [_process (doc ) for doc in results .docs ]
139
147
140
148
149
+ def process_aggregate_results (
150
+ results : "AggregateResult" , query : AggregationQuery , storage_type : StorageType
151
+ ) -> List [Dict [str , Any ]]:
152
+ """Convert an aggregate reslt object into a list of document dictionaries.
153
+
154
+ This function processes results from Redis, handling different storage
155
+ types and query types. For JSON storage with empty return fields, it
156
+ unpacks the JSON object while retaining the document ID. The 'payload'
157
+ field is also removed from all resulting documents for consistency.
158
+
159
+ Args:
160
+ results (AggregarteResult): The aggregart results from Redis.
161
+ query (AggregationQuery): The aggregation query object used for the aggregation.
162
+ storage_type (StorageType): The storage type of the search
163
+ index (json or hash).
164
+
165
+ Returns:
166
+ List[Dict[str, Any]]: A list of processed document dictionaries.
167
+ """
168
+
169
+ def _process (row ):
170
+ result = make_dict (convert_bytes (row ))
171
+ result .pop ("__score" , None )
172
+ return result
173
+
174
+ return [_process (r ) for r in results .rows ]
175
+
176
+
141
177
class BaseSearchIndex :
142
178
"""Base search engine class"""
143
179
@@ -650,6 +686,17 @@ def fetch(self, id: str) -> Optional[Dict[str, Any]]:
650
686
return convert_bytes (obj [0 ])
651
687
return None
652
688
689
+ def _aggregate (self , aggregation_query : AggregationQuery ) -> List [Dict [str , Any ]]:
690
+ """Execute an aggretation query and processes the results."""
691
+ results = self .aggregate (
692
+ aggregation_query , query_params = aggregation_query .params # type: ignore[attr-defined]
693
+ )
694
+ return process_aggregate_results (
695
+ results ,
696
+ query = aggregation_query ,
697
+ storage_type = self .schema .index .storage_type ,
698
+ )
699
+
653
700
def aggregate (self , * args , ** kwargs ) -> "AggregateResult" :
654
701
"""Perform an aggregation operation against the index.
655
702
@@ -772,14 +819,14 @@ def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
772
819
results = self .search (query .query , query_params = query .params )
773
820
return process_results (results , query = query , schema = self .schema )
774
821
775
- def query (self , query : BaseQuery ) -> List [Dict [str , Any ]]:
822
+ def query (self , query : Union [ BaseQuery , AggregationQuery ] ) -> List [Dict [str , Any ]]:
776
823
"""Execute a query on the index.
777
824
778
- This method takes a BaseQuery object directly, runs the search , and
825
+ This method takes a BaseQuery or AggregationQuery object directly , and
779
826
handles post-processing of the search.
780
827
781
828
Args:
782
- query (BaseQuery): The query to run.
829
+ query (Union[ BaseQuery, AggregateQuery] ): The query to run.
783
830
784
831
Returns:
785
832
List[Result]: A list of search results.
@@ -797,7 +844,10 @@ def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
797
844
results = index.query(query)
798
845
799
846
"""
800
- return self ._query (query )
847
+ if isinstance (query , AggregationQuery ):
848
+ return self ._aggregate (query )
849
+ else :
850
+ return self ._query (query )
801
851
802
852
def paginate (self , query : BaseQuery , page_size : int = 30 ) -> Generator :
803
853
"""Execute a given query against the index and return results in
@@ -1303,6 +1353,19 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]:
1303
1353
return convert_bytes (obj [0 ])
1304
1354
return None
1305
1355
1356
+ async def _aggregate (
1357
+ self , aggregation_query : AggregationQuery
1358
+ ) -> List [Dict [str , Any ]]:
1359
+ """Execute an aggretation query and processes the results."""
1360
+ results = await self .aggregate (
1361
+ aggregation_query , query_params = aggregation_query .params # type: ignore[attr-defined]
1362
+ )
1363
+ return process_aggregate_results (
1364
+ results ,
1365
+ query = aggregation_query ,
1366
+ storage_type = self .schema .index .storage_type ,
1367
+ )
1368
+
1306
1369
async def aggregate (self , * args , ** kwargs ) -> "AggregateResult" :
1307
1370
"""Perform an aggregation operation against the index.
1308
1371
@@ -1426,14 +1489,16 @@ async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
1426
1489
results = await self .search (query .query , query_params = query .params )
1427
1490
return process_results (results , query = query , schema = self .schema )
1428
1491
1429
- async def query (self , query : BaseQuery ) -> List [Dict [str , Any ]]:
1492
+ async def query (
1493
+ self , query : Union [BaseQuery , AggregationQuery ]
1494
+ ) -> List [Dict [str , Any ]]:
1430
1495
"""Asynchronously execute a query on the index.
1431
1496
1432
- This method takes a BaseQuery object directly, runs the search, and
1433
- handles post-processing of the search.
1497
+ This method takes a BaseQuery or AggregationQuery object directly, runs
1498
+ the search, and handles post-processing of the search.
1434
1499
1435
1500
Args:
1436
- query (BaseQuery): The query to run.
1501
+ query (Union[ BaseQuery, AggregateQuery] ): The query to run.
1437
1502
1438
1503
Returns:
1439
1504
List[Result]: A list of search results.
@@ -1450,7 +1515,10 @@ async def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
1450
1515
1451
1516
results = await index.query(query)
1452
1517
"""
1453
- return await self ._query (query )
1518
+ if isinstance (query , AggregationQuery ):
1519
+ return await self ._aggregate (query )
1520
+ else :
1521
+ return await self ._query (query )
1454
1522
1455
1523
async def paginate (self , query : BaseQuery , page_size : int = 30 ) -> AsyncGenerator :
1456
1524
"""Execute a given query against the index and return results in
0 commit comments