Skip to content

Commit 12726f3

Browse files
committed
docs: Qdrant as vector DB example
Signed-off-by: Anush008 <[email protected]>
1 parent da57aa2 commit 12726f3

File tree

1 file changed

+113
-0
lines changed

1 file changed

+113
-0
lines changed

Diff for: examples/using_qdrant_as_vectorDB.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import os
2+
import asyncio
3+
import uuid
4+
import numpy as np
5+
from nano_graphrag import GraphRAG, QueryParam
6+
from nano_graphrag._utils import logger
7+
from nano_graphrag.base import BaseVectorStorage
8+
from dataclasses import dataclass
9+
10+
try:
11+
from qdrant_client import QdrantClient
12+
from qdrant_client.models import VectorParams, Distance, PointStruct, SearchParams
13+
except ImportError as original_error:
14+
raise ImportError(
15+
"Qdrant client is not installed. Install it using: pip install qdrant-client\n"
16+
) from original_error
17+
18+
19+
@dataclass
20+
class QdrantStorage(BaseVectorStorage):
21+
def __post_init__(self):
22+
23+
# Use a local file-based Qdrant storage
24+
# Useful for prototyping and CI.
25+
# For production, refer to:
26+
# https://qdrant.tech/documentation/guides/installation/
27+
self._client_file_path = os.path.join(
28+
self.global_config["working_dir"], "qdrant_storage"
29+
)
30+
31+
self._client = QdrantClient(path=self._client_file_path)
32+
33+
self._max_batch_size = self.global_config["embedding_batch_num"]
34+
35+
if not self._client.collection_exists(collection_name=self.namespace):
36+
self._client.create_collection(
37+
collection_name=self.namespace,
38+
vectors_config=VectorParams(
39+
size=self.embedding_func.embedding_dim, distance=Distance.COSINE
40+
),
41+
)
42+
43+
async def upsert(self, data: dict[str, dict]):
44+
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
45+
46+
list_data = [
47+
{
48+
"id": k,
49+
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
50+
}
51+
for k, v in data.items()
52+
]
53+
54+
contents = [v["content"] for v in data.values()]
55+
batches = [
56+
contents[i : i + self._max_batch_size]
57+
for i in range(0, len(contents), self._max_batch_size)
58+
]
59+
60+
embeddings_list = await asyncio.gather(
61+
*[self.embedding_func(batch) for batch in batches]
62+
)
63+
embeddings = np.concatenate(embeddings_list)
64+
65+
points = [
66+
PointStruct(
67+
id=uuid.uuid4().hex,
68+
vector=embeddings[i].tolist(),
69+
payload=data,
70+
)
71+
for i, data in enumerate(list_data)
72+
]
73+
74+
results = self._client.upsert(collection_name=self.namespace, points=points)
75+
return results
76+
77+
async def query(self, query, top_k=5):
78+
embedding = await self.embedding_func([query])
79+
80+
results = self._client.query_points(
81+
collection_name=self.namespace,
82+
query=embedding[0].tolist(),
83+
limit=top_k,
84+
).points
85+
86+
return [
87+
{**result.payload, "id": str(result.id), "score": result.score}
88+
for result in results
89+
]
90+
91+
92+
def insert():
93+
data = ["YOUR TEXT DATA HERE", "YOUR TEXT DATA HERE"]
94+
rag = GraphRAG(
95+
working_dir="./nano_graphrag_cache_qdrant_TEST",
96+
enable_llm_cache=True,
97+
vector_db_storage_cls=QdrantStorage,
98+
)
99+
rag.insert(data)
100+
101+
102+
def query():
103+
rag = GraphRAG(
104+
working_dir="./nano_graphrag_cache_qdrant_TEST",
105+
enable_llm_cache=True,
106+
vector_db_storage_cls=QdrantStorage,
107+
)
108+
print(rag.query("YOUR QUERY HERE", param=QueryParam(mode="local")))
109+
110+
111+
if __name__ == "__main__":
112+
insert()
113+
query()

0 commit comments

Comments
 (0)