Skip to content

Commit 4bcfbb1

Browse files
committed
Add SchemaFromExistingGraphExtractor component
Parses the result from get_structured_schema and returns a GraphSchema object
1 parent 18963b2 commit 4bcfbb1

File tree

2 files changed

+117
-2
lines changed

2 files changed

+117
-2
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""This example demonstrates how to use the SchemaFromExistingGraphExtractor component
2+
to automatically extract a schema from an existing Neo4j database.
3+
"""
4+
5+
import asyncio
6+
7+
import neo4j
8+
9+
from neo4j_graphrag.experimental.components.schema import (
10+
SchemaFromExistingGraphExtractor,
11+
GraphSchema,
12+
)
13+
14+
15+
URI = "neo4j+s://demo.neo4jlabs.com"
16+
AUTH = ("recommendations", "recommendations")
17+
DATABASE = "recommendations"
18+
INDEX = "moviePlotsEmbedding"
19+
20+
21+
async def main() -> None:
22+
"""Run the example."""
23+
24+
with neo4j.GraphDatabase.driver(
25+
URI,
26+
auth=AUTH,
27+
) as driver:
28+
extractor = SchemaFromExistingGraphExtractor(driver)
29+
schema: GraphSchema = await extractor.run()
30+
# schema.store_as_json("my_schema.json")
31+
print(schema)
32+
33+
34+
if __name__ == "__main__":
35+
asyncio.run(main())

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from __future__ import annotations
1616

1717
import json
18+
19+
import neo4j
1820
import yaml
1921
import logging
2022
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence
@@ -42,6 +44,7 @@
4244
)
4345
from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate
4446
from neo4j_graphrag.llm import LLMInterface
47+
from neo4j_graphrag.schema import get_structured_schema
4548

4649

4750
class PropertyType(BaseModel):
@@ -254,7 +257,12 @@ def from_yaml(cls, file_path: Union[str, Path]) -> Self:
254257
raise SchemaValidationError(f"Schema validation failed: {e}")
255258

256259

257-
class SchemaBuilder(Component):
260+
class BaseSchemaBuilder(Component):
261+
async def run(self, *args: Any, **kwargs: Any) -> GraphSchema:
262+
raise NotImplementedError()
263+
264+
265+
class SchemaBuilder(BaseSchemaBuilder):
258266
"""
259267
A builder class for constructing GraphSchema objects from given entities,
260268
relations, and their interrelationships defined in a potential schema.
@@ -363,7 +371,7 @@ async def run(
363371
return self.create_schema_model(node_types, relationship_types, patterns)
364372

365373

366-
class SchemaFromTextExtractor(Component):
374+
class SchemaFromTextExtractor(BaseSchemaBuilder):
367375
"""
368376
A component for constructing GraphSchema objects from the output of an LLM after
369377
automatic schema extraction from text.
@@ -446,3 +454,75 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
446454
"patterns": extracted_patterns,
447455
}
448456
)
457+
458+
459+
class SchemaFromExistingGraphExtractor(BaseSchemaBuilder):
460+
"""A class to build a GraphSchema object from an existing graph."""
461+
462+
def __init__(self, driver: neo4j.Driver) -> None:
463+
self.driver = driver
464+
465+
async def run(self, **kwargs: Any) -> GraphSchema:
466+
structured_schema = get_structured_schema(self.driver)
467+
node_labels = set(structured_schema["node_props"].keys())
468+
node_types = [
469+
{
470+
"label": key,
471+
"properties": [
472+
{
473+
"name": p["property"],
474+
"type": p["type"],
475+
}
476+
for p in properties
477+
],
478+
}
479+
for key, properties in structured_schema["node_props"].items()
480+
]
481+
rel_labels = set(structured_schema["rel_props"].keys())
482+
relationship_types = [
483+
{
484+
"label": key,
485+
"properties": [
486+
{
487+
"name": p["property"],
488+
"type": p["type"],
489+
}
490+
for p in properties
491+
],
492+
}
493+
for key, properties in structured_schema["rel_props"].items()
494+
]
495+
patterns = [
496+
(s["start"], s["type"], s["end"])
497+
for s in structured_schema["relationships"]
498+
]
499+
# deal with nodes and relationships without properties
500+
for source, rel, target in patterns:
501+
if source not in node_labels:
502+
node_labels.add(source)
503+
node_types.append(
504+
{
505+
"label": source,
506+
}
507+
)
508+
if target not in node_labels:
509+
node_labels.add(target)
510+
node_types.append(
511+
{
512+
"label": target,
513+
}
514+
)
515+
if rel not in rel_labels:
516+
rel_labels.add(rel)
517+
relationship_types.append(
518+
{
519+
"label": rel,
520+
}
521+
)
522+
return GraphSchema.model_validate(
523+
{
524+
"node_types": node_types,
525+
"relationship_types": relationship_types,
526+
"patterns": patterns,
527+
}
528+
)

0 commit comments

Comments
 (0)