Skip to content

Commit 9e18d61

Browse files
committed
Update web demo
1 parent 34a25f9 commit 9e18d61

File tree

4 files changed

+48
-54
lines changed

4 files changed

+48
-54
lines changed

README.md

+10-12
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,11 @@ git clone https://github.com/THUDM/ChatGLM2-6B
137137
cd ChatGLM2-6B
138138
```
139139

140-
然后使用 pip 安装依赖:`pip install -r requirements.txt`,其中 `transformers` 库版本推荐为 `4.30.2``torch` 推荐使用 2.0 以上的版本,以获得最佳的推理性能。
140+
然后使用 pip 安装依赖:
141+
```
142+
pip install -r requirements.txt
143+
```
144+
其中 `transformers` 库版本推荐为 `4.30.2``torch` 推荐使用 2.0 及以上的版本,以获得最佳的推理性能。
141145

142146
### 代码调用
143147

@@ -188,23 +192,17 @@ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/THUDM/chatglm2-6b
188192

189193
![web-demo](resources/web-demo.gif)
190194

191-
首先安装 Gradio:`pip install gradio`,然后运行仓库中的 [web_demo.py](web_demo.py)
192-
195+
可以通过以下命令启动基于 Streamlit 的网页版 demo:
193196
```shell
194-
python web_demo.py
197+
streamlit run web_demo2.py
195198
```
196199

197200
程序会运行一个 Web Server,并输出地址。在浏览器中打开输出的地址即可使用。
198-
> 默认使用了 `share=False` 启动,不会生成公网链接。如有需要公网访问的需求,可以修改为 `share=True` 启动。
199-
>
200201

201-
感谢 [@AdamBear](https://github.com/AdamBear) 实现了基于 Streamlit 的网页版 Demo `web_demo2.py`。使用时首先需要额外安装以下依赖:
202-
```shell
203-
pip install streamlit streamlit-chat
204-
```
205-
然后通过以下命令运行:
202+
203+
[web_demo.py](./web_demo.py) 中提供了旧版基于 Gradio 的 web demo,可以通过如下命令运行:
206204
```shell
207-
streamlit run web_demo2.py
205+
python web_demo.py
208206
```
209207
经测试,如果输入的 prompt 较长的话,使用基于 Streamlit 的网页版 Demo 会更流畅。
210208

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ gradio
66
mdtex2html
77
sentencepiece
88
accelerate
9-
sse-starlette
9+
sse-starlette
10+
streamlit>=1.24.0

resources/web-demo.gif

466 KB
Loading

web_demo2.py

+36-41
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from transformers import AutoModel, AutoTokenizer
22
import streamlit as st
3-
from streamlit_chat import message
43

54

65
st.set_page_config(
@@ -21,40 +20,9 @@ def get_model():
2120
return tokenizer, model
2221

2322

24-
MAX_TURNS = 20
25-
MAX_BOXES = MAX_TURNS * 2
23+
tokenizer, model = get_model()
2624

27-
28-
def predict(input, max_length, top_p, temperature, history=None):
29-
tokenizer, model = get_model()
30-
if history is None:
31-
history = []
32-
33-
with container:
34-
if len(history) > 0:
35-
if len(history)>MAX_BOXES:
36-
history = history[-MAX_TURNS:]
37-
for i, (query, response) in enumerate(history):
38-
message(query, avatar_style="big-smile", key=str(i) + "_user")
39-
message(response, avatar_style="bottts", key=str(i))
40-
41-
message(input, avatar_style="big-smile", key=str(len(history)) + "_user")
42-
st.write("AI正在回复:")
43-
with st.empty():
44-
for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
45-
temperature=temperature):
46-
query, response = history[-1]
47-
st.write(response)
48-
49-
return history
50-
51-
52-
container = st.container()
53-
54-
# create a prompt text for the text generation
55-
prompt_text = st.text_area(label="用户命令输入",
56-
height = 100,
57-
placeholder="请在这儿输入您的命令")
25+
st.title("ChatGLM2-6B")
5826

5927
max_length = st.sidebar.slider(
6028
'max_length', 0, 32768, 8192, step=1
@@ -63,13 +31,40 @@ def predict(input, max_length, top_p, temperature, history=None):
6331
'top_p', 0.0, 1.0, 0.8, step=0.01
6432
)
6533
temperature = st.sidebar.slider(
66-
'temperature', 0.0, 1.0, 0.95, step=0.01
34+
'temperature', 0.0, 1.0, 0.8, step=0.01
6735
)
6836

69-
if 'state' not in st.session_state:
70-
st.session_state['state'] = []
37+
if 'history' not in st.session_state:
38+
st.session_state.history = []
39+
40+
if 'past_key_values' not in st.session_state:
41+
st.session_state.past_key_values = None
7142

72-
if st.button("发送", key="predict"):
73-
with st.spinner("AI正在思考,请稍等........"):
74-
# text generation
75-
st.session_state["state"] = predict(prompt_text, max_length, top_p, temperature, st.session_state["state"])
43+
for i, (query, response) in enumerate(st.session_state.history):
44+
with st.chat_message(name="user", avatar="user"):
45+
st.markdown(query)
46+
with st.chat_message(name="assistant", avatar="assistant"):
47+
st.markdown(response)
48+
with st.chat_message(name="user", avatar="user"):
49+
input_placeholder = st.empty()
50+
with st.chat_message(name="assistant", avatar="assistant"):
51+
message_placeholder = st.empty()
52+
53+
prompt_text = st.text_area(label="用户命令输入",
54+
height=100,
55+
placeholder="请在这儿输入您的命令")
56+
57+
button = st.button("发送", key="predict")
58+
59+
if button:
60+
input_placeholder.markdown(prompt_text)
61+
history, past_key_values = st.session_state.history, st.session_state.past_key_values
62+
for response, history, past_key_values in model.stream_chat(tokenizer, prompt_text, history,
63+
past_key_values=past_key_values,
64+
max_length=max_length, top_p=top_p,
65+
temperature=temperature,
66+
return_past_key_values=True):
67+
message_placeholder.markdown(response)
68+
69+
st.session_state.history = history
70+
st.session_state.past_key_values = past_key_values

0 commit comments

Comments
 (0)