Skip to content

Commit 255b532

Browse files
committed
fix syntax error on dialogpt tutorial
1 parent 0ed8f6f commit 255b532

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

Diff for: machine-learning/nlp/chatbot-transformers/dialogpt.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# model_name = "microsoft/DialoGPT-small"
1818
tokenizer = AutoTokenizer.from_pretrained(model_name)
1919
model = AutoModelForCausalLM.from_pretrained(model_name)
20-
20+
print("====Greedy search chat====")
2121
# chatting 5 times with greedy search
2222
for step in range(5):
2323
# take user input
@@ -35,7 +35,7 @@
3535
#print the output
3636
output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
3737
print(f"DialoGPT: {output}")
38-
38+
print("====Beam search chat====")
3939
# chatting 5 times with beam search
4040
for step in range(5):
4141
# take user input
@@ -55,7 +55,7 @@
5555
#print the output
5656
output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
5757
print(f"DialoGPT: {output}")
58-
58+
print("====Sampling chat====")
5959
# chatting 5 times with sampling
6060
for step in range(5):
6161
# take user input
@@ -75,7 +75,7 @@
7575
#print the output
7676
output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
7777
print(f"DialoGPT: {output}")
78-
78+
print("====Sampling chat with tweaking temperature====")
7979
# chatting 5 times with sampling & tweaking temperature
8080
for step in range(5):
8181
# take user input
@@ -96,7 +96,7 @@
9696
#print the output
9797
output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
9898
print(f"DialoGPT: {output}")
99-
99+
print("====Top-K sampling chat with tweaking temperature====")
100100
# chatting 5 times with Top K sampling & tweaking temperature
101101
for step in range(5):
102102
# take user input
@@ -117,7 +117,7 @@
117117
#print the output
118118
output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
119119
print(f"DialoGPT: {output}")
120-
120+
print("====Nucleus sampling (top-p) chat with tweaking temperature====")
121121
# chatting 5 times with nucleus sampling & tweaking temperature
122122
for step in range(5):
123123
# take user input
@@ -139,7 +139,7 @@
139139
#print the output
140140
output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
141141
print(f"DialoGPT: {output}")
142-
142+
print("====chatting 5 times with nucleus & top-k sampling & tweaking temperature & multiple sentences====")
143143
# chatting 5 times with nucleus & top-k sampling & tweaking temperature & multiple
144144
# sentences
145145
for step in range(5):
@@ -155,7 +155,7 @@
155155
max_length=1000,
156156
do_sample=True,
157157
top_p=0.95,
158-
top_k=50,Y
158+
top_k=50,
159159
temperature=0.75,
160160
num_return_sequences=5,
161161
pad_token_id=tokenizer.eos_token_id

0 commit comments

Comments
 (0)