diff --git a/pephub/_version.py b/pephub/_version.py index 8782a8b4..32e2f393 100644 --- a/pephub/_version.py +++ b/pephub/_version.py @@ -1 +1 @@ -__version__ = "0.15.4" +__version__ = "0.15.5" diff --git a/pephub/const.py b/pephub/const.py index cf8e9c84..53fd674f 100644 --- a/pephub/const.py +++ b/pephub/const.py @@ -64,7 +64,8 @@ # https://arxiv.org/abs/2210.07316 # figure 4 # great speed to accuracy tradeoff -DEFAULT_HF_MODEL = "BAAI/bge-small-en-v1.5" +DENSE_ENCODER_MODEL = "sentence-transformers/all-MiniLM-L6-v2" +SPARSE_ENCODER_MODEL = "naver/splade-v3" EIDO_TEMPLATES_DIRNAME = "templates/eido" EIDO_TEMPLATES_PATH = os.path.join( diff --git a/pephub/dependencies.py b/pephub/dependencies.py index 6c47a92e..7b2644ba 100644 --- a/pephub/dependencies.py +++ b/pephub/dependencies.py @@ -20,10 +20,11 @@ from pepdbagent.models import AnnotationModel, Namespace, ListOfNamespaceInfo from pydantic import BaseModel from qdrant_client import QdrantClient +from sentence_transformers import SparseEncoder from qdrant_client.http.exceptions import ResponseHandlingException from .const import ( - DEFAULT_HF_MODEL, + DENSE_ENCODER_MODEL, DEFAULT_POSTGRES_DB, DEFAULT_POSTGRES_HOST, DEFAULT_POSTGRES_PASSWORD, @@ -32,12 +33,14 @@ DEFAULT_QDRANT_HOST, DEFAULT_QDRANT_PORT, JWT_SECRET, + SPARSE_ENCODER_MODEL, + PKG_NAME, ) from .helpers import jwt_encode_user_data from .routers.models import ForkRequest from .developer_keys import dev_key_handler -_LOGGER_PEPHUB = logging.getLogger("uvicorn.access") +_LOGGER_PEPHUB = logging.getLogger(PKG_NAME) load_dotenv() @@ -98,12 +101,21 @@ def jwt_encode_user_data(user_data: dict, exp: datetime = None) -> str: ) # sentence_transformer model -_LOGGER_PEPHUB.info(f"HF MODEL IN USE: {os.getenv('HF_MODEL', DEFAULT_HF_MODEL)}") +_LOGGER_PEPHUB.info(f"HF MODEL IN USE: {os.getenv('HF_MODEL', DENSE_ENCODER_MODEL)}") embedding_model = Embedding( - model_name=os.getenv("HF_MODEL", DEFAULT_HF_MODEL), max_length=512 + model_name=os.getenv("HF_MODEL", DENSE_ENCODER_MODEL), max_length=512 ) # embedding_model = None +token = os.environ.get("HF_TOKEN", None) +hf_model_sparse = os.environ.get("HF_MODEL_SPARSE", SPARSE_ENCODER_MODEL) +if token is None: + sparse_model = None + _LOGGER_PEPHUB.warning("No HF_TOKEN provided, sparce model disabled.") +else: + sparse_model = SparseEncoder(hf_model_sparse, token=token) + _LOGGER_PEPHUB.info(f"Sparce model in use: {hf_model_sparse}") + ## Qdrant connection def parse_boolean_env_var(env_var: str) -> bool: @@ -389,6 +401,13 @@ def get_sentence_transformer() -> Embedding: return embedding_model +def get_sparse_model() -> Union[SparseEncoder, None]: + """ + Return sparce encoder model + """ + return sparse_model + + def get_namespace_info( namespace: str, agent: PEPDatabaseAgent = Depends(get_db), diff --git a/pephub/routers/api/v1/search.py b/pephub/routers/api/v1/search.py index fb648a62..f1e672d9 100644 --- a/pephub/routers/api/v1/search.py +++ b/pephub/routers/api/v1/search.py @@ -7,6 +7,17 @@ from pepdbagent import PEPDatabaseAgent from pepdbagent.models import NamespaceList from qdrant_client import QdrantClient +from qdrant_client.models import ( + SparseVector, + Prefetch, + FusionQuery, + Fusion, + SearchParams, + FieldCondition, + MatchValue, + Filter, +) +from sentence_transformers import SparseEncoder from ....const import DEFAULT_QDRANT_COLLECTION_NAME from ....dependencies import ( @@ -14,6 +25,7 @@ get_namespace_access_list, get_qdrant, get_sentence_transformer, + get_sparse_model, ) from ...models import SearchQuery, SearchReturnModel from qdrant_client.models import ScoredPoint @@ -42,6 +54,7 @@ async def search_for_pep( query: SearchQuery, qdrant: QdrantClient = Depends(get_qdrant), model: Embedding = Depends(get_sentence_transformer), + model_sparce: SparseEncoder = Depends(get_sparse_model), agent: PEPDatabaseAgent = Depends(get_db), namespace_access: List[str] = Depends(get_namespace_access_list), ) -> SearchReturnModel: @@ -59,14 +72,55 @@ async def search_for_pep( ).results if qdrant is not None: - query_vec = list(model.embed(query.query))[0] + dense_query = list(list(model.embed(query.query))[0]) + + if model_sparce: + sparse_result = model_sparce.encode(query.query).coalesce() + sparse_embeddings = SparseVector( + indices=sparse_result.indices().tolist()[0], + values=sparse_result.values().tolist(), + ) + else: + sparse_embeddings = None + + should_statement = [ + FieldCondition( + key="name", + match=MatchValue(value=query.query), + ) + ] + + if sparse_embeddings: + hybrid_query = [ + # Dense retrieval: semantic understanding + Prefetch(query=dense_query, using="dense", limit=100), + # Sparse retrieval: exact technical term matching + Prefetch(query=sparse_embeddings, using="sparse", limit=100), + # Exact match retrieval: precise filtering + Prefetch(filter=Filter(must=should_statement), limit=10), + ] + else: + hybrid_query = [ + # Dense retrieval: semantic understanding + Prefetch(query=dense_query, using="dense", limit=100), + # Exact match retrieval: precise filtering + Prefetch(filter=Filter(must=should_statement), limit=10), + ] vector_results = qdrant.query_points( - collection_name=(query.collection_name or DEFAULT_QDRANT_COLLECTION_NAME), - query=query_vec, + collection_name=DEFAULT_QDRANT_COLLECTION_NAME, limit=limit, offset=offset, - score_threshold=score_threshold, + prefetch=hybrid_query, + query=FusionQuery(fusion=Fusion.RRF), + with_payload=True, + with_vectors=False, + search_params=SearchParams( + exact=True, + ), + # query_filter=( + # models.Filter(must=should_statement) if should_statement else None + # ), ).points return SearchReturnModel( diff --git a/pephub/routers/models.py b/pephub/routers/models.py index defeb78f..69ad2435 100644 --- a/pephub/routers/models.py +++ b/pephub/routers/models.py @@ -20,7 +20,6 @@ class ProjectOptional(UpdateItems): class SearchQuery(BaseModel): query: str - collection_name: Optional[str] = None limit: Optional[int] = 100 offset: Optional[int] = 0 score_threshold: Optional[float] = DEFAULT_QDRANT_SCORE_THRESHOLD