|
15 | 15 | from __future__ import annotations
|
16 | 16 |
|
17 | 17 | import json
|
| 18 | + |
| 19 | +import neo4j |
18 | 20 | import yaml
|
19 | 21 | import logging
|
20 | 22 | from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence
|
|
42 | 44 | )
|
43 | 45 | from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate
|
44 | 46 | from neo4j_graphrag.llm import LLMInterface
|
| 47 | +from neo4j_graphrag.schema import get_structured_schema |
45 | 48 |
|
46 | 49 |
|
47 | 50 | class PropertyType(BaseModel):
|
@@ -254,7 +257,12 @@ def from_yaml(cls, file_path: Union[str, Path]) -> Self:
|
254 | 257 | raise SchemaValidationError(f"Schema validation failed: {e}")
|
255 | 258 |
|
256 | 259 |
|
257 |
| -class SchemaBuilder(Component): |
| 260 | +class BaseSchemaBuilder(Component): |
| 261 | + async def run(self, *args: Any, **kwargs: Any) -> GraphSchema: |
| 262 | + raise NotImplementedError() |
| 263 | + |
| 264 | + |
| 265 | +class SchemaBuilder(BaseSchemaBuilder): |
258 | 266 | """
|
259 | 267 | A builder class for constructing GraphSchema objects from given entities,
|
260 | 268 | relations, and their interrelationships defined in a potential schema.
|
@@ -363,7 +371,7 @@ async def run(
|
363 | 371 | return self.create_schema_model(node_types, relationship_types, patterns)
|
364 | 372 |
|
365 | 373 |
|
366 |
| -class SchemaFromTextExtractor(Component): |
| 374 | +class SchemaFromTextExtractor(BaseSchemaBuilder): |
367 | 375 | """
|
368 | 376 | A component for constructing GraphSchema objects from the output of an LLM after
|
369 | 377 | automatic schema extraction from text.
|
@@ -446,3 +454,75 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
|
446 | 454 | "patterns": extracted_patterns,
|
447 | 455 | }
|
448 | 456 | )
|
| 457 | + |
| 458 | + |
| 459 | +class SchemaFromExistingGraphExtractor(BaseSchemaBuilder): |
| 460 | + """A class to build a GraphSchema object from an existing graph.""" |
| 461 | + |
| 462 | + def __init__(self, driver: neo4j.Driver) -> None: |
| 463 | + self.driver = driver |
| 464 | + |
| 465 | + async def run(self, **kwargs: Any) -> GraphSchema: |
| 466 | + structured_schema = get_structured_schema(self.driver) |
| 467 | + node_labels = set(structured_schema["node_props"].keys()) |
| 468 | + node_types = [ |
| 469 | + { |
| 470 | + "label": key, |
| 471 | + "properties": [ |
| 472 | + { |
| 473 | + "name": p["property"], |
| 474 | + "type": p["type"], |
| 475 | + } |
| 476 | + for p in properties |
| 477 | + ], |
| 478 | + } |
| 479 | + for key, properties in structured_schema["node_props"].items() |
| 480 | + ] |
| 481 | + rel_labels = set(structured_schema["rel_props"].keys()) |
| 482 | + relationship_types = [ |
| 483 | + { |
| 484 | + "label": key, |
| 485 | + "properties": [ |
| 486 | + { |
| 487 | + "name": p["property"], |
| 488 | + "type": p["type"], |
| 489 | + } |
| 490 | + for p in properties |
| 491 | + ], |
| 492 | + } |
| 493 | + for key, properties in structured_schema["rel_props"].items() |
| 494 | + ] |
| 495 | + patterns = [ |
| 496 | + (s["start"], s["type"], s["end"]) |
| 497 | + for s in structured_schema["relationships"] |
| 498 | + ] |
| 499 | + # deal with nodes and relationships without properties |
| 500 | + for source, rel, target in patterns: |
| 501 | + if source not in node_labels: |
| 502 | + node_labels.add(source) |
| 503 | + node_types.append( |
| 504 | + { |
| 505 | + "label": source, |
| 506 | + } |
| 507 | + ) |
| 508 | + if target not in node_labels: |
| 509 | + node_labels.add(target) |
| 510 | + node_types.append( |
| 511 | + { |
| 512 | + "label": target, |
| 513 | + } |
| 514 | + ) |
| 515 | + if rel not in rel_labels: |
| 516 | + rel_labels.add(rel) |
| 517 | + relationship_types.append( |
| 518 | + { |
| 519 | + "label": rel, |
| 520 | + } |
| 521 | + ) |
| 522 | + return GraphSchema.model_validate( |
| 523 | + { |
| 524 | + "node_types": node_types, |
| 525 | + "relationship_types": relationship_types, |
| 526 | + "patterns": patterns, |
| 527 | + } |
| 528 | + ) |
0 commit comments