Skip to content

Commit 9dc9719

Browse files
authored
Fixes for graph pruning (#359)
* Fixes * Fix existing tests * Add UT to make sure we're not pruning the lexical graph * Add e2e test * Make deprecated parameters more visible in the doc * Fix CI * Lock file, because * Rename variable * Update doc * Update doc
1 parent dffd484 commit 9dc9719

File tree

13 files changed

+389
-252
lines changed

13 files changed

+389
-252
lines changed
49.4 KB
Loading

docs/source/user_guide_kg_builder.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ A Knowledge Graph (KG) construction pipeline requires a few components (some of
2424
- **Schema builder**: provide a schema to ground the LLM extracted node and relationship types and obtain an easily navigable KG. Schema can be provided manually or extracted automatically using LLMs.
2525
- **Lexical graph builder**: build the lexical graph (Document, Chunk and their relationships) (optional).
2626
- **Entity and relation extractor**: extract relevant entities and relations from the text.
27+
- **Graph pruner**: clean the graph based on schema, if provided.
2728
- **Knowledge Graph writer**: save the identified entities and relations.
2829
- **Entity resolver**: merge similar entities into a single node.
2930

poetry.lock

Lines changed: 124 additions & 128 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/neo4j_graphrag/experimental/components/graph_pruning.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
Neo4jGraph,
2929
Neo4jNode,
3030
Neo4jRelationship,
31+
LexicalGraphConfig,
3132
)
3233
from neo4j_graphrag.experimental.pipeline import Component, DataModel
3334

@@ -135,9 +136,14 @@ async def run(
135136
self,
136137
graph: Neo4jGraph,
137138
schema: Optional[GraphSchema] = None,
139+
lexical_graph_config: Optional[LexicalGraphConfig] = None,
138140
) -> GraphPruningResult:
141+
if lexical_graph_config is None:
142+
lexical_graph_config = LexicalGraphConfig()
139143
if schema is not None:
140-
new_graph, pruning_stats = self._clean_graph(graph, schema)
144+
new_graph, pruning_stats = self._clean_graph(
145+
graph, schema, lexical_graph_config
146+
)
141147
else:
142148
new_graph = graph
143149
pruning_stats = PruningStats()
@@ -150,6 +156,7 @@ def _clean_graph(
150156
self,
151157
graph: Neo4jGraph,
152158
schema: GraphSchema,
159+
lexical_graph_config: LexicalGraphConfig,
153160
) -> tuple[Neo4jGraph, PruningStats]:
154161
"""
155162
Verify that the graph conforms to the provided schema.
@@ -162,6 +169,7 @@ def _clean_graph(
162169
filtered_nodes = self._enforce_nodes(
163170
graph.nodes,
164171
schema,
172+
lexical_graph_config,
165173
pruning_stats,
166174
)
167175
if not filtered_nodes:
@@ -174,6 +182,7 @@ def _clean_graph(
174182
graph.relationships,
175183
filtered_nodes,
176184
schema,
185+
lexical_graph_config,
177186
pruning_stats,
178187
)
179188

@@ -214,20 +223,24 @@ def _validate_node(
214223

215224
def _enforce_nodes(
216225
self,
217-
extracted_nodes: list[Neo4jNode],
226+
nodes: list[Neo4jNode],
218227
schema: GraphSchema,
228+
lexical_graph_config: LexicalGraphConfig,
219229
pruning_stats: PruningStats,
220230
) -> list[Neo4jNode]:
221231
"""
222-
Filter extracted nodes to be conformant to the schema.
232+
Filter nodes to be conformant to the schema.
223233
224234
Keep only those whose label is in schema
225235
(unless schema has additional_node_types=True, default value)
226236
For each valid node, validate properties. If a node is left without
227237
properties, prune it.
228238
"""
229239
valid_nodes = []
230-
for node in extracted_nodes:
240+
for node in nodes:
241+
if node.label in lexical_graph_config.lexical_graph_node_labels:
242+
valid_nodes.append(node)
243+
continue
231244
schema_entity = schema.node_type_from_label(node.label)
232245
new_node = self._validate_node(
233246
node,
@@ -316,13 +329,14 @@ def _validate_relationship(
316329

317330
def _enforce_relationships(
318331
self,
319-
extracted_relationships: list[Neo4jRelationship],
332+
relationships: list[Neo4jRelationship],
320333
filtered_nodes: list[Neo4jNode],
321334
schema: GraphSchema,
335+
lexical_graph_config: LexicalGraphConfig,
322336
pruning_stats: PruningStats,
323337
) -> list[Neo4jRelationship]:
324338
"""
325-
Filter extracted nodes to be conformant to the schema.
339+
Filter relationships to be conformant to the schema.
326340
327341
Keep only those whose types are in schema, start/end node conform to schema,
328342
and start/end nodes are in filtered nodes (i.e., kept after node enforcement).
@@ -333,7 +347,10 @@ def _enforce_relationships(
333347

334348
valid_rels = []
335349
valid_nodes = {node.id: node.label for node in filtered_nodes}
336-
for rel in extracted_relationships:
350+
for rel in relationships:
351+
if rel.type in lexical_graph_config.lexical_graph_relationship_types:
352+
valid_rels.append(rel)
353+
continue
337354
schema_relation = schema.relationship_type_from_label(rel.type)
338355
new_rel = self._validate_relationship(
339356
rel,

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ def create_schema_model(
334334
node_types: Sequence[NodeType],
335335
relationship_types: Optional[Sequence[RelationshipType]] = None,
336336
patterns: Optional[Sequence[Tuple[str, str, str]]] = None,
337+
**kwargs: Any,
337338
) -> GraphSchema:
338339
"""
339340
Creates a GraphSchema object from Lists of Entity and Relation objects
@@ -343,6 +344,7 @@ def create_schema_model(
343344
node_types (Sequence[NodeType]): List or tuple of NodeType objects.
344345
relationship_types (Optional[Sequence[RelationshipType]]): List or tuple of RelationshipType objects.
345346
patterns (Optional[Sequence[Tuple[str, str, str]]]): List or tuples of triplets: (source_entity_label, relation_label, target_entity_label).
347+
kwargs: other arguments passed to GraphSchema validator.
346348
347349
Returns:
348350
GraphSchema: A configured schema object.
@@ -353,17 +355,19 @@ def create_schema_model(
353355
node_types=node_types,
354356
relationship_types=relationship_types or (),
355357
patterns=patterns or (),
358+
**kwargs,
356359
)
357360
)
358-
except (ValidationError, SchemaValidationError) as e:
359-
raise SchemaValidationError(e) from e
361+
except ValidationError as e:
362+
raise SchemaValidationError() from e
360363

361364
@validate_call
362365
async def run(
363366
self,
364367
node_types: Sequence[NodeType],
365368
relationship_types: Optional[Sequence[RelationshipType]] = None,
366369
patterns: Optional[Sequence[Tuple[str, str, str]]] = None,
370+
**kwargs: Any,
367371
) -> GraphSchema:
368372
"""
369373
Asynchronously constructs and returns a GraphSchema object.
@@ -376,7 +380,12 @@ async def run(
376380
Returns:
377381
GraphSchema: A configured schema object, constructed asynchronously.
378382
"""
379-
return self.create_schema_model(node_types, relationship_types, patterns)
383+
return self.create_schema_model(
384+
node_types,
385+
relationship_types,
386+
patterns,
387+
**kwargs,
388+
)
380389

381390

382391
class SchemaFromTextExtractor(Component):

src/neo4j_graphrag/experimental/components/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,14 @@ class LexicalGraphConfig(BaseModel):
174174
def lexical_graph_node_labels(self) -> tuple[str, ...]:
175175
return self.document_node_label, self.chunk_node_label
176176

177+
@property
178+
def lexical_graph_relationship_types(self) -> tuple[str, ...]:
179+
return (
180+
self.chunk_to_document_relationship_type,
181+
self.next_chunk_relationship_type,
182+
self.node_to_chunk_relationship_type,
183+
)
184+
177185

178186
class GraphResult(DataModel):
179187
graph: Neo4jGraph

src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py

Lines changed: 17 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
Optional,
2222
Sequence,
2323
Union,
24-
Tuple,
2524
)
2625
import logging
2726
import warnings
@@ -45,8 +44,6 @@
4544
from neo4j_graphrag.experimental.components.schema import (
4645
SchemaBuilder,
4746
GraphSchema,
48-
NodeType,
49-
RelationshipType,
5047
SchemaFromTextExtractor,
5148
)
5249
from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter
@@ -184,66 +181,33 @@ def _get_schema(self) -> Union[SchemaBuilder, SchemaFromTextExtractor]:
184181
return SchemaFromTextExtractor(llm=self.get_default_llm())
185182
return SchemaBuilder()
186183

187-
def _process_schema_with_precedence(
188-
self,
189-
) -> Tuple[
190-
Tuple[NodeType, ...],
191-
Tuple[RelationshipType, ...] | None,
192-
Optional[Tuple[Tuple[str, str, str], ...]] | None,
193-
]:
184+
def _process_schema_with_precedence(self) -> dict[str, Any]:
194185
"""
195186
Process schema inputs according to precedence rules:
196187
1. If schema is provided as GraphSchema object, use it
197188
2. If schema is provided as dictionary, extract from it
198189
3. Otherwise, use individual schema components
199190
200191
Returns:
201-
Tuple of (node_types, relationship_types, patterns)
192+
A dict representing the schema
202193
"""
203194
if self.schema_ is not None:
204-
# schema takes precedence over individual components
205-
node_types = self.schema_.node_types
195+
return self.schema_.model_dump()
206196

207-
# handle case where relations could be None
208-
if self.schema_.relationship_types is not None:
209-
relationship_types = self.schema_.relationship_types
210-
else:
211-
relationship_types = None
212-
213-
patterns = self.schema_.patterns
214-
else:
215-
# use individual components
216-
node_types = tuple(
217-
[NodeType.model_validate(e) for e in self.entities]
218-
if self.entities
219-
else []
220-
)
221-
relationship_types = (
222-
tuple([RelationshipType.model_validate(r) for r in self.relations])
223-
if self.relations is not None
224-
else None
225-
)
226-
patterns = (
227-
tuple(self.potential_schema) if self.potential_schema else tuple()
228-
)
229-
230-
return node_types, relationship_types, patterns
197+
return dict(
198+
node_types=self.entities,
199+
relationship_types=self.relations,
200+
patterns=self.potential_schema,
201+
)
231202

232203
def _get_run_params_for_schema(self) -> dict[str, Any]:
233204
if not self.has_user_provided_schema():
234205
# for automatic extraction, the text parameter is needed (will flow through the pipeline connections)
235206
return {}
236207
else:
237208
# process schema components according to precedence rules
238-
node_types, relationship_types, patterns = (
239-
self._process_schema_with_precedence()
240-
)
241-
242-
return {
243-
"node_types": node_types,
244-
"relationship_types": relationship_types,
245-
"patterns": patterns,
246-
}
209+
schema_dict = self._process_schema_with_precedence()
210+
return schema_dict
247211

248212
def _get_extractor(self) -> EntityRelationExtractor:
249213
return LLMEntityRelationExtractor(
@@ -368,7 +332,13 @@ def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]:
368332
run_params = {}
369333
if self.lexical_graph_config:
370334
run_params["extractor"] = {
371-
"lexical_graph_config": self.lexical_graph_config
335+
"lexical_graph_config": self.lexical_graph_config,
336+
}
337+
run_params["writer"] = {
338+
"lexical_graph_config": self.lexical_graph_config,
339+
}
340+
run_params["pruner"] = {
341+
"lexical_graph_config": self.lexical_graph_config,
372342
}
373343
text = user_input.get("text")
374344
file_path = user_input.get("file_path")

src/neo4j_graphrag/experimental/pipeline/kg_builder.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,29 @@ class SimpleKGPipeline:
5656
llm (LLMInterface): An instance of an LLM to use for entity and relation extraction.
5757
driver (neo4j.Driver): A Neo4j driver instance for database connection.
5858
embedder (Embedder): An instance of an embedder used to generate chunk embeddings from text chunks.
59-
schema (Optional[Union[GraphSchema, dict[str, list]]]): A schema configuration defining entities,
60-
relations, and potential schema relationships.
61-
This is the recommended way to provide schema information.
59+
schema (Optional[Union[GraphSchema, dict[str, list]]]): A schema configuration defining node types,
60+
relationship types, and graph patterns.
6261
entities (Optional[List[Union[str, dict[str, str], NodeType]]]): DEPRECATED. A list of either:
6362
6463
- str: entity labels
6564
- dict: following the NodeType schema, ie with label, description and properties keys
6665
66+
.. deprecated:: 1.7.1
67+
Use schema instead
68+
6769
relations (Optional[List[Union[str, dict[str, str], RelationshipType]]]): DEPRECATED. A list of either:
6870
6971
- str: relation label
7072
- dict: following the RelationshipType schema, ie with label, description and properties keys
7173
74+
.. deprecated:: 1.7.1
75+
Use schema instead
76+
7277
potential_schema (Optional[List[tuple]]): DEPRECATED. A list of potential schema relationships.
78+
79+
.. deprecated:: 1.7.1
80+
Use schema instead
81+
7382
from_pdf (bool): Determines whether to include the PdfLoader in the pipeline.
7483
If True, expects `file_path` input in `run` methods.
7584
If False, expects `text` input in `run` methods.

0 commit comments

Comments
 (0)