1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # @Time : 17-4-27 下午8:43
4
+ # @Author : Tianyu Liu
5
+
6
+ import tensorflow as tf
7
+ import time
8
+ import numpy as np
9
+ import re , time , os
10
+
11
+
12
+ class DataLoader (object ):
13
+ def __init__ (self , data_dir , limits ):
14
+ self .train_data_path = [data_dir + '/train/train.summary.id' , data_dir + '/train/train.box.val.id' ,
15
+ data_dir + '/train/train.box.lab.id' , data_dir + '/train/train.box.pos' ,
16
+ data_dir + '/train/train.box.rpos' ]
17
+ self .test_data_path = [data_dir + '/test/test.summary.id' , data_dir + '/test/test.box.val.id' ,
18
+ data_dir + '/test/test.box.lab.id' , data_dir + '/test/test.box.pos' ,
19
+ data_dir + '/test/test.box.rpos' ]
20
+ self .dev_data_path = [data_dir + '/valid/valid.summary.id' , data_dir + '/valid/valid.box.val.id' ,
21
+ data_dir + '/valid/valid.box.lab.id' , data_dir + '/valid/valid.box.pos' ,
22
+ data_dir + '/valid/valid.box.rpos' ]
23
+ self .limits = limits
24
+ self .man_text_len = 100
25
+ start_time = time .time ()
26
+
27
+ print ('Reading datasets ...' )
28
+ # self.train_set = self.load_data(self.train_data_path)
29
+ self .test_set = self .load_data (self .test_data_path )
30
+ # self.small_test_set = self.load_data(self.small_test_data_path)
31
+ # self.dev_set = self.load_data(self.dev_data_path)
32
+ print ('Reading datasets comsumes %.3f seconds' % (time .time () - start_time ))
33
+
34
+ def load_data (self , path ):
35
+ summary_path , text_path , field_path , pos_path , rpos_path = path
36
+ summaries = open (summary_path , 'r' ).read ().strip ().split ('\n ' )
37
+ texts = open (text_path , 'r' ).read ().strip ().split ('\n ' )
38
+ fields = open (field_path , 'r' ).read ().strip ().split ('\n ' )
39
+ poses = open (pos_path , 'r' ).read ().strip ().split ('\n ' )
40
+ rposes = open (rpos_path , 'r' ).read ().strip ().split ('\n ' )
41
+ if self .limits > 0 :
42
+ summaries = summaries [:self .limits ]
43
+ texts = texts [:self .limits ]
44
+ fields = fields [:self .limits ]
45
+ poses = poses [:self .limits ]
46
+ rposes = rposes [:self .limits ]
47
+ print summaries [0 ].strip ().split (' ' )
48
+ summaries = [list (map (int , summary .strip ().split (' ' ))) for summary in summaries ]
49
+
50
+ print ('................texts.................' , texts [0 ])
51
+ texts = [list (map (int , text .strip ().split (' ' ))) for text in texts ]
52
+ fields = [list (map (int , field .strip ().split (' ' ))) for field in fields ]
53
+ poses = [list (map (int , pos .strip ().split (' ' ))) for pos in poses ]
54
+ rposes = [list (map (int , rpos .strip ().split (' ' ))) for rpos in rposes ]
55
+ return summaries , texts , fields , poses , rposes
56
+
57
+ def single_test (self ):
58
+ class Vocab (object ):
59
+ def __init__ (self ):
60
+ vocab = dict ()
61
+ vocab ['PAD' ] = 0
62
+ vocab ['START_TOKEN' ] = 1
63
+ vocab ['END_TOKEN' ] = 2
64
+ vocab ['UNK_TOKEN' ] = 3
65
+ cnt = 4
66
+ with open ("original_data/word_vocab.txt" , "r" ) as v :
67
+ for line in v :
68
+ word = line .strip ().split ()[0 ]
69
+ vocab [word ] = cnt
70
+ cnt += 1
71
+ self ._word2id = vocab
72
+ self ._id2word = {value : key for key , value in vocab .items ()}
73
+
74
+ key_map = dict ()
75
+ key_map ['PAD' ] = 0
76
+ key_map ['START_TOKEN' ] = 1
77
+ key_map ['END_TOKEN' ] = 2
78
+ key_map ['UNK_TOKEN' ] = 3
79
+ cnt = 4
80
+ with open ("original_data/field_vocab.txt" , "r" ) as v :
81
+ for line in v :
82
+ key = line .strip ().split ()[0 ]
83
+ key_map [key ] = cnt
84
+ cnt += 1
85
+ self ._key2id = key_map
86
+ self ._id2key = {value : key for key , value in key_map .items ()}
87
+
88
+ def word2id (self , word ):
89
+ ans = self ._word2id [word ] if word in self ._word2id else 3
90
+ return ans
91
+
92
+ def id2word (self , id ):
93
+ ans = self ._id2word [int (id )]
94
+ return ans
95
+
96
+ def key2id (self , key ):
97
+ ans = self ._key2id [key ] if key in self ._key2id else 3
98
+ return ans
99
+
100
+ def id2key (self , id ):
101
+ ans = self ._id2key [int (id )]
102
+ return ans
103
+
104
+
105
+ fboxes = "original_data/test.box"
106
+
107
+ mixb_word , mixb_label , mixb_pos = [], [], []
108
+
109
+ box = open (fboxes , "r" ).read ().strip ().split ('\n ' )
110
+ box_word , box_label , box_pos = [], [], []
111
+
112
+ print (box [0 ])
113
+ ib = box [0 ]
114
+
115
+ item = ib .split ('\t ' )
116
+ box_single_word , box_single_label , box_single_pos = [], [], []
117
+ for it in item :
118
+ if len (it .split (':' )) > 2 :
119
+ continue
120
+ # print it
121
+ prefix , word = it .split (':' )
122
+ if '<none>' in word or word .strip ()== '' or prefix .strip ()== '' :
123
+ continue
124
+ new_label = re .sub ("_[1-9]\d*$" , "" , prefix )
125
+ if new_label .strip () == "" :
126
+ continue
127
+ box_single_word .append (word )
128
+ box_single_label .append (new_label )
129
+ if re .search ("_[1-9]\d*$" , prefix ):
130
+ field_id = int (prefix .split ('_' )[- 1 ])
131
+ box_single_pos .append (field_id if field_id <= 30 else 30 )
132
+ else :
133
+ box_single_pos .append (1 )
134
+ box_word .append (box_single_word )
135
+ box_label .append (box_single_label )
136
+ box_pos .append (box_single_pos )
137
+
138
+
139
+ ######################## reverse box #############################
140
+ box = box_pos
141
+ tmp_pos = []
142
+ single_pos = []
143
+ reverse_pos = []
144
+ for pos in box :
145
+ tmp_pos = []
146
+ single_pos = []
147
+ for p in pos :
148
+ if int (p ) == 1 and len (tmp_pos ) != 0 :
149
+ single_pos .extend (tmp_pos [::- 1 ])
150
+ tmp_pos = []
151
+ tmp_pos .append (p )
152
+ single_pos .extend (tmp_pos [::- 1 ])
153
+ reverse_pos = single_pos
154
+
155
+
156
+
157
+ vocab = Vocab ()
158
+
159
+ texts = (" " .join ([str (vocab .word2id (word )) for word in box_word [0 ]]) + '\n ' )
160
+ text = list (map (int ,texts .strip ().split (' ' )))
161
+ print (text )
162
+
163
+ fields = (" " .join ([str (vocab .key2id (word )) for word in box_label [0 ]]) + '\n ' )
164
+ field = list (map (int ,fields .strip ().split (' ' )))
165
+ print (field )
166
+
167
+ pos = box_pos [0 ]
168
+ print (pos )
169
+
170
+ rpos = reverse_pos
171
+ print (rpos )
172
+
173
+ text_len = len (text )
174
+ pos_len = len (pos )
175
+ rpos_len = len (rpos )
176
+
177
+ batch_data = {'enc_in' :[], 'enc_fd' :[], 'enc_pos' :[], 'enc_rpos' :[], 'enc_len' :[],
178
+ 'dec_in' :[], 'dec_len' :[], 'dec_out' :[]}
179
+
180
+ batch_data ['enc_in' ].append (text )
181
+ batch_data ['enc_len' ].append (text_len )
182
+ batch_data ['enc_fd' ].append (field )
183
+ batch_data ['enc_pos' ].append (pos )
184
+ batch_data ['enc_rpos' ].append (rpos )
185
+
186
+ yield batch_data
187
+
188
+ def batch_iter (self , data , batch_size , shuffle ):
189
+ summaries , texts , fields , poses , rposes = data
190
+
191
+ data_size = len (summaries )
192
+ num_batches = int (data_size / batch_size ) if data_size % batch_size == 0 \
193
+ else int (data_size / batch_size ) + 1
194
+
195
+ if shuffle :
196
+ shuffle_indices = np .random .permutation (np .arange (data_size ))
197
+ summaries = np .array (summaries )[shuffle_indices ]
198
+ texts = np .array (texts )[shuffle_indices ]
199
+ fields = np .array (fields )[shuffle_indices ]
200
+ poses = np .array (poses )[shuffle_indices ]
201
+ rposes = np .array (rposes )[shuffle_indices ]
202
+
203
+ for batch_num in range (num_batches ):
204
+ start_index = batch_num * batch_size
205
+ end_index = min ((batch_num + 1 ) * batch_size , data_size )
206
+ max_summary_len = max ([len (sample ) for sample in summaries [start_index :end_index ]])
207
+ max_text_len = max ([len (sample ) for sample in texts [start_index :end_index ]])
208
+ batch_data = {'enc_in' :[], 'enc_fd' :[], 'enc_pos' :[], 'enc_rpos' :[], 'enc_len' :[],
209
+ 'dec_in' :[], 'dec_len' :[], 'dec_out' :[]}
210
+
211
+ for summary , text , field , pos , rpos in zip (summaries [start_index :end_index ], texts [start_index :end_index ],
212
+ fields [start_index :end_index ], poses [start_index :end_index ],
213
+ rposes [start_index :end_index ]):
214
+ summary_len = len (summary )
215
+ text_len = len (text )
216
+ pos_len = len (pos )
217
+ rpos_len = len (rpos )
218
+ assert text_len == len (field )
219
+ assert pos_len == len (field )
220
+ assert rpos_len == pos_len
221
+ gold = summary + [2 ] + [0 ] * (max_summary_len - summary_len )
222
+ summary = summary + [0 ] * (max_summary_len - summary_len )
223
+ text = text + [0 ] * (max_text_len - text_len )
224
+ field = field + [0 ] * (max_text_len - text_len )
225
+ pos = pos + [0 ] * (max_text_len - text_len )
226
+ rpos = rpos + [0 ] * (max_text_len - text_len )
227
+
228
+ if max_text_len > self .man_text_len :
229
+ text = text [:self .man_text_len ]
230
+ field = field [:self .man_text_len ]
231
+ pos = pos [:self .man_text_len ]
232
+ rpos = rpos [:self .man_text_len ]
233
+ text_len = min (text_len , self .man_text_len )
234
+
235
+ batch_data ['enc_in' ].append (text )
236
+ batch_data ['enc_len' ].append (text_len )
237
+ batch_data ['enc_fd' ].append (field )
238
+ batch_data ['enc_pos' ].append (pos )
239
+ batch_data ['enc_rpos' ].append (rpos )
240
+ batch_data ['dec_in' ].append (summary )
241
+ batch_data ['dec_len' ].append (summary_len )
242
+ batch_data ['dec_out' ].append (gold )
243
+ yield batch_data
0 commit comments