1
1
from transformers import AutoModel , AutoTokenizer
2
2
import streamlit as st
3
- from streamlit_chat import message
4
3
5
4
6
5
st .set_page_config (
@@ -21,40 +20,9 @@ def get_model():
21
20
return tokenizer , model
22
21
23
22
24
- MAX_TURNS = 20
25
- MAX_BOXES = MAX_TURNS * 2
23
+ tokenizer , model = get_model ()
26
24
27
-
28
- def predict (input , max_length , top_p , temperature , history = None ):
29
- tokenizer , model = get_model ()
30
- if history is None :
31
- history = []
32
-
33
- with container :
34
- if len (history ) > 0 :
35
- if len (history )> MAX_BOXES :
36
- history = history [- MAX_TURNS :]
37
- for i , (query , response ) in enumerate (history ):
38
- message (query , avatar_style = "big-smile" , key = str (i ) + "_user" )
39
- message (response , avatar_style = "bottts" , key = str (i ))
40
-
41
- message (input , avatar_style = "big-smile" , key = str (len (history )) + "_user" )
42
- st .write ("AI正在回复:" )
43
- with st .empty ():
44
- for response , history in model .stream_chat (tokenizer , input , history , max_length = max_length , top_p = top_p ,
45
- temperature = temperature ):
46
- query , response = history [- 1 ]
47
- st .write (response )
48
-
49
- return history
50
-
51
-
52
- container = st .container ()
53
-
54
- # create a prompt text for the text generation
55
- prompt_text = st .text_area (label = "用户命令输入" ,
56
- height = 100 ,
57
- placeholder = "请在这儿输入您的命令" )
25
+ st .title ("ChatGLM2-6B" )
58
26
59
27
max_length = st .sidebar .slider (
60
28
'max_length' , 0 , 32768 , 8192 , step = 1
@@ -63,13 +31,40 @@ def predict(input, max_length, top_p, temperature, history=None):
63
31
'top_p' , 0.0 , 1.0 , 0.8 , step = 0.01
64
32
)
65
33
temperature = st .sidebar .slider (
66
- 'temperature' , 0.0 , 1.0 , 0.95 , step = 0.01
34
+ 'temperature' , 0.0 , 1.0 , 0.8 , step = 0.01
67
35
)
68
36
69
- if 'state' not in st .session_state :
70
- st .session_state ['state' ] = []
37
+ if 'history' not in st .session_state :
38
+ st .session_state .history = []
39
+
40
+ if 'past_key_values' not in st .session_state :
41
+ st .session_state .past_key_values = None
71
42
72
- if st .button ("发送" , key = "predict" ):
73
- with st .spinner ("AI正在思考,请稍等........" ):
74
- # text generation
75
- st .session_state ["state" ] = predict (prompt_text , max_length , top_p , temperature , st .session_state ["state" ])
43
+ for i , (query , response ) in enumerate (st .session_state .history ):
44
+ with st .chat_message (name = "user" , avatar = "user" ):
45
+ st .markdown (query )
46
+ with st .chat_message (name = "assistant" , avatar = "assistant" ):
47
+ st .markdown (response )
48
+ with st .chat_message (name = "user" , avatar = "user" ):
49
+ input_placeholder = st .empty ()
50
+ with st .chat_message (name = "assistant" , avatar = "assistant" ):
51
+ message_placeholder = st .empty ()
52
+
53
+ prompt_text = st .text_area (label = "用户命令输入" ,
54
+ height = 100 ,
55
+ placeholder = "请在这儿输入您的命令" )
56
+
57
+ button = st .button ("发送" , key = "predict" )
58
+
59
+ if button :
60
+ input_placeholder .markdown (prompt_text )
61
+ history , past_key_values = st .session_state .history , st .session_state .past_key_values
62
+ for response , history , past_key_values in model .stream_chat (tokenizer , prompt_text , history ,
63
+ past_key_values = past_key_values ,
64
+ max_length = max_length , top_p = top_p ,
65
+ temperature = temperature ,
66
+ return_past_key_values = True ):
67
+ message_placeholder .markdown (response )
68
+
69
+ st .session_state .history = history
70
+ st .session_state .past_key_values = past_key_values
0 commit comments