-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloader.py
28 lines (22 loc) · 1.32 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from torchtext import data
from torchtext import datasets
from torchtext.vocab import GloVe
class Dataloader:
def __init__(self, name, length, dimension, vocab_size, batch_sizes):
self.TEXT = data.Field(lower=True, fix_length=length, batch_first=True)
self.LABEL = data.Field(sequential=False,)
# SST-2
if name == 'SST':
self.train, self.dev, self.test = data.TabularDataset.splits(
path='SST-2', train='train.tsv', validation='dev.tsv',
test='test.tsv', format='tsv', skip_header=True,
fields=[('text', self.TEXT), ('label', self.LABEL)])
print("the size of train: {}, dev:{}, test:{}".format(len(self.train.examples), len(self.dev.examples), len(self.test.examples)))
self.TEXT.build_vocab(self.train, vectors=GloVe(name='6B', dim=dimension), max_size=vocab_size)
self.LABEL.build_vocab(self.train,)
print("train.fields:", self.train.fields, self.TEXT.vocab.vectors.shape)
self.train_iter, self.dev_iter, self.test_iter = data.BucketIterator.splits(
(self.train, self.dev, self.test), batch_sizes=batch_sizes, sort_key=lambda x: len(x.text), sort_within_batch=True, repeat=False
)
self.train_iter.repeat = False
self.test_iter.repeat = False