Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions nemo_retriever/src/nemo_retriever/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ def _execute_queries_graph(
text_col = str(embed_params.text_column)
df = pd.DataFrame({text_col: query_texts})

# Hybrid retrieval relies on these ordered query strings staying aligned
# with the embedded rows produced from ``df``. If this query graph grows
# distributed/shuffled stages, carry row-local query text or IDs instead.
graph = self._get_graph(embed_extra=embed_extra)
if not callable(getattr(graph, "resolve_for_local_execution", None)):
raise TypeError("graph must provide resolve_for_local_execution() (e.g. pipeline_graph.Graph)")
Expand Down
49 changes: 44 additions & 5 deletions nemo_retriever/src/nemo_retriever/vdb/lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import time

from collections.abc import Iterable, Sequence
from datetime import timedelta
from typing import Any, Final, FrozenSet

Expand Down Expand Up @@ -529,7 +530,7 @@ def run(self, records):
logger.info("Skipping LanceDB index creation for table %r because build_index=False.", self.table_name)
return records

def retrieval(self, vectors, **kwargs):
def retrieval(self, vectors: Iterable[Sequence[float]], **kwargs: Any) -> list[list[dict[str, Any]]]:
"""Search LanceDB with precomputed query vectors.

Keyword arguments
Expand All @@ -546,10 +547,12 @@ def retrieval(self, vectors, **kwargs):
``table.search`` (e.g. ``query_type``, ``fts_columns``). Do not
pass ``vector_column_name`` here; use the top-level
``vector_column_name`` retrieval argument instead.
query_texts:
Raw query strings aligned with ``vectors``. Required for
``hybrid=True`` and ignored for dense-only retrieval.
"""
hybrid = kwargs.pop("hybrid", self.hybrid)
if hybrid:
raise NotImplementedError("LanceDB hybrid retrieval with precomputed vectors is not implemented yet.")
query_texts = kwargs.pop("query_texts", None)
table_path = kwargs.pop("table_path", self.uri)
table_name = kwargs.pop("table_name", self.table_name)

Expand All @@ -567,6 +570,23 @@ def retrieval(self, vectors, **kwargs):
else:
search_kwargs = dict(search_kwargs_raw)

if hybrid:
if query_texts is None:
raise ValueError(
"LanceDB hybrid retrieval requires query_texts. Pass query_texts=your_queries "
"alongside vectors when calling retrieval() with hybrid=True."
)
query_type = search_kwargs.get("query_type")
if query_type is not None:
query_type_value = getattr(query_type, "value", query_type)
if str(query_type_value).lower() != "hybrid":
raise ValueError(
"LanceDB hybrid retrieval requires search_kwargs['query_type']='hybrid'; "
f"got {query_type!r}."
)
search_kwargs["query_type"] = "hybrid"
search_kwargs.setdefault("fts_columns", "text")

where_clause = kwargs.pop("where", None)
_filter_fallback = kwargs.pop("_filter", None)
if where_clause is None:
Expand All @@ -576,9 +596,28 @@ def retrieval(self, vectors, **kwargs):

table = lancedb.connect(uri=table_path).open_table(table_name)

if hybrid:
vectors_for_search = list(vectors)
query_texts_list = [query_texts] if isinstance(query_texts, str) else list(query_texts)
if len(query_texts_list) != len(vectors_for_search):
raise ValueError(
"LanceDB hybrid retrieval requires query_texts length to match vectors length; "
f"got query_texts={len(query_texts_list)} vectors={len(vectors_for_search)}."
)
else:
vectors_for_search = vectors
query_texts_list = []

search_results = []
for vector in vectors:
query = table.search([vector], vector_column_name=vector_column_name, **search_kwargs)
for idx, vector in enumerate(vectors_for_search):
if hybrid:
query = (
table.search(vector_column_name=vector_column_name, **search_kwargs)
.vector(vector)
.text(str(query_texts_list[idx]))
)
else:
query = table.search([vector], vector_column_name=vector_column_name, **search_kwargs)
if where_clause is not None:
query = query.where(where_clause)
query = query.limit(top_k).refine_factor(refine_factor).nprobes(n_probe)
Expand Down
3 changes: 3 additions & 0 deletions nemo_retriever/src/nemo_retriever/vdb/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(
) -> None:
merged = dict(vdb_kwargs or {})
clean_kwargs, _sidecar = split_sidecar_from_vdb_kwargs(merged)
clean_kwargs.pop("query_texts", None)
super().__init__(vdb=vdb, vdb_op=vdb_op, vdb_kwargs=clean_kwargs, explode_for_rerank=explode_for_rerank)
self._vdb_kwargs = clean_kwargs
self._retrieval_vdb_kwargs = clean_kwargs
Expand All @@ -162,6 +163,8 @@ def process(self, data: Any, **kwargs: Any) -> list[list[dict[str, Any]]]:
from nemo_retriever.retriever_graph_utils import filter_retrieval_kwargs

retrieval_kwargs = {**self._retrieval_vdb_kwargs, **filter_retrieval_kwargs(kwargs)}
if retrieval_kwargs.get("hybrid") and "query_texts" in kwargs:
retrieval_kwargs["query_texts"] = kwargs["query_texts"]
return normalize_retrieval_results(self._vdb.retrieval(data, **retrieval_kwargs))

def postprocess(self, data: Any, **kwargs: Any) -> Any:
Expand Down
85 changes: 83 additions & 2 deletions nemo_retriever/tests/test_lancedb_retrieval_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from nemo_retriever.vdb.lancedb import LanceDB


def _tiny_table(uri: str) -> None:
def _tiny_table(uri: str, *, create_fts_index: bool = False) -> None:
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2)),
Expand All @@ -41,7 +41,9 @@ def _tiny_table(uri: str) -> None:
},
]
db = lancedb.connect(uri)
db.create_table("t", rows, schema=schema, mode="overwrite")
table = db.create_table("t", rows, schema=schema, mode="overwrite")
if create_fts_index:
table.create_fts_index("text", replace=True)


def test_retrieval_where_filters_rows() -> None:
Expand Down Expand Up @@ -101,3 +103,82 @@ def test_retrieval_search_kwargs_must_be_dict() -> None:
op = LanceDB(uri=d, table_name="t", overwrite=False, vector_dim=2, validate_vector_length=False)
with pytest.raises(TypeError, match="search_kwargs"):
op.retrieval([[1.0, 0.0]], top_k=5, table_path=d, table_name="t", search_kwargs="bad")


def test_hybrid_retrieval_uses_query_texts() -> None:
d = tempfile.mkdtemp()
_tiny_table(d, create_fts_index=True)
op = LanceDB(uri=d, table_name="t", overwrite=False, vector_dim=2, validate_vector_length=False)

results = op.retrieval(
[[1.0, 0.0]],
top_k=2,
table_path=d,
table_name="t",
hybrid=True,
query_texts=["alpha"],
)

assert results[0]
assert results[0][0]["text"] == "alpha"


def test_hybrid_retrieval_requires_query_texts() -> None:
d = tempfile.mkdtemp()
_tiny_table(d, create_fts_index=True)
op = LanceDB(uri=d, table_name="t", overwrite=False, vector_dim=2, validate_vector_length=False)

with pytest.raises(ValueError, match="requires query_texts"):
op.retrieval([[1.0, 0.0]], top_k=2, table_path=d, table_name="t", hybrid=True)


def test_hybrid_retrieval_requires_query_texts_aligned_with_vectors() -> None:
d = tempfile.mkdtemp()
_tiny_table(d, create_fts_index=True)
op = LanceDB(uri=d, table_name="t", overwrite=False, vector_dim=2, validate_vector_length=False)

with pytest.raises(ValueError, match="length to match vectors length"):
op.retrieval(
[[1.0, 0.0]],
top_k=2,
table_path=d,
table_name="t",
hybrid=True,
query_texts=["alpha", "beta"],
)


def test_hybrid_retrieval_where_filters_rows() -> None:
d = tempfile.mkdtemp()
_tiny_table(d, create_fts_index=True)
op = LanceDB(uri=d, table_name="t", overwrite=False, vector_dim=2, validate_vector_length=False)

filtered = op.retrieval(
[[1.0, 0.0]],
top_k=10,
table_path=d,
table_name="t",
hybrid=True,
query_texts=["beta"],
where="text = 'beta'",
)

assert len(filtered[0]) == 1
assert filtered[0][0]["text"] == "beta"


def test_hybrid_retrieval_rejects_non_hybrid_query_type() -> None:
d = tempfile.mkdtemp()
_tiny_table(d, create_fts_index=True)
op = LanceDB(uri=d, table_name="t", overwrite=False, vector_dim=2, validate_vector_length=False)

with pytest.raises(ValueError, match="query_type"):
op.retrieval(
[[1.0, 0.0]],
top_k=2,
table_path=d,
table_name="t",
hybrid=True,
query_texts=["alpha"],
search_kwargs={"query_type": "vector"},
)
32 changes: 32 additions & 0 deletions nemo_retriever/tests/test_nv_ingest_vdb_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,38 @@ def test_retrieve_operator_delegates_vectors_to_retrieval() -> None:
assert vdb.retrieval_calls == [([[0.1, 0.2]], {"collection_name": "docs", "model_name": "embedder", "top_k": 3})]


def test_retrieve_operator_forwards_runtime_query_texts() -> None:
vdb = FakeVDB()
operator = RetrieveVdbOperator(
vdb=vdb,
vdb_kwargs={"collection_name": "docs", "model_name": "embedder", "hybrid": True, "query_texts": ["stale"]},
)

operator.process([[0.1, 0.2]], top_k=3, query_texts=["current"])

assert vdb.retrieval_calls == [
(
[[0.1, 0.2]],
{
"collection_name": "docs",
"model_name": "embedder",
"hybrid": True,
"top_k": 3,
"query_texts": ["current"],
},
)
]


def test_retrieve_operator_does_not_forward_query_texts_for_dense_retrieval() -> None:
vdb = FakeVDB()
operator = RetrieveVdbOperator(vdb=vdb, vdb_kwargs={"collection_name": "docs", "model_name": "embedder"})

operator.process([[0.1, 0.2]], top_k=3, query_texts=["current"])

assert vdb.retrieval_calls == [([[0.1, 0.2]], {"collection_name": "docs", "model_name": "embedder", "top_k": 3})]


def test_constructor_requires_exactly_one_vdb_source() -> None:
with pytest.raises(ValueError, match="Either vdb or vdb_op is required"):
IngestVdbOperator()
Expand Down
Loading