Skip to content

Commit 9ad71cf

Browse files
authored
feat: add Neo4j graph backend (#61)
* refactor: make _storage a folder * feat: add neo4j backend * fix: remove test coverage for neo4j * refactor: dspy extraction * docs: update neo4j * fix: neo4j return clusters in node_data * tests: fix test wrong with clusters node data * improve coverage of llm
1 parent b33b2b8 commit 9ad71cf

26 files changed

+1111
-360
lines changed

Diff for: .coveragerc

+4-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,7 @@ exclude_lines =
55

66
# Don't complain if tests don't hit defensive assertion code:
77
raise NotImplementedError
8-
logger.
8+
logger.
9+
omit =
10+
# Don't have a nice github action for neo4j now, so skip this file:
11+
nano_graphrag/_storage/gdb_neo4j.py

Diff for: .github/workflows/test.yml

+2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ jobs:
4242
run: |
4343
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
4444
- name: Build and Test
45+
env:
46+
NANO_GRAPHRAG_TEST_IGNORE_NEO4J: true
4547
run: |
4648
python -m pytest -o log_cli=true -o log_cli_level="INFO" --cov=nano_graphrag --cov-report=xml -v ./
4749
- name: Check codecov file

Diff for: .gitignore

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# Created by https://www.toptal.com/developers/gitignore/api/python
22
# Edit at https://www.toptal.com/developers/gitignore?templates=python
33
test_cache.json
4-
run_test.py
5-
run_test_zh.py
4+
run_test*.py
65
nano_graphrag_cache*/
76
*.txt
87
examples/benchmarks/fixtures/

Diff for: docs/use_neo4j_for_graphrag.md

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
1. Install [Neo4j](https://neo4j.com/docs/operations-manual/current/installation/)
2+
2. Install Neo4j GDS (graph data science) [plugin](https://neo4j.com/docs/graph-data-science/current/installation/neo4j-server/)
3+
3. Start neo4j server
4+
4. Get the `NEO4J_URL`, `NEO4J_USER` and `NEO4J_PASSWORD`
5+
- By default, `NEO4J_URL` is `neo4j://localhost:7687` , `NEO4J_USER` is `neo4j` and `NEO4J_PASSWORD` is `neo4j`
6+
7+
Pass your neo4j instance to `GraphRAG`:
8+
9+
```python
10+
from nano_graphrag import GraphRAG
11+
from nano_graphrag._storage import Neo4jStorage
12+
13+
neo4j_config = {
14+
"neo4j_url": os.environ.get("NEO4J_URL", "neo4j://localhost:7687"),
15+
"neo4j_auth": (
16+
os.environ.get("NEO4J_USER", "neo4j"),
17+
os.environ.get("NEO4J_PASSWORD", "neo4j"),
18+
)
19+
}
20+
GraphRAG(
21+
graph_storage_cls=Neo4jStorage,
22+
addon_params=neo4j_config,
23+
)
24+
```
25+
26+
27+

Diff for: examples/no_openai_key_at_all.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ async def ollama_model_if_cache(
3434
) -> str:
3535
# remove kwargs that are not supported by ollama
3636
kwargs.pop("max_tokens", None)
37+
kwargs.pop("response_format", None)
3738

3839
ollama_client = ollama.AsyncClient()
3940
messages = []

Diff for: examples/using_ollama_as_llm.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ async def ollama_model_if_cache(
1818
) -> str:
1919
# remove kwargs that are not supported by ollama
2020
kwargs.pop("max_tokens", None)
21+
kwargs.pop("response_format", None)
2122

2223
ollama_client = ollama.AsyncClient()
2324
messages = []

Diff for: examples/using_ollama_as_llm_and_embedding.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020
EMBEDDING_MODEL_DIM = 768
2121
EMBEDDING_MODEL_MAX_TOKENS = 8192
2222

23+
2324
async def ollama_model_if_cache(
2425
prompt, system_prompt=None, history_messages=[], **kwargs
2526
) -> str:
2627
# remove kwargs that are not supported by ollama
2728
kwargs.pop("max_tokens", None)
29+
kwargs.pop("response_format", None)
2830

2931
ollama_client = ollama.AsyncClient()
3032
messages = []
@@ -98,20 +100,21 @@ def insert():
98100
# rag = GraphRAG(working_dir=WORKING_DIR, enable_llm_cache=True)
99101
# rag.insert(FAKE_TEXT[half_len:])
100102

103+
101104
# We're using Ollama to generate embeddings for the BGE model
102105
@wrap_embedding_func_with_attrs(
103-
embedding_dim= EMBEDDING_MODEL_DIM,
104-
max_token_size= EMBEDDING_MODEL_MAX_TOKENS,
106+
embedding_dim=EMBEDDING_MODEL_DIM,
107+
max_token_size=EMBEDDING_MODEL_MAX_TOKENS,
105108
)
106-
107-
async def ollama_embedding(texts :list[str]) -> np.ndarray:
109+
async def ollama_embedding(texts: list[str]) -> np.ndarray:
108110
embed_text = []
109111
for text in texts:
110-
data = ollama.embeddings(model=EMBEDDING_MODEL, prompt=text)
111-
embed_text.append(data["embedding"])
112-
112+
data = ollama.embeddings(model=EMBEDDING_MODEL, prompt=text)
113+
embed_text.append(data["embedding"])
114+
113115
return embed_text
114116

117+
115118
if __name__ == "__main__":
116119
insert()
117120
query()

Diff for: nano_graphrag/_llm.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,23 @@
1313
from ._utils import compute_args_hash, wrap_embedding_func_with_attrs
1414
from .base import BaseKVStorage
1515

16+
global_openai_async_client = None
17+
global_azure_openai_async_client = None
18+
19+
20+
def get_openai_async_client_instance():
21+
global global_openai_async_client
22+
if global_openai_async_client is None:
23+
global_openai_async_client = AsyncOpenAI()
24+
return global_openai_async_client
25+
26+
27+
def get_azure_openai_async_client_instance():
28+
global global_azure_openai_async_client
29+
if global_azure_openai_async_client is None:
30+
global_azure_openai_async_client = AsyncAzureOpenAI()
31+
return global_azure_openai_async_client
32+
1633

1734
@retry(
1835
stop=stop_after_attempt(5),
@@ -22,7 +39,7 @@
2239
async def openai_complete_if_cache(
2340
model, prompt, system_prompt=None, history_messages=[], **kwargs
2441
) -> str:
25-
openai_async_client = AsyncOpenAI()
42+
openai_async_client = get_openai_async_client_instance()
2643
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
2744
messages = []
2845
if system_prompt:
@@ -78,7 +95,7 @@ async def gpt_4o_mini_complete(
7895
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
7996
)
8097
async def openai_embedding(texts: list[str]) -> np.ndarray:
81-
openai_async_client = AsyncOpenAI()
98+
openai_async_client = get_openai_async_client_instance()
8299
response = await openai_async_client.embeddings.create(
83100
model="text-embedding-3-small", input=texts, encoding_format="float"
84101
)
@@ -93,7 +110,7 @@ async def openai_embedding(texts: list[str]) -> np.ndarray:
93110
async def azure_openai_complete_if_cache(
94111
deployment_name, prompt, system_prompt=None, history_messages=[], **kwargs
95112
) -> str:
96-
azure_openai_client = AsyncAzureOpenAI()
113+
azure_openai_client = get_azure_openai_async_client_instance()
97114
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
98115
messages = []
99116
if system_prompt:
@@ -154,11 +171,7 @@ async def azure_gpt_4o_mini_complete(
154171
retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
155172
)
156173
async def azure_openai_embedding(texts: list[str]) -> np.ndarray:
157-
azure_openai_client = AsyncAzureOpenAI(
158-
api_key=os.environ.get("API_KEY_EMB"),
159-
api_version=os.environ.get("API_VERSION_EMB"),
160-
azure_endpoint=os.environ.get("AZURE_ENDPOINT_EMB"),
161-
)
174+
azure_openai_client = get_azure_openai_async_client_instance()
162175
response = await azure_openai_client.embeddings.create(
163176
model="text-embedding-3-small", input=texts, encoding_format="float"
164177
)

Diff for: nano_graphrag/_storage/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .gdb_networkx import NetworkXStorage
2+
from .gdb_neo4j import Neo4jStorage
3+
from .vdb_hnswlib import HNSWVectorStorage
4+
from .vdb_nanovectordb import NanoVectorDBStorage
5+
from .kv_json import JsonKVStorage

0 commit comments

Comments
 (0)