-
Notifications
You must be signed in to change notification settings - Fork 274
/
Copy pathusing_faiss_as_vextorDB.py
97 lines (77 loc) · 3.37 KB
/
using_faiss_as_vextorDB.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import asyncio
import numpy as np
from nano_graphrag.graphrag import GraphRAG, QueryParam
from nano_graphrag._utils import logger
from nano_graphrag.base import BaseVectorStorage
from dataclasses import dataclass
import faiss
import pickle
import logging
import xxhash
logging.getLogger('msal').setLevel(logging.WARNING)
logging.getLogger('azure').setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
WORKING_DIR = "./nano_graphrag_cache_faiss_TEST"
@dataclass
class FAISSStorage(BaseVectorStorage):
def __post_init__(self):
self._index_file_name = os.path.join(
self.global_config["working_dir"], f"{self.namespace}_faiss.index"
)
self._metadata_file_name = os.path.join(
self.global_config["working_dir"], f"{self.namespace}_metadata.pkl"
)
self._max_batch_size = self.global_config["embedding_batch_num"]
if os.path.exists(self._index_file_name) and os.path.exists(self._metadata_file_name):
self._index = faiss.read_index(self._index_file_name)
with open(self._metadata_file_name, 'rb') as f:
self._metadata = pickle.load(f)
else:
self._index = faiss.IndexIDMap(faiss.IndexFlatIP(self.embedding_func.embedding_dim))
self._metadata = {}
async def upsert(self, data: dict[str, dict]):
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches]
)
embeddings = np.concatenate(embeddings_list)
ids = []
for k, v in data.items():
id = xxhash.xxh32_intdigest(k.encode())
metadata = {k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}
metadata['id'] = k
self._metadata[id] = metadata
ids.append(id)
ids = np.array(ids, dtype=np.int64)
self._index.add_with_ids(embeddings, ids)
return len(data)
async def query(self, query, top_k=5):
embedding = await self.embedding_func([query])
distances, indices = self._index.search(embedding, top_k)
results = []
for _, (distance, id) in enumerate(zip(distances[0], indices[0])):
if id != -1: # FAISS returns -1 for empty slots
if id in self._metadata:
metadata = self._metadata[id]
results.append({**metadata, "distance": 1 - distance}) # Convert to cosine distance
return results
async def index_done_callback(self):
faiss.write_index(self._index, self._index_file_name)
with open(self._metadata_file_name, 'wb') as f:
pickle.dump(self._metadata, f)
if __name__ == "__main__":
graph_func = GraphRAG(
working_dir=WORKING_DIR,
enable_llm_cache=True,
vector_db_storage_cls=FAISSStorage,
)
with open(r"tests/mock_data.txt", encoding='utf-8') as f:
graph_func.insert(f.read()[:30000])
# Perform global graphrag search
print(graph_func.query("What are the top themes in this story?"))