Skip to content

Commit 44141c9

Browse files
committed
new version
1 parent 77cbf92 commit 44141c9

23 files changed

+616
-199
lines changed

config/grid_search_cnn.ini

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# -*- coding: utf-8 -*-
2+
3+
[COMMON]
4+
model = lstm;basic_cnn;kim_cnn;multi_cnn;inception_cnn;fasttext;rcnn;bilstm
5+
keep_dropout=0.8;0.9
6+
batch_size=64;32;128
7+
learning_rate=0.01;0.001
8+
optimizer = adam;rmsprop
9+
dataset = imdb

config/imdb.ini

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[COMMON]
2-
dataset = imdb
2+
dataset = imdb;sst
33

dataHelper.py

-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import random
1010
import time
1111
from utils import log_time_delta
12-
from tqdm import tqdm
1312
from dataloader import Dataset
1413
import torch
1514
from torch.autograd import Variable

main.py

+132-67
Original file line numberDiff line numberDiff line change
@@ -4,80 +4,145 @@
44
from __future__ import division
55
from __future__ import print_function
66

7-
import torch
8-
from torch.autograd import Variable
9-
import torch.optim as optim
107
import numpy as np
11-
8+
import pandas as pd
129
from six.moves import cPickle
10+
import time,os,random
11+
import itertools
1312

14-
import opts
15-
import models
13+
import torch
14+
from torch.autograd import Variable
15+
import torch.optim as optim
1616
import torch.nn as nn
17-
import utils
1817
import torch.nn.functional as F
19-
from torchtext import data
20-
from torchtext import datasets
21-
from torchtext.vocab import Vectors, GloVe, CharNGram, FastText
2218
from torch.nn.modules.loss import NLLLoss,MultiLabelSoftMarginLoss,MultiLabelMarginLoss,BCELoss
23-
import dataHelper
24-
import time,os
25-
26-
27-
from_torchtext = False
28-
29-
opt = opts.parse_opt()
30-
#opt.proxy="http://xxxx.xxxx.com:8080"
31-
32-
33-
if "CUDA_VISIBLE_DEVICES" not in os.environ.keys():
34-
os.environ["CUDA_VISIBLE_DEVICES"] =opt.gpu
35-
#opt.model ='lstm'
36-
#opt.model ='capsule'
3719

38-
if from_torchtext:
39-
train_iter, test_iter = utils.loadData(opt)
40-
else:
41-
import dataHelper as helper
42-
train_iter, test_iter = dataHelper.loadData(opt)
43-
44-
opt.lstm_layers=2
45-
46-
model=models.setup(opt)
47-
if torch.cuda.is_available():
48-
model.cuda()
49-
model.train()
50-
print("# parameters:", sum(param.numel() for param in model.parameters() if param.requires_grad))
51-
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.learning_rate)
52-
optimizer.zero_grad()
53-
loss_fun = F.cross_entropy
54-
55-
#batch = next(iter(train_iter))
56-
57-
#x=batch.text[0]
58-
59-
#x=batch.text[0] #64x200
60-
61-
#print(utils.evaluation(model,test_iter))
62-
for i in range(opt.max_epoch):
63-
for epoch,batch in enumerate(train_iter):
64-
start= time.time()
65-
66-
text = batch.text[0] if from_torchtext else batch.text
67-
predicted = model(text)
20+
import opts
21+
import models
22+
import utils
6823

69-
loss= loss_fun(predicted,batch.label)
7024

71-
loss.backward()
72-
utils.clip_gradient(optimizer, opt.grad_clip)
73-
optimizer.step()
74-
if epoch% 100==0:
75-
if torch.cuda.is_available():
76-
print("%d iteration %d epoch with loss : %.5f in %.4f seconds" % (i,epoch,loss.cpu().item(),time.time()-start))
77-
else:
78-
print("%d iteration %d epoch with loss : %.5f in %.4f seconds" % (i,epoch,loss.data.numpy()[0],time.time()-start))
25+
timeStamp = time.strftime("%Y%m%d%H%M%S", time.localtime(int(time.time()) ))
26+
performance_log_file = os.path.join("log","result"+timeStamp+ ".csv")
27+
if not os.path.exists(performance_log_file):
28+
with open(performance_log_file,"w") as f:
29+
f.write("argument\n")
30+
f.close()
31+
32+
33+
def train(opt,train_iter, test_iter,verbose=True):
34+
global_start= time.time()
35+
logger = utils.getLogger()
36+
model=models.setup(opt)
37+
if torch.cuda.is_available():
38+
model.cuda()
39+
params = [param for param in model.parameters() if param.requires_grad] #filter(lambda p: p.requires_grad, model.parameters())
40+
41+
model_info =";".join( [str(k)+":"+ str(v) for k,v in opt.__dict__.items() if type(v) in (str,int,float,list,bool)])
42+
logger.info("# parameters:" + str(sum(param.numel() for param in params)))
43+
logger.info(model_info)
44+
45+
46+
model.train()
47+
optimizer = utils.getOptimizer(params,name=opt.optimizer, lr=opt.learning_rate,scheduler= utils.get_lr_scheduler(opt.lr_scheduler))
48+
optimizer.zero_grad()
49+
loss_fun = F.cross_entropy
50+
51+
filename = None
52+
percisions=[]
53+
for i in range(opt.max_epoch):
54+
for epoch,batch in enumerate(train_iter):
55+
start= time.time()
56+
57+
text = batch.text[0] if opt.from_torchtext else batch.text
58+
predicted = model(text)
59+
60+
loss= loss_fun(predicted,batch.label)
61+
62+
loss.backward()
63+
utils.clip_gradient(optimizer, opt.grad_clip)
64+
optimizer.step()
65+
66+
if verbose:
67+
if torch.cuda.is_available():
68+
logger.info("%d iteration %d epoch with loss : %.5f in %.4f seconds" % (i,epoch,loss.cpu().data.numpy(),time.time()-start))
69+
else:
70+
logger.info("%d iteration %d epoch with loss : %.5f in %.4f seconds" % (i,epoch,loss.data.numpy()[0],time.time()-start))
7971

80-
percision=utils.evaluation(model,test_iter,from_torchtext)
81-
print("%d iteration with percision %.4f" % (i,percision))
82-
83-
72+
percision=utils.evaluation(model,test_iter,opt.from_torchtext)
73+
if verbose:
74+
logger.info("%d iteration with percision %.4f" % (i,percision))
75+
if len(percisions)==0 or percision > max(percisions):
76+
if filename:
77+
os.remove(filename)
78+
filename = model.save(metric=percision)
79+
percisions.append(percision)
80+
81+
# while(utils.is_writeable(performance_log_file)):
82+
df = pd.read_csv(performance_log_file,index_col=0,sep="\t")
83+
df.loc[model_info,opt.dataset] = max(percisions)
84+
df.to_csv(performance_log_file,sep="\t")
85+
logger.info(model_info +" with time :"+ str( time.time()-global_start)+" ->" +str( max(percisions) ) )
86+
print(model_info +" with time :"+ str( time.time()-global_start)+" ->" +str( max(percisions) ) )
87+
88+
def main():
89+
from_torchtext = False
90+
if "CUDA_VISIBLE_DEVICES" not in os.environ.keys():
91+
os.environ["CUDA_VISIBLE_DEVICES"] =opt.gpu
92+
#opt.model ='lstm'
93+
#opt.model ='capsule'
94+
if from_torchtext:
95+
train_iter, test_iter = utils.loadData(opt)
96+
else:
97+
import dataHelper
98+
train_iter, test_iter = dataHelper.loadData(opt)
99+
100+
model=models.setup(opt)
101+
print(opt.model)
102+
if torch.cuda.is_available():
103+
model.cuda()
104+
105+
106+
107+
train(opt,train_iter, test_iter)
108+
109+
if __name__=="__main__":
110+
parameter_pools = utils.parse_grid_parameters("config/grid_search_cnn.ini")
111+
112+
# parameter_pools={
113+
# "model":["lstm","cnn","fasttext"],
114+
# "keep_dropout":[0.8,0.9,1.0],
115+
# "batch_size":[32,64,128],
116+
# "learning_rate":[100,10,1,1e-1,1e-2,1e-3],
117+
# "optimizer":["adam"],
118+
# "lr_scheduler":[None]
119+
# }
120+
opt = opts.parse_opt()
121+
if "CUDA_VISIBLE_DEVICES" not in os.environ.keys():
122+
os.environ["CUDA_VISIBLE_DEVICES"] =opt.gpu
123+
train_iter, test_iter = utils.loadData(opt)
124+
# if from_torchtext:
125+
# train_iter, test_iter = utils.loadData(opt)
126+
# else:
127+
# import dataHelper
128+
# train_iter, test_iter = dataHelper.loadData(opt)
129+
if False:
130+
model=models.setup(opt)
131+
print(opt.model)
132+
if torch.cuda.is_available():
133+
model.cuda()
134+
train(opt,train_iter, test_iter)
135+
else:
136+
137+
pool =[ arg for arg in itertools.product(*parameter_pools.values())]
138+
random.shuffle(pool)
139+
args=[arg for i,arg in enumerate(pool) if i%opt.gpu_num==opt.gpu]
140+
141+
for arg in args:
142+
olddataset = opt.dataset
143+
for k,v in zip(parameter_pools.keys(),arg):
144+
opt.__setattr__(k,v)
145+
if "dataset" in parameter_pools and olddataset != opt.dataset:
146+
train_iter, test_iter = utils.loadData(opt)
147+
train(opt,train_iter, test_iter,verbose=False)
148+

models/BERTFast.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,18 @@
33
import numpy as np
44
from torch import nn
55
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
6-
7-
class BERTFast(nn.Module):
6+
from models.BaseModel import BaseModel
7+
class BERTFast(BaseModel):
88
def __init__(self, opt ):
9-
super(BERTFast, self).__init__()
9+
super(BERTFast, self).__init__(opt)
1010
self.model_name = 'bert'
1111
self.opt=opt
1212

1313
self.fc = nn.Linear(768, opt.label_size)
1414

15-
self.bert_model = BertModel.from_pretrained('bert-base-uncased')
15+
self.bert_model = BertModel.from_pretrained('bert-base-uncased')
16+
for param in self.bert_model.parameters():
17+
param.requires_grad=self.opt.bert_trained
1618
self.content_fc = nn.Sequential(
1719
nn.Linear(768,100),
1820
nn.BatchNorm1d(100),
@@ -22,12 +24,16 @@ def __init__(self, opt ):
2224
# nn.ReLU(inplace=True),
2325
nn.Linear(100,opt.label_size)
2426
)
27+
self.hidden2label = nn.Linear(768, opt.label_size)
28+
self.properties.update(
29+
{"bert_trained":self.opt.bert_trained
30+
})
2531

2632

2733
def forward(self, content):
2834
encoded, _ = self.bert_model(content)
2935
encoded_doc = t.mean(encoded[-1],dim=1)
30-
logits = self.content_fc(encoded_doc)
36+
logits = self.hidden2label(encoded_doc)
3137
return logits
3238

3339
import argparse

models/BaseModel.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import torch as t
4+
5+
import numpy as np
6+
from torch import nn
7+
from collections import OrderedDict
8+
import os
9+
class BaseModel(nn.Module):
10+
def __init__(self, opt ):
11+
super(BaseModel, self).__init__()
12+
self.model_name = 'BaseModel'
13+
self.opt=opt
14+
15+
self.encoder = nn.Embedding(opt.vocab_size,opt.embedding_dim)
16+
if opt.__dict__.get("embeddings",None) is not None:
17+
self.encoder.weight=nn.Parameter(opt.embeddings,requires_grad=opt.embedding_training)
18+
self.fc = nn.Linear(opt.embedding_dim, opt.label_size)
19+
20+
self.properties = {"model_name":self.__class__.__name__,
21+
"embedding_dim":self.opt.embedding_dim,
22+
"embedding_training":self.opt.embedding_training,
23+
"max_seq_len":self.opt.max_seq_len,
24+
"batch_size":self.opt.batch_size,
25+
"learning_rate":self.opt.learning_rate,
26+
"keep_dropout":self.opt.keep_dropout,
27+
}
28+
29+
def forward(self,content):
30+
content_=t.mean(self.encoder(content),dim=1)
31+
out=self.fc(content_.view(content_.size(0),-1))
32+
return out
33+
34+
35+
36+
def save(self,save_dir="saved_model",metric=None):
37+
if not os.path.exists(save_dir):
38+
os.mkdir(save_dir)
39+
self.model_info = "__".join([k+"_"+str(v) if type(v)!=list else k+"_"+str(v)[1:-1].replace(",","_").replace(",","") for k,v in self.properties.items() ])
40+
if metric:
41+
path = os.path.join(save_dir, str(metric) +"__"+ self.model_info)
42+
else:
43+
path = os.path.join(save_dir,self.model_info)
44+
t.save(self,path)
45+
return path
46+
47+
48+
49+
if __name__ == '__main__':
50+
import sys
51+
sys.path.append(r"..")
52+
import opts
53+
opt=opts.parse_opt()
54+
opt.vocab_size=2501
55+
opt.embedding_dim=300
56+
opt.label_size=3
57+
m = BaseModel(opt)
58+
59+
content = t.autograd.Variable(t.arange(0,2500).view(10,250)).long()
60+
o = m(content)
61+
print(o.size())
62+
path = m.save()

0 commit comments

Comments
 (0)