Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add configurable prompts #35

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions graphrag_sdk/chat_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ class ChatSession:
>>> chat_session.send_message("What is the capital of France?")
"""

def __init__(self, model_config: KnowledgeGraphModelConfig, ontology: Ontology, graph: Graph):
def __init__(self, model_config: KnowledgeGraphModelConfig, ontology: Ontology, graph: Graph,
cypher_system_instruction: str = None, qa_system_instruction: str = None,
cypher_gen_prompt: str = None, qa_prompt: str = None):
"""
Initializes a new ChatSession object.

Expand All @@ -44,13 +46,21 @@ def __init__(self, model_config: KnowledgeGraphModelConfig, ontology: Ontology,
self.model_config = model_config
self.graph = graph
self.ontology = ontology
if cypher_system_instruction is None:
cypher_system_instruction = CYPHER_GEN_SYSTEM.replace("#ONTOLOGY", str(ontology.to_json()))
else:
cypher_system_instruction = cypher_system_instruction + "\nOntology:\n" + str(ontology.to_json())
Comment on lines +49 to +52
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Standardize ontology injection method.

The code uses two different approaches to inject the ontology:

  1. Using replace("#ONTOLOGY", str(ontology.to_json())) for the default case
  2. Using string concatenation with "\nOntology:\n" for custom instructions

This inconsistency could lead to formatting differences and maintenance issues.

Consider standardizing the approach:

-        if cypher_system_instruction is None:
-            cypher_system_instruction = CYPHER_GEN_SYSTEM.replace("#ONTOLOGY", str(ontology.to_json()))
-        else:
-            cypher_system_instruction = cypher_system_instruction + "\nOntology:\n" + str(ontology.to_json())
+        base_instruction = cypher_system_instruction or CYPHER_GEN_SYSTEM
+        ontology_json = str(ontology.to_json())
+        cypher_system_instruction = base_instruction.replace("#ONTOLOGY", ontology_json)

Committable suggestion skipped: line range outside the PR's diff.


self.cypher_prompt = cypher_gen_prompt
self.qa_prompt = qa_prompt

self.cypher_chat_session = (
model_config.cypher_generation.with_system_instruction(
CYPHER_GEN_SYSTEM.replace("#ONTOLOGY", str(ontology.to_json()))
).start_chat()
)
self.qa_chat_session = model_config.qa.with_system_instruction(
GRAPH_QA_SYSTEM
qa_system_instruction or GRAPH_QA_SYSTEM
).start_chat()
self.last_answer = None

Expand All @@ -69,6 +79,7 @@ def send_message(self, message: str):
chat_session=self.cypher_chat_session,
ontology=self.ontology,
last_answer=self.last_answer,
cypher_prompt=self.cypher_prompt,
)

(context, cypher) = cypher_step.run(message)
Expand All @@ -78,6 +89,7 @@ def send_message(self, message: str):

qa_step = QAStep(
chat_session=self.qa_chat_session,
qa_prompt=self.qa_prompt,
)

answer = qa_step.run(message, cypher, context)
Expand Down
8 changes: 5 additions & 3 deletions graphrag_sdk/kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,11 @@ def delete(self) -> None:
for key in self.__dict__.keys():
setattr(self, key, None)

def chat_session(self) -> ChatSession:
return ChatSession(self._model_config, self.ontology, self.graph)

def chat_session(self, cypher_system_instruction: str = None, qa_system_instruction: str = None,
cypher_gen_prompt: str = None, qa_prompt: str = None) -> ChatSession:
chat_session = ChatSession(self._model_config, self.ontology, self.graph, cypher_system_instruction,
qa_system_instruction, cypher_gen_prompt, qa_prompt)
return chat_session
def add_node(self, entity: str, attributes: dict):
"""
Add a node to the knowledge graph, checking if it matches the ontology
Expand Down
23 changes: 14 additions & 9 deletions graphrag_sdk/steps/graph_query_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,33 @@ def __init__(
chat_session: GenerativeModelChatSession,
config: dict = None,
last_answer: str = None,
cypher_prompt: str = None,
) -> None:
self.ontology = ontology
self.config = config or {}
self.graph = graph
self.chat_session = chat_session
self.last_answer = last_answer
self.cypher_prompt = cypher_prompt

def run(self, question: str, retries: int = 5):
error = False

cypher = ""
while error is not None and retries > 0:
try:
cypher_prompt = (
(CYPHER_GEN_PROMPT.format(question=question)
if self.last_answer is None
else CYPHER_GEN_PROMPT_WITH_HISTORY.format(question=question, last_answer=self.last_answer))
if error is False
else CYPHER_GEN_PROMPT_WITH_ERROR.format(
question=question, error=error
)
)
if self.cypher_prompt is not None:
cypher_prompt = self.cypher_prompt
else:
cypher_prompt = (
(CYPHER_GEN_PROMPT.format(question=question)
if self.last_answer is None
else CYPHER_GEN_PROMPT_WITH_HISTORY.format(question=question, last_answer=self.last_answer))
if error is False
else CYPHER_GEN_PROMPT_WITH_ERROR.format(
question=question, error=error
)
)
Comment on lines +51 to +62
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Refactor nested ternary expressions for better readability.

The current implementation uses nested ternary expressions which can be hard to read and maintain. Consider extracting the logic into a separate method for better clarity.

Here's a suggested refactor:

-                if self.cypher_prompt is not None:
-                    cypher_prompt = self.cypher_prompt
-                else:
-                    cypher_prompt = (
-                        (CYPHER_GEN_PROMPT.format(question=question) 
-                        if self.last_answer is None
-                        else CYPHER_GEN_PROMPT_WITH_HISTORY.format(question=question, last_answer=self.last_answer))
-                        if error is False
-                        else CYPHER_GEN_PROMPT_WITH_ERROR.format(
-                            question=question, error=error
-                        )
-                    )   
+                cypher_prompt = self._get_cypher_prompt(question, error)
+
+    def _get_cypher_prompt(self, question: str, error: Exception | bool) -> str:
+        if self.cypher_prompt is not None:
+            return self.cypher_prompt
+        
+        if error:
+            return CYPHER_GEN_PROMPT_WITH_ERROR.format(
+                question=question, 
+                error=error
+            )
+        
+        if self.last_answer is None:
+            return CYPHER_GEN_PROMPT.format(question=question)
+        
+        return CYPHER_GEN_PROMPT_WITH_HISTORY.format(
+            question=question,
+            last_answer=self.last_answer
+        )

Committable suggestion skipped: line range outside the PR's diff.

logger.debug(f"Cypher Prompt: {cypher_prompt}")
cypher_statement_response = self.chat_session.send_message(
cypher_prompt,
Expand Down
10 changes: 6 additions & 4 deletions graphrag_sdk/steps/qa_step.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from graphrag_sdk.steps.Step import Step
from graphrag_sdk.models import GenerativeModelChatSession

from graphrag_sdk.fixtures.prompts import GRAPH_QA_SYSTEM, GRAPH_QA_PROMPT
from graphrag_sdk.fixtures.prompts import GRAPH_QA_PROMPT
import logging

logger = logging.getLogger(__name__)
Expand All @@ -17,17 +17,19 @@ def __init__(
self,
chat_session: GenerativeModelChatSession,
config: dict = None,
qa_prompt: str = None,
) -> None:
self.config = config or {}
self.chat_session = chat_session
self.qa_prompt = qa_prompt

def run(self, question: str, cypher: str, context: str):

qa_prompt = GRAPH_QA_PROMPT.format(
graph_qa_prompt = self.qa_prompt or GRAPH_QA_PROMPT
qa_prompt = graph_qa_prompt.format(
context=context, cypher=cypher, question=question
)

# logger.debug(f"QA Prompt: {qa_prompt}")
logger.debug(f"QA Prompt: {qa_prompt}")
qa_response = self.chat_session.send_message(qa_prompt)

return qa_response.text
Loading