-
Notifications
You must be signed in to change notification settings - Fork 179
Expand file tree
/
Copy pathpydanticai_graph.py
More file actions
142 lines (111 loc) · 4.36 KB
/
pydanticai_graph.py
File metadata and controls
142 lines (111 loc) · 4.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from __future__ import annotations as _annotations
import asyncio
import os
from dataclasses import dataclass, field
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
from dotenv import load_dotenv
from groq import BaseModel
from openai import AsyncOpenAI
from pydantic_ai import Agent, format_as_xml
from pydantic_ai.messages import ModelMessage
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.providers.openai import OpenAIProvider
from pydantic_graph import (
BaseNode,
End,
Graph,
GraphRunContext,
)
# Setup the OpenAI client to use either Azure OpenAI or GitHub Models
load_dotenv(override=True)
API_HOST = os.getenv("API_HOST", "github")
async_credential = None
if API_HOST == "azure":
async_credential = DefaultAzureCredential()
token_provider = get_bearer_token_provider(async_credential, "https://cognitiveservices.azure.com/.default")
client = AsyncOpenAI(
base_url=os.environ["AZURE_OPENAI_ENDPOINT"] + "/openai/v1",
api_key=token_provider,
)
model = OpenAIChatModel(os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"], provider=OpenAIProvider(openai_client=client))
elif API_HOST == "github":
client = AsyncOpenAI(api_key=os.environ["GITHUB_TOKEN"], base_url="https://models.inference.ai.azure.com")
model = OpenAIChatModel(os.getenv("GITHUB_MODEL", "gpt-4o"), provider=OpenAIProvider(openai_client=client))
elif API_HOST == "ollama":
client = AsyncOpenAI(base_url=os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434/v1"), api_key="none")
model = OpenAIChatModel(os.environ["OLLAMA_MODEL"], provider=OpenAIProvider(openai_client=client))
else:
client = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"])
model = OpenAIChatModel(os.environ.get("OPENAI_MODEL", "gpt-4o"), provider=OpenAIProvider(openai_client=client))
"""
Agent definitions
"""
ask_agent = Agent(model, output_type=str)
class EvaluationResult(BaseModel, use_attribute_docstrings=True):
correct: bool
"""Whether the answer is correct."""
comment: str
"""Comment on the answer, reprimand the user if the answer is wrong."""
evaluate_agent = Agent(
model,
output_type=EvaluationResult,
system_prompt="Given a question and answer, evaluate if the answer is correct.",
)
"""
Graph state and nodes
"""
@dataclass
class QuestionState:
question: str | None = None
ask_agent_messages: list[ModelMessage] = field(default_factory=list)
evaluate_agent_messages: list[ModelMessage] = field(default_factory=list)
@dataclass
class Ask(BaseNode[QuestionState]):
async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer:
result = await ask_agent.run(
"Ask a simple question with a single correct answer.",
message_history=ctx.state.ask_agent_messages,
)
ctx.state.ask_agent_messages += result.all_messages()
ctx.state.question = result.output
return Answer(result.output)
@dataclass
class Answer(BaseNode[QuestionState]):
question: str
async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate:
answer = input(f"{self.question}: ")
return Evaluate(answer)
@dataclass
class Evaluate(BaseNode[QuestionState, None, str]):
answer: str
async def run(
self,
ctx: GraphRunContext[QuestionState],
) -> End[str] | Reprimand:
assert ctx.state.question is not None
result = await evaluate_agent.run(
format_as_xml({"question": ctx.state.question, "answer": self.answer}),
message_history=ctx.state.evaluate_agent_messages,
)
ctx.state.evaluate_agent_messages += result.all_messages()
if result.output.correct:
return End(result.output.comment)
else:
return Reprimand(result.output.comment)
@dataclass
class Reprimand(BaseNode[QuestionState]):
comment: str
async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask:
print(f"Comment: {self.comment}")
ctx.state.question = None
return Ask()
question_graph = Graph(nodes=(Ask, Answer, Evaluate, Reprimand), state_type=QuestionState)
async def main():
state = QuestionState()
node = Ask()
end = await question_graph.run(node, state=state)
print("END:", end.output)
if async_credential:
await async_credential.close()
if __name__ == "__main__":
asyncio.run(main())