Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pephub/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.15.4"
__version__ = "0.15.5"
3 changes: 2 additions & 1 deletion pephub/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@
# https://arxiv.org/abs/2210.07316
# figure 4
# great speed to accuracy tradeoff
DEFAULT_HF_MODEL = "BAAI/bge-small-en-v1.5"
DENSE_ENCODER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
SPARSE_ENCODER_MODEL = "naver/splade-v3"

EIDO_TEMPLATES_DIRNAME = "templates/eido"
EIDO_TEMPLATES_PATH = os.path.join(
Expand Down
27 changes: 23 additions & 4 deletions pephub/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
from pepdbagent.models import AnnotationModel, Namespace, ListOfNamespaceInfo
from pydantic import BaseModel
from qdrant_client import QdrantClient
from sentence_transformers import SparseEncoder
from qdrant_client.http.exceptions import ResponseHandlingException

from .const import (
DEFAULT_HF_MODEL,
DENSE_ENCODER_MODEL,
DEFAULT_POSTGRES_DB,
DEFAULT_POSTGRES_HOST,
DEFAULT_POSTGRES_PASSWORD,
Expand All @@ -32,12 +33,14 @@
DEFAULT_QDRANT_HOST,
DEFAULT_QDRANT_PORT,
JWT_SECRET,
SPARSE_ENCODER_MODEL,
PKG_NAME,
)
from .helpers import jwt_encode_user_data
from .routers.models import ForkRequest
from .developer_keys import dev_key_handler

_LOGGER_PEPHUB = logging.getLogger("uvicorn.access")
_LOGGER_PEPHUB = logging.getLogger(PKG_NAME)

load_dotenv()

Expand Down Expand Up @@ -98,12 +101,21 @@ def jwt_encode_user_data(user_data: dict, exp: datetime = None) -> str:
)

# sentence_transformer model
_LOGGER_PEPHUB.info(f"HF MODEL IN USE: {os.getenv('HF_MODEL', DEFAULT_HF_MODEL)}")
_LOGGER_PEPHUB.info(f"HF MODEL IN USE: {os.getenv('HF_MODEL', DENSE_ENCODER_MODEL)}")
embedding_model = Embedding(
model_name=os.getenv("HF_MODEL", DEFAULT_HF_MODEL), max_length=512
model_name=os.getenv("HF_MODEL", DENSE_ENCODER_MODEL), max_length=512
)
# embedding_model = None

token = os.environ.get("HF_TOKEN", None)
hf_model_sparse = os.environ.get("HF_MODEL_SPARSE", SPARSE_ENCODER_MODEL)
if token is None:
sparse_model = None
_LOGGER_PEPHUB.warning("No HF_TOKEN provided, sparce model disabled.")
else:
sparse_model = SparseEncoder(hf_model_sparse, token=token)
_LOGGER_PEPHUB.info(f"Sparce model in use: {hf_model_sparse}")


## Qdrant connection
def parse_boolean_env_var(env_var: str) -> bool:
Expand Down Expand Up @@ -389,6 +401,13 @@ def get_sentence_transformer() -> Embedding:
return embedding_model


def get_sparse_model() -> Union[SparseEncoder, None]:
"""
Return sparce encoder model
"""
return sparse_model


def get_namespace_info(
namespace: str,
agent: PEPDatabaseAgent = Depends(get_db),
Expand Down
62 changes: 58 additions & 4 deletions pephub/routers/api/v1/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,25 @@
from pepdbagent import PEPDatabaseAgent
from pepdbagent.models import NamespaceList
from qdrant_client import QdrantClient
from qdrant_client.models import (
SparseVector,
Prefetch,
FusionQuery,
Fusion,
SearchParams,
FieldCondition,
MatchValue,
Filter,
)
from sentence_transformers import SparseEncoder

from ....const import DEFAULT_QDRANT_COLLECTION_NAME
from ....dependencies import (
get_db,
get_namespace_access_list,
get_qdrant,
get_sentence_transformer,
get_sparse_model,
)
from ...models import SearchQuery, SearchReturnModel
from qdrant_client.models import ScoredPoint
Expand Down Expand Up @@ -42,6 +54,7 @@ async def search_for_pep(
query: SearchQuery,
qdrant: QdrantClient = Depends(get_qdrant),
model: Embedding = Depends(get_sentence_transformer),
model_sparce: SparseEncoder = Depends(get_sparse_model),
agent: PEPDatabaseAgent = Depends(get_db),
namespace_access: List[str] = Depends(get_namespace_access_list),
) -> SearchReturnModel:
Expand All @@ -59,14 +72,55 @@ async def search_for_pep(
).results

if qdrant is not None:
query_vec = list(model.embed(query.query))[0]
dense_query = list(list(model.embed(query.query))[0])

if model_sparce:
sparse_result = model_sparce.encode(query.query).coalesce()
sparse_embeddings = SparseVector(
indices=sparse_result.indices().tolist()[0],
values=sparse_result.values().tolist(),
)
else:
sparse_embeddings = None

should_statement = [
FieldCondition(
key="name",
match=MatchValue(value=query.query),
)
]

if sparse_embeddings:
hybrid_query = [
# Dense retrieval: semantic understanding
Prefetch(query=dense_query, using="dense", limit=100),
# Sparse retrieval: exact technical term matching
Prefetch(query=sparse_embeddings, using="sparse", limit=100),
# Exact match retrieval: precise filtering
Prefetch(filter=Filter(must=should_statement), limit=10),
]
else:
hybrid_query = [
# Dense retrieval: semantic understanding
Prefetch(query=dense_query, using="dense", limit=100),
# Exact match retrieval: precise filtering
Prefetch(filter=Filter(must=should_statement), limit=10),
]

vector_results = qdrant.query_points(
collection_name=(query.collection_name or DEFAULT_QDRANT_COLLECTION_NAME),
query=query_vec,
collection_name=DEFAULT_QDRANT_COLLECTION_NAME,
limit=limit,
offset=offset,
score_threshold=score_threshold,
prefetch=hybrid_query,
query=FusionQuery(fusion=Fusion.RRF),
with_payload=True,
with_vectors=False,
search_params=SearchParams(
exact=True,
),
# query_filter=(
# models.Filter(must=should_statement) if should_statement else None
# ),
).points

return SearchReturnModel(
Expand Down
1 change: 0 additions & 1 deletion pephub/routers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class ProjectOptional(UpdateItems):

class SearchQuery(BaseModel):
query: str
collection_name: Optional[str] = None
limit: Optional[int] = 100
offset: Optional[int] = 0
score_threshold: Optional[float] = DEFAULT_QDRANT_SCORE_THRESHOLD
Expand Down
Loading