forked from KnowledgeXLab/LeanRAG
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbuild_graph.py
More file actions
250 lines (214 loc) · 8.61 KB
/
build_graph.py
File metadata and controls
250 lines (214 loc) · 8.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import argparse
from concurrent.futures import ProcessPoolExecutor, as_completed
import json
import os
import logging
import numpy as np
from openai import OpenAI
from tools.utils import truncate_text
from tqdm import tqdm
from _cluster_utils import HierarchicalClustering
from database_utils import (
build_vector_search,
create_db_table_sqlite,
insert_data_to_memgraph,
persist_hierarchy_to_sqlite,
get_embedding,
get_db_path,
)
import sqlite3
from dotenv import load_dotenv
logger=logging.getLogger(__name__)
load_dotenv()
# Load configuration from environment variables
MODEL = os.getenv('MODEL_LLM', 'grok-4-fast-reasoning')
EMBEDDING_MODEL = os.getenv('TOGETHER_MODEL', 'intfloat/multilingual-e5-large-instruct')
EMBEDDING_URL = os.getenv('TOGETHER_EMBEDDING_URL', 'https://api.together.xyz/embedding')
EMBEDDING_API_KEY = os.getenv('TOGETHER_API_KEY')
EMBEDDING_MAX_TOKENS = int(os.getenv('TOGETHER_EMBED_MAX_TOKENS', '480'))
TOTAL_TOKEN_COST = 0
TOTAL_API_CALL_COST = 0
if not EMBEDDING_API_KEY:
raise EnvironmentError("TOGETHER_API_KEY not set; cannot generate embeddings for hierarchy build")
def get_common_rag_res(_working_dir: str):
"""Load raw entities and relations from the shared SQLite store."""
db_path = get_db_path()
conn = sqlite3.connect(db_path)
cur = conn.cursor()
entity_rows = cur.execute(
"""
SELECT entity_name, description, source_id
FROM entities
WHERE COALESCE(level, 0) = 0
"""
).fetchall()
e_dic: dict[str, dict] = {}
for entity_name, description, source_id in entity_rows:
if not entity_name:
continue
entry = e_dic.setdefault(
entity_name,
{
'entity_name': entity_name,
'description': description or '',
'source_id': source_id or '',
'degree': 0,
},
)
if entry['description'] and description and description not in entry['description']:
entry['description'] += f"|Additional: {description}"
if entry['source_id'] and source_id and source_id not in entry['source_id']:
entry['source_id'] += f"|{source_id}"
relation_rows = cur.execute(
"""
SELECT src_tgt, tgt_src, description, weight
FROM relations
WHERE src_tgt IS NOT NULL AND tgt_src IS NOT NULL
"""
).fetchall()
conn.close()
r_dic = {}
for src_tgt, tgt_src, description, weight in relation_rows:
src = str(src_tgt)
tgt = str(tgt_src)
r_dic[(src, tgt)] = {
'src_tgt': src,
'tgt_src': tgt,
'description': description or '',
'weight': weight or 1,
}
return e_dic, r_dic
def _prepare_embedding_texts(texts: list[str]) -> list[str]:
prepared = []
for txt in texts:
normalized = truncate_text(txt or '', max_tokens=EMBEDDING_MAX_TOKENS)
prepared.append(normalized[:EMBEDDING_MAX_TOKENS * 3])
return prepared
def embedding(texts) -> np.ndarray:
# """doublette, see also embedding in query_graph and get_embedding in database_utils"""
# model_name = EMBEDDING_MODEL
# client = OpenAI(
# api_key=EMBEDDING_API_KEY,
# base_url=EMBEDDING_URL
# )
# embedding = client.embeddings.create(
# input=_prepare_embedding_texts(texts),
# model=model_name,
# )
# final_embedding = [d.embedding for d in embedding.data]
# return np.array(final_embedding)
return np.array(get_embedding(texts))
def compute_level_counts(levels):
"""Return a list of dictionaries describing node counts per hierarchy level."""
counts = []
for idx, layer in enumerate(levels):
if isinstance(layer, list):
counts.append({"level": idx, "nodes": len(layer)})
elif isinstance(layer, dict):
counts.append({"level": idx, "nodes": 1})
return counts
def render_level_chart(level_counts):
"""Render a simple ASCII bar chart from level counts."""
if not level_counts:
return "No hierarchical levels detected."
max_nodes = max(entry["nodes"] for entry in level_counts) or 1
lines = ["Level Distribution:"]
for entry in level_counts:
level = entry["level"]
nodes = entry["nodes"]
bar_length = max(1, int((nodes / max_nodes) * 40))
bar = "#" * bar_length
lines.append(f"Level {level:02d} | {nodes:3d} nodes | {bar}")
return "\n".join(lines)
def embedding_init(entities: list[dict]) -> list[dict]:
texts = _prepare_embedding_texts([i.get('description', '') for i in entities])
model_name = EMBEDDING_MODEL
client = OpenAI(
api_key=EMBEDDING_API_KEY,
base_url=EMBEDDING_URL
)
embedding = client.embeddings.create(
input=texts,
model=model_name,
)
final_embedding = [d.embedding for d in embedding.data]
for i, entity in enumerate(entities):
entity['vector'] = np.array(final_embedding[i])
return entities
def embedding_data(entity_results, max_workers: int = 8):
entities = [v for k, v in entity_results.items()]
entity_with_embeddings=[]
embeddings_batch_size = 64
num_embeddings_batches = (len(entities) + embeddings_batch_size - 1) // embeddings_batch_size
batches = [
entities[i * embeddings_batch_size : min((i + 1) * embeddings_batch_size, len(entities))]
for i in range(num_embeddings_batches)
]
if max_workers <= 1:
for batch in batches:
result = embedding_init(batch)
entity_with_embeddings.extend(result)
else:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(embedding_init, batch) for batch in batches]
for future in tqdm(as_completed(futures), total=len(futures)):
result = future.result()
entity_with_embeddings.extend(result)
for i in entity_with_embeddings:
entiy_name=i['entity_name']
vector=i['vector']
entity_results[entiy_name]['vector']=vector
return entity_results
def hierarchical_clustering(global_config):
WORKING_DIR =global_config['working_dir']
entity_results,relation_results=get_common_rag_res(global_config['working_dir'])
embed_workers = max(1, global_config.get('max_workers', 8) // 2)
all_entities=embedding_data(entity_results, max_workers=embed_workers)
if entity_results:
sample_vector = next(iter(entity_results.values()))['vector']
global_config['embedding_dim'] = np.asarray(sample_vector).size
hierarchical_cluster = HierarchicalClustering()
all_entities,generate_relations,community =hierarchical_cluster.perform_clustering(global_config=global_config,entities=all_entities,relations=relation_results,\
WORKING_DIR=WORKING_DIR,max_workers=global_config['max_workers'])
level_counts = compute_level_counts(all_entities)
level_summary_text = render_level_chart(level_counts)
summary_path = os.path.join(global_config['working_dir'], 'level_summary.txt')
with open(summary_path, 'w', encoding='utf-8') as summary_file:
summary_file.write(level_summary_text + "\n")
with open(os.path.join(global_config['working_dir'], 'level_summary.json'), 'w', encoding='utf-8') as summary_json:
json.dump(level_counts, summary_json, indent=2)
logger.info("\n%s", level_summary_text)
recreate_vectors = global_config.get('clear_existing', False)
try:
all_entities[-1]['vector'] = embedding(all_entities[-1]['description'])
build_vector_search(all_entities, f"{WORKING_DIR}", recreate=recreate_vectors)
except Exception as e:
print(f"Error in build_vector_search: {e}")
persistence_stats = persist_hierarchy_to_sqlite(
all_entities,
relation_results,
generate_relations,
community,
replace_existing=True,
)
for layer in all_entities:
if type(layer) != list:
if "vector" in layer.keys():
del layer["vector"]
continue
for item in layer:
if "vector" in item.keys():
del item["vector"]
if len(layer) == 1:
item['parent'] = 'root'
create_db_table_sqlite(global_config['working_dir'])
graph_stats = insert_data_to_memgraph(
global_config['working_dir'],
clear_existing=global_config.get('clear_existing', False)
)
graph_stats.update({
'sqlite_entities': persistence_stats.get('entities', 0),
'sqlite_relations': persistence_stats.get('relations', 0),
'sqlite_communities': persistence_stats.get('communities', 0),
})
return graph_stats