-
Notifications
You must be signed in to change notification settings - Fork 473
/
Copy pathsimplemind_example.py
112 lines (80 loc) · 2.45 KB
/
simplemind_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "marimo",
# "simplemind==0.1.3",
# ]
# ///
import marimo
__generated_with = "0.9.14"
app = marimo.App(width="full")
@app.cell(hide_code=True)
def __(mo):
mo.md(r"""## Using [simplemind](https://github.com/kennethreitz/simplemind) with `mo.ui.chat()`""")
return
@app.cell(hide_code=True)
def __():
import marimo as mo
import os
import simplemind as sm
return mo, os, sm
@app.cell(hide_code=True)
def __(__file__, mo, os):
has_set_env = os.environ.get("OPENAI_API_KEY") is not None
mo.md(f"""
Missing OpenAI API key. Re-run this notebook with the following command:
```bash
export OPENAI_API_KEY='sk-'
marimo edit {__file__}
```
""").callout("warn") if not has_set_env else ""
return (has_set_env,)
@app.cell
def __(mo):
get_logs, set_logs = mo.state([], allow_self_loops=True)
return get_logs, set_logs
@app.cell
def __(set_logs, sm):
def add_log(value):
return set_logs(lambda logs: logs + [value])
class LoggingPlugin(sm.BasePlugin):
def pre_send_hook(self, conversation):
add_log(
f"Sending conversation with {len(conversation.messages)} messages"
)
def add_message_hook(self, conversation, message):
add_log(f"Adding message to conversation: {message.text}")
def cleanup_hook(self, conversation):
add_log(
f"Cleaning up conversation with {len(conversation.messages)} messages"
)
def initialize_hook(self, conversation):
add_log("Initializing conversation")
def post_send_hook(self, conversation, response):
add_log(f"Received response: {response.text}")
return LoggingPlugin, add_log
@app.cell
def __(LoggingPlugin, mo, sm):
conversation = sm.create_conversation(
llm_model="gpt-4o", llm_provider="openai"
)
conversation.add_plugin(LoggingPlugin())
def on_message(messages):
conversation.add_message("user", messages[-1].content)
return conversation.send().text
chat = mo.ui.chat(on_message)
return chat, conversation, on_message
@app.cell
def __(chat, get_logs, mo):
logs = list(reversed(get_logs()))
mo.hstack(
[chat, mo.ui.table(logs, selection=None)],
widths="equal",
)
return (logs,)
@app.cell
def __(chat):
chat.value
return
if __name__ == "__main__":
app.run()