Skip to content
This repository was archived by the owner on Jan 1, 2021. It is now read-only.

Commit 54c48f5

Browse files
committed
fixed bugs chatbot
1 parent ecde8d0 commit 54c48f5

File tree

3 files changed

+36
-85
lines changed

3 files changed

+36
-85
lines changed

assignments/chatbot/chatbot.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def _get_buckets():
8787
train_buckets_scale is the inverval that'll help us
8888
choose a random bucket later on.
8989
"""
90-
test_buckets = data.load_data('test.enc.ids', 'test.dec.ids')
91-
data_buckets = data.load_data('train.enc.ids', 'train.dec.ids')
90+
test_buckets = data.load_data('test_ids.enc', 'test_ids.dec')
91+
data_buckets = data.load_data('train_ids.enc', 'train_ids.dec')
9292
train_bucket_sizes = [len(data_buckets[b]) for b in range(len(config.BUCKETS))]
9393
print("Number of samples in each bucket:\n", train_bucket_sizes)
9494
train_total_size = sum(train_bucket_sizes)
@@ -169,8 +169,7 @@ def _get_user_input():
169169
""" Get user's input, which will be transformed into encoder input later """
170170
print("> ", end="")
171171
sys.stdout.flush()
172-
text = sys.stdin.readline()
173-
return data.tokenize_helper(text)
172+
return sys.stdin.readline()
174173

175174
def _find_right_bucket(length):
176175
""" Find the proper bucket for an encoder input based on its length """
@@ -184,6 +183,7 @@ def _construct_response(output_logits, inv_dec_vocab):
184183
185184
This is a greedy decoder - outputs are just argmaxes of output_logits.
186185
"""
186+
print(output_logits[0])
187187
outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
188188
# If there is an EOS symbol in outputs, cut them at that point.
189189
if config.EOS_ID in outputs:
@@ -200,8 +200,6 @@ def chat():
200200
model = ChatBotModel(True, batch_size=1)
201201
model.build_graph()
202202

203-
# saver = tf.train.import_meta_graph('checkpoints/chatbot-30.meta')
204-
205203
saver = tf.train.Saver()
206204

207205
with tf.Session() as sess:

assignments/chatbot/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
PROCESSED_PATH = 'processed'
2424
CPT_PATH = 'checkpoints'
2525

26-
THRESHOLD = 1
26+
THRESHOLD = 2
2727

2828
PAD_ID = 0
2929
UNK_ID = 1

assignments/chatbot/data.py

+31-78
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,21 @@
2626
def get_lines():
2727
id2line = {}
2828
file_path = os.path.join(config.DATA_PATH, config.LINE_FILE)
29+
print(config.LINE_FILE)
2930
with open(file_path, 'r', errors='ignore') as f:
30-
lines = f.readlines()
31-
for i, line in enumerate(lines):
32-
parts = line.split(' +++$+++ ')
33-
if len(parts) == 5:
34-
if parts[4][-1] == '\n':
35-
parts[4] = parts[4][:-1]
36-
id2line[parts[0]] = parts[4]
31+
# lines = f.readlines()
32+
# for line in lines:
33+
i = 0
34+
try:
35+
for line in f:
36+
parts = line.split(' +++$+++ ')
37+
if len(parts) == 5:
38+
if parts[4][-1] == '\n':
39+
parts[4] = parts[4][:-1]
40+
id2line[parts[0]] = parts[4]
41+
i += 1
42+
except UnicodeDecodeError:
43+
print(i, line)
3744
return id2line
3845

3946
def get_convos():
@@ -54,64 +61,13 @@ def get_convos():
5461
def question_answers(id2line, convos):
5562
""" Divide the dataset into two sets: questions and answers. """
5663
questions, answers = [], []
57-
seen_questions, seen_answers = set(), set()
58-
repeated = 0
5964
for convo in convos:
6065
for index, line in enumerate(convo[:-1]):
61-
if not convo[index] in id2line or not convo[index+1] in id2line:
62-
continue
63-
q = id2line[convo[index]]
64-
a = id2line[convo[index + 1]]
65-
# if q in seen_questions or a in seen_answers:
66-
if q in seen_questions:
67-
print('Q:', q)
68-
print('A:', a)
69-
repeated += 1
70-
continue
71-
questions.append(q)
72-
answers.append(a)
73-
seen_questions.add(q)
74-
# seen_answers.add(a)
66+
questions.append(id2line[convo[index]])
67+
answers.append(id2line[convo[index + 1]])
7568
assert len(questions) == len(answers)
76-
print('Total repeated:', repeated)
7769
return questions, answers
7870

79-
def tokenize_helper(line):
80-
tokens = basic_tokenizer(line)
81-
text = ' '.join(tokens)
82-
for a, b in config.CONTRACTIONS:
83-
text = text.replace(a, b)
84-
return text
85-
86-
def tokenize_data():
87-
print('Tokenizing the data ...')
88-
# filenames = ['test.enc', 'test.dec', 'train.enc', 'train.dec']
89-
modes = ['train', 'test']
90-
seen_questions = set()
91-
for mode in modes:
92-
q_file = os.path.join(config.PROCESSED_PATH, mode + '.enc')
93-
a_file = os.path.join(config.PROCESSED_PATH, mode + '.dec')
94-
q_out = open(os.path.join(config.PROCESSED_PATH, mode + '.enc.tok'), 'w')
95-
a_out = open(os.path.join(config.PROCESSED_PATH, mode + '.dec.tok'), 'w')
96-
97-
q_lines = open(q_file, 'r').readlines()
98-
a_lines = open(a_file, 'r').readlines()
99-
n = len(q_lines)
100-
repeated = 0
101-
102-
for i in range(n):
103-
q, a = q_lines[i], a_lines[i]
104-
q_clean = tokenize_helper(q)
105-
if q_clean in seen_questions:
106-
print(q_clean)
107-
repeated += 1
108-
continue
109-
seen_questions.add(q_clean)
110-
q_out.write(q_clean + '\n')
111-
a_clean = tokenize_helper(a)
112-
a_out.write(a_clean + '\n')
113-
print('Total repeated in', mode, ':', repeated)
114-
11571
def prepare_dataset(questions, answers):
11672
# create path to store all the train & test encoder & decoder
11773
make_dir(config.PROCESSED_PATH)
@@ -122,7 +78,7 @@ def prepare_dataset(questions, answers):
12278
filenames = ['train.enc', 'train.dec', 'test.enc', 'test.dec']
12379
files = []
12480
for filename in filenames:
125-
files.append(open(os.path.join(config.PROCESSED_PATH, filename), 'w'))
81+
files.append(open(os.path.join(config.PROCESSED_PATH, filename),'w'))
12682

12783
for i in range(len(questions)):
12884
if i in test_ids:
@@ -142,14 +98,13 @@ def make_dir(path):
14298
except OSError:
14399
pass
144100

145-
def basic_tokenizer(line, normalize_digits=False):
101+
def basic_tokenizer(line, normalize_digits=True):
146102
""" A basic tokenizer to tokenize text into tokens.
147103
Feel free to change this to suit your need. """
148104
line = re.sub('<u>', '', line)
149105
line = re.sub('</u>', '', line)
150106
line = re.sub('\[', '', line)
151107
line = re.sub('\]', '', line)
152-
line = line.replace('`', "'")
153108
words = []
154109
_WORD_SPLIT = re.compile("([.,!?\"'-<>:;)(])")
155110
_DIGIT_RE = re.compile(r"\d")
@@ -162,15 +117,14 @@ def basic_tokenizer(line, normalize_digits=False):
162117
words.append(token)
163118
return words
164119

165-
def build_vocab(filename, normalize_digits=False):
120+
def build_vocab(filename, normalize_digits=True):
166121
in_path = os.path.join(config.PROCESSED_PATH, filename)
167-
out_path = os.path.join(config.PROCESSED_PATH, 'vocab.{}'.format(filename[-7:-4]))
122+
out_path = os.path.join(config.PROCESSED_PATH, 'vocab.{}'.format(filename[-3:]))
168123

169124
vocab = {}
170125
with open(in_path, 'r') as f:
171126
for line in f.readlines():
172-
tokens = line.split()
173-
for token in tokens:
127+
for token in basic_tokenizer(line):
174128
if not token in vocab:
175129
vocab[token] = 0
176130
vocab[token] += 1
@@ -184,29 +138,29 @@ def build_vocab(filename, normalize_digits=False):
184138
index = 4
185139
for word in sorted_vocab:
186140
if vocab[word] < config.THRESHOLD:
187-
with open('config.py', 'a') as cf:
188-
if 'enc' in filename:
189-
cf.write('ENC_VOCAB = ' + str(index) + '\n')
190-
else:
191-
cf.write('DEC_VOCAB = ' + str(index) + '\n')
192141
break
193142
f.write(word + '\n')
194143
index += 1
144+
with open('config.py', 'a') as cf:
145+
if filename[-3:] == 'enc':
146+
cf.write('ENC_VOCAB = ' + str(index) + '\n')
147+
else:
148+
cf.write('DEC_VOCAB = ' + str(index) + '\n')
195149

196150
def load_vocab(vocab_path):
197151
with open(vocab_path, 'r') as f:
198152
words = f.read().splitlines()
199153
return words, {words[i]: i for i in range(len(words))}
200154

201155
def sentence2id(vocab, line):
202-
return [vocab.get(token, vocab['<unk>']) for token in line]
156+
return [vocab.get(token, vocab['<unk>']) for token in basic_tokenizer(line)]
203157

204158
def token2id(data, mode):
205159
""" Convert all the tokens in the data into their corresponding
206160
index in the vocabulary. """
207161
vocab_path = 'vocab.' + mode
208-
in_path = data + '.' + mode + '.tok'
209-
out_path = data + '.' + mode + '.ids'
162+
in_path = data + '.' + mode
163+
out_path = data + '_ids.' + mode
210164

211165
_, vocab = load_vocab(os.path.join(config.PROCESSED_PATH, vocab_path))
212166
in_file = open(os.path.join(config.PROCESSED_PATH, in_path), 'r')
@@ -233,8 +187,8 @@ def prepare_raw_data():
233187

234188
def process_data():
235189
print('Preparing data to be model-ready ...')
236-
build_vocab('train.enc.tok')
237-
build_vocab('train.dec.tok')
190+
build_vocab('train.enc')
191+
build_vocab('train.dec')
238192
token2id('train', 'enc')
239193
token2id('train', 'dec')
240194
token2id('test', 'enc')
@@ -304,5 +258,4 @@ def get_batch(data_bucket, bucket_id, batch_size=1):
304258

305259
if __name__ == '__main__':
306260
prepare_raw_data()
307-
tokenize_data()
308261
process_data()

0 commit comments

Comments
 (0)