Skip to content

Commit 52cfe75

Browse files
justin-cechmanekrbs333abrookinstylerhutcherson
committed
Add support for full text queries and hybrid search queries (#303)
Co-authored-by: Robert Shelton <[email protected]> Co-authored-by: Andrew Brookins <[email protected]> Co-authored-by: Tyler Hutcherson <[email protected]> Co-authored-by: Robert Shelton <[email protected]>
1 parent 5f4c85a commit 52cfe75

11 files changed

+1331
-329
lines changed

poetry.lock

+242-313
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ tenacity = ">=8.2.2"
3333
tabulate = "^0.9.0"
3434
ml-dtypes = "^0.4.0"
3535
python-ulid = "^3.0.0"
36+
nltk = { version = "^3.8.1", optional = true }
3637
jsonpath-ng = "^1.5.0"
37-
3838
openai = { version = "^1.13.0", optional = true }
3939
sentence-transformers = { version = "^3.4.0", optional = true }
4040
scipy = [
@@ -58,6 +58,7 @@ mistralai = ["mistralai"]
5858
voyageai = ["voyageai"]
5959
ranx = ["ranx"]
6060
bedrock = ["boto3"]
61+
nltk = ["nltk"]
6162

6263
[tool.poetry.group.dev.dependencies]
6364
black = "^25.1.0"

redisvl/index/index.py

+78-10
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Union,
1919
)
2020

21+
from redisvl.redis.utils import convert_bytes, make_dict
2122
from redisvl.utils.utils import deprecated_argument, deprecated_function, sync_wrapper
2223

2324
if TYPE_CHECKING:
@@ -39,7 +40,14 @@
3940
SchemaValidationError,
4041
)
4142
from redisvl.index.storage import BaseStorage, HashStorage, JsonStorage
42-
from redisvl.query import BaseQuery, BaseVectorQuery, CountQuery, FilterQuery
43+
from redisvl.query import (
44+
AggregationQuery,
45+
BaseQuery,
46+
BaseVectorQuery,
47+
CountQuery,
48+
FilterQuery,
49+
HybridQuery,
50+
)
4351
from redisvl.query.filter import FilterExpression
4452
from redisvl.redis.connection import (
4553
RedisConnectionFactory,
@@ -138,6 +146,34 @@ def _process(doc: "Document") -> Dict[str, Any]:
138146
return [_process(doc) for doc in results.docs]
139147

140148

149+
def process_aggregate_results(
150+
results: "AggregateResult", query: AggregationQuery, storage_type: StorageType
151+
) -> List[Dict[str, Any]]:
152+
"""Convert an aggregate reslt object into a list of document dictionaries.
153+
154+
This function processes results from Redis, handling different storage
155+
types and query types. For JSON storage with empty return fields, it
156+
unpacks the JSON object while retaining the document ID. The 'payload'
157+
field is also removed from all resulting documents for consistency.
158+
159+
Args:
160+
results (AggregarteResult): The aggregart results from Redis.
161+
query (AggregationQuery): The aggregation query object used for the aggregation.
162+
storage_type (StorageType): The storage type of the search
163+
index (json or hash).
164+
165+
Returns:
166+
List[Dict[str, Any]]: A list of processed document dictionaries.
167+
"""
168+
169+
def _process(row):
170+
result = make_dict(convert_bytes(row))
171+
result.pop("__score", None)
172+
return result
173+
174+
return [_process(r) for r in results.rows]
175+
176+
141177
class BaseSearchIndex:
142178
"""Base search engine class"""
143179

@@ -650,6 +686,17 @@ def fetch(self, id: str) -> Optional[Dict[str, Any]]:
650686
return convert_bytes(obj[0])
651687
return None
652688

689+
def _aggregate(self, aggregation_query: AggregationQuery) -> List[Dict[str, Any]]:
690+
"""Execute an aggretation query and processes the results."""
691+
results = self.aggregate(
692+
aggregation_query, query_params=aggregation_query.params # type: ignore[attr-defined]
693+
)
694+
return process_aggregate_results(
695+
results,
696+
query=aggregation_query,
697+
storage_type=self.schema.index.storage_type,
698+
)
699+
653700
def aggregate(self, *args, **kwargs) -> "AggregateResult":
654701
"""Perform an aggregation operation against the index.
655702
@@ -772,14 +819,14 @@ def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
772819
results = self.search(query.query, query_params=query.params)
773820
return process_results(results, query=query, schema=self.schema)
774821

775-
def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
822+
def query(self, query: Union[BaseQuery, AggregationQuery]) -> List[Dict[str, Any]]:
776823
"""Execute a query on the index.
777824
778-
This method takes a BaseQuery object directly, runs the search, and
825+
This method takes a BaseQuery or AggregationQuery object directly, and
779826
handles post-processing of the search.
780827
781828
Args:
782-
query (BaseQuery): The query to run.
829+
query (Union[BaseQuery, AggregateQuery]): The query to run.
783830
784831
Returns:
785832
List[Result]: A list of search results.
@@ -797,7 +844,10 @@ def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
797844
results = index.query(query)
798845
799846
"""
800-
return self._query(query)
847+
if isinstance(query, AggregationQuery):
848+
return self._aggregate(query)
849+
else:
850+
return self._query(query)
801851

802852
def paginate(self, query: BaseQuery, page_size: int = 30) -> Generator:
803853
"""Execute a given query against the index and return results in
@@ -1303,6 +1353,19 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]:
13031353
return convert_bytes(obj[0])
13041354
return None
13051355

1356+
async def _aggregate(
1357+
self, aggregation_query: AggregationQuery
1358+
) -> List[Dict[str, Any]]:
1359+
"""Execute an aggretation query and processes the results."""
1360+
results = await self.aggregate(
1361+
aggregation_query, query_params=aggregation_query.params # type: ignore[attr-defined]
1362+
)
1363+
return process_aggregate_results(
1364+
results,
1365+
query=aggregation_query,
1366+
storage_type=self.schema.index.storage_type,
1367+
)
1368+
13061369
async def aggregate(self, *args, **kwargs) -> "AggregateResult":
13071370
"""Perform an aggregation operation against the index.
13081371
@@ -1426,14 +1489,16 @@ async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
14261489
results = await self.search(query.query, query_params=query.params)
14271490
return process_results(results, query=query, schema=self.schema)
14281491

1429-
async def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
1492+
async def query(
1493+
self, query: Union[BaseQuery, AggregationQuery]
1494+
) -> List[Dict[str, Any]]:
14301495
"""Asynchronously execute a query on the index.
14311496
1432-
This method takes a BaseQuery object directly, runs the search, and
1433-
handles post-processing of the search.
1497+
This method takes a BaseQuery or AggregationQuery object directly, runs
1498+
the search, and handles post-processing of the search.
14341499
14351500
Args:
1436-
query (BaseQuery): The query to run.
1501+
query (Union[BaseQuery, AggregateQuery]): The query to run.
14371502
14381503
Returns:
14391504
List[Result]: A list of search results.
@@ -1450,7 +1515,10 @@ async def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
14501515
14511516
results = await index.query(query)
14521517
"""
1453-
return await self._query(query)
1518+
if isinstance(query, AggregationQuery):
1519+
return await self._aggregate(query)
1520+
else:
1521+
return await self._query(query)
14541522

14551523
async def paginate(self, query: BaseQuery, page_size: int = 30) -> AsyncGenerator:
14561524
"""Execute a given query against the index and return results in

redisvl/query/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from redisvl.query.aggregate import AggregationQuery, HybridQuery
12
from redisvl.query.query import (
23
BaseQuery,
34
BaseVectorQuery,
45
CountQuery,
56
FilterQuery,
67
RangeQuery,
8+
TextQuery,
79
VectorQuery,
810
VectorRangeQuery,
911
)
@@ -16,4 +18,7 @@
1618
"RangeQuery",
1719
"VectorRangeQuery",
1820
"CountQuery",
21+
"TextQuery",
22+
"AggregationQuery",
23+
"HybridQuery",
1924
]

0 commit comments

Comments
 (0)