|
17 | 17 | # model_name = "microsoft/DialoGPT-small"
|
18 | 18 | tokenizer = AutoTokenizer.from_pretrained(model_name)
|
19 | 19 | model = AutoModelForCausalLM.from_pretrained(model_name)
|
20 |
| - |
| 20 | +print("====Greedy search chat====") |
21 | 21 | # chatting 5 times with greedy search
|
22 | 22 | for step in range(5):
|
23 | 23 | # take user input
|
|
35 | 35 | #print the output
|
36 | 36 | output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
|
37 | 37 | print(f"DialoGPT: {output}")
|
38 |
| - |
| 38 | +print("====Beam search chat====") |
39 | 39 | # chatting 5 times with beam search
|
40 | 40 | for step in range(5):
|
41 | 41 | # take user input
|
|
55 | 55 | #print the output
|
56 | 56 | output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
|
57 | 57 | print(f"DialoGPT: {output}")
|
58 |
| - |
| 58 | +print("====Sampling chat====") |
59 | 59 | # chatting 5 times with sampling
|
60 | 60 | for step in range(5):
|
61 | 61 | # take user input
|
|
75 | 75 | #print the output
|
76 | 76 | output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
|
77 | 77 | print(f"DialoGPT: {output}")
|
78 |
| - |
| 78 | +print("====Sampling chat with tweaking temperature====") |
79 | 79 | # chatting 5 times with sampling & tweaking temperature
|
80 | 80 | for step in range(5):
|
81 | 81 | # take user input
|
|
96 | 96 | #print the output
|
97 | 97 | output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
|
98 | 98 | print(f"DialoGPT: {output}")
|
99 |
| - |
| 99 | +print("====Top-K sampling chat with tweaking temperature====") |
100 | 100 | # chatting 5 times with Top K sampling & tweaking temperature
|
101 | 101 | for step in range(5):
|
102 | 102 | # take user input
|
|
117 | 117 | #print the output
|
118 | 118 | output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
|
119 | 119 | print(f"DialoGPT: {output}")
|
120 |
| - |
| 120 | +print("====Nucleus sampling (top-p) chat with tweaking temperature====") |
121 | 121 | # chatting 5 times with nucleus sampling & tweaking temperature
|
122 | 122 | for step in range(5):
|
123 | 123 | # take user input
|
|
139 | 139 | #print the output
|
140 | 140 | output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
|
141 | 141 | print(f"DialoGPT: {output}")
|
142 |
| - |
| 142 | +print("====chatting 5 times with nucleus & top-k sampling & tweaking temperature & multiple sentences====") |
143 | 143 | # chatting 5 times with nucleus & top-k sampling & tweaking temperature & multiple
|
144 | 144 | # sentences
|
145 | 145 | for step in range(5):
|
|
155 | 155 | max_length=1000,
|
156 | 156 | do_sample=True,
|
157 | 157 | top_p=0.95,
|
158 |
| - top_k=50,Y |
| 158 | + top_k=50, |
159 | 159 | temperature=0.75,
|
160 | 160 | num_return_sequences=5,
|
161 | 161 | pad_token_id=tokenizer.eos_token_id
|
|
0 commit comments