-
Notifications
You must be signed in to change notification settings - Fork 471
/
Copy pathmlx_chat.py
217 lines (176 loc) · 6.21 KB
/
mlx_chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "marimo",
# "mlx-lm==0.19.0",
# "huggingface-hub==0.25.1",
# ]
# ///
import marimo
__generated_with = "0.9.4"
app = marimo.App(width="medium")
@app.cell
def __():
from mlx_lm import load, generate
from pathlib import Path
import marimo as mo
from huggingface_hub import snapshot_download
return Path, generate, load, mo, snapshot_download
@app.cell(hide_code=True)
def __(mo):
mo.md(
r"""
# Using MLX with Marimo
## Chat Example
This example shows how to use [`mo.ui.chat`](https://docs.marimo.io/api/inputs/chat.html#marimo.ui.chat) to make a chatbot backed by Apple's MLX, using the `mlx_lm` library and marimo.
[`mlx_lm`](https://github.com/ml-explore/mlx-examples/tree/main/llm) is a library for running large language models on Apple Silicon.
[`mlx`](https://github.com/ml-explore/mlx) is a framework for running machine learning models on Apple Silicon.
Convert your own models to MLX, or find community-converted ones at various quantizations [here](https://huggingface.co/mlx-community).
### Things you can do to improve this example:
- [`prompt caching`](https://github.com/ml-explore/mlx-examples/blob/main/llms/README.md#long-prompts-and-generations)
- completions / notebook mode
- assistant pre-fill
"""
)
return
@app.cell
def __(Path, snapshot_download):
def get_model_path(path_or_hf_repo: str) -> Path:
"""
Ensures the model is available locally. If the path does not exist locally,
it is downloaded from the Hugging Face Hub.
Args:
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
Returns:
Path: The path to the model.
"""
model_path = Path(path_or_hf_repo)
if model_path.exists():
return model_path
else:
try:
# If it doesn't exist locally, download it from Hugging Face
return Path(
snapshot_download(
repo_id=path_or_hf_repo,
allow_patterns=[
"*.json",
"*.safetensors",
"*.py",
"tokenizer.model",
"*.tiktoken",
"*.txt",
],
)
)
except Exception as e:
raise ValueError(
f"Error downloading model from Hugging Face: {str(e)}"
)
return (get_model_path,)
@app.cell
def __(mo):
MODEL_ID = mo.ui.text(
label="Hugging Face Model Repo or Local Path",
value="mlx-community/Llama-3.2-3B-Instruct-bf16",
placeholder="Enter huggingfacerepo_id/model_id or local path",
full_width=True,
)
load_model_button = mo.ui.run_button(label="Load Model")
mo.hstack([MODEL_ID, load_model_button])
return MODEL_ID, load_model_button
@app.cell
def __(MODEL_ID, get_model_path, load, load_model_button, mo):
mo.stop(not load_model_button.value, "Click 'Load Model' to proceed")
try:
mo.output.append(
"⏳ Fetching model... This may take a while if downloading from Hugging Face."
)
model_path = get_model_path(MODEL_ID.value)
mo.output.append(f"📁 Model path: {model_path}")
mo.output.append("🔄 Loading model into memory...")
model, tokenizer = load(model_path)
mo.output.append(f"✅ Model loaded successfully!")
except Exception as e:
mo.output.append(f"❌ Error loading model: {str(e)}")
raise
return model, model_path, tokenizer
@app.cell(hide_code=True)
def __(mo):
# Create a text area for the system message
system_message = mo.ui.text_area(
value="You are a helpful AI assistant.",
label="System Message",
full_width=True,
rows=3,
)
system_message # display the system message
return (system_message,)
@app.cell(hide_code=True)
def __(mo):
temp_slider = mo.ui.slider(
start=0.0, stop=2.0, step=0.1, value=0.7, label="Temperature Slider"
)
max_tokens = mo.ui.number(value=512, label="Max Tokens Per Turn")
temp_slider, max_tokens # display the inputs
return max_tokens, temp_slider
@app.cell
def __(
generate,
max_tokens,
mo,
model,
system_message,
temp_slider,
tokenizer,
):
def mlx_chat_model(messages, config):
# Include the system message as the first message
chat_messages = [{"role": "system", "content": system_message.value}]
# Add the rest of the messages
chat_messages.extend(
[{"role": msg.role, "content": msg.content} for msg in messages]
)
# Use the tokenizer's chat template if available
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
prompt = tokenizer.apply_chat_template(
chat_messages, tokenize=False, add_generation_prompt=True
)
else:
# Fallback to simple concatenation if no chat template
prompt = "\n".join(
f"{msg['role']}: {msg['content']}" for msg in chat_messages
)
prompt += "\nassistant:"
# Generate the response
response = generate(
model,
tokenizer,
prompt=prompt,
max_tokens=int(max_tokens.value), # Use the max_tokens input
temp=float(temp_slider.value), # Use the temperature slider
)
return response.strip()
# Create the chat interface
chatbot = mo.ui.chat(
mlx_chat_model,
prompts=[
"Hello",
"How are you?",
"I'm doing great, how about you?",
],
)
# Display the chatbot
chatbot
return chatbot, mlx_chat_model
@app.cell(hide_code=True)
def __(mo):
mo.md("""Access the chatbot's historical messages with `chatbot.value`.""")
return
@app.cell
def __(chatbot):
# Display the chat history
chatbot.value
return
if __name__ == "__main__":
app.run()