-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathquery_db.py
100 lines (83 loc) · 3.13 KB
/
query_db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from pinecone import Pinecone
from pinecone_text.sparse import BM25Encoder
import os
import streamlit as st
# Initialize Pinecone and key parameters
TOP_K = 5
bm25_file_name = "bm25_birds.json"
def query_integrated_inference(query, index_name, namespace="bird-search"):
pc = Pinecone(api_key=st.secrets["pinecone_api_key"])
index = pc.Index(index_name)
results = index.search_records(
namespace=namespace,
query={
"inputs": {"text": query},
"top_k": TOP_K,
},
fields=["bird", "chunk_text"] # Request specific fields from records
)
return results['result']['hits'] # Access hits from result object
def query_rerank_integrated_inference(query, index_name, namespace="bird-search"):
pc = Pinecone(api_key=st.secrets["pinecone_api_key"])
index=pc.Index(index_name)
sr = index.search(
namespace=namespace,
query={
"top_k": TOP_K * 2,
"inputs": {
"text": query
}
},
rerank={
"model": "cohere-rerank-3.5",
"rank_fields": ["chunk_text"]
}
)
return sr['result']['hits']
def query_bm25(query, index_name, namespace="bird-search"):
pc = Pinecone(api_key=st.secrets["pinecone_api_key"])
index = pc.Index(index_name)
bm25 = BM25Encoder().load(path=bm25_file_name)
encoded_query = bm25.encode_queries(query)
# query the db
results = index.query(
namespace=namespace,
sparse_vector={
"values": encoded_query["values"],
"indices": encoded_query["indices"]
},
top_k=TOP_K,
include_metadata=True
)
# Different format than for integrated inference!
# reformat into the same format as the integrated inference
final_results = []
for r in results['matches']:
final_results.append({
'id': r['id'],
'fields': r['metadata'],
'_score': r['score'],
})
return final_results
def conduct_cascading_retrieval(query, sparse_index_name="sparse-bird-search", dense_index_name="dense-bird-search", namespace="bird-search"):
'''Conduct cascading retrieval, reranking across sparse and dense indexes. Returns TOP_K results.'''
# Conduct dense retrieval
dense_results = query_rerank_integrated_inference(query, dense_index_name, namespace)
# Conduct sparse retrieval
sparse_results = query_rerank_integrated_inference(query, sparse_index_name, namespace)
# combine results into one list
combined_results = dense_results + sparse_results
# dedup results on chunk_id
# Create a dictionary to track seen IDs
seen_ids = {}
deduped_results = []
# Keep first occurrence of each ID
for result in combined_results:
if result['_id'] not in seen_ids:
seen_ids[result['_id']] = True
deduped_results.append(result)
combined_results = deduped_results
# sort results by score, highest to lowest
combined_results = sorted(combined_results, key=lambda x: x['_score'], reverse=True)
# return TOPK results
return combined_results[:TOP_K]