Skip to content

Commit 65ecddf

Browse files
author
maru0kun
committed
first commit
0 parents  commit 65ecddf

File tree

7 files changed

+10487
-0
lines changed

7 files changed

+10487
-0
lines changed

README.md

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Effective Approaches to Attention-based Neural Machine Translation
2+
Encoder-Decoder model with global attention mechanismのpytorch実装。
3+
4+
## Model Details
5+
- LSTM-based encoder-decoder model
6+
- global attention (see Figure 2 in original paper)
7+
- scheduled sampling
8+
9+
10+
## Usages
11+
学習
12+
```python
13+
python train.py \
14+
--gpu
15+
--train ./sample_data/sample_train.py
16+
--valid ./sample_data/sample_valid.py
17+
--tf-ratio 0.5
18+
--savedir ./checkpoints
19+
```
20+
21+
翻訳
22+
```python
23+
python translate.py \
24+
--gpu
25+
--model ./checkpoints/checkpoint_best.pt
26+
--input ./sample_data/sample_test.txt
27+
```
28+
29+
## References
30+
- [Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/pdf/1508.04025.pdf)
31+
- [Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks](https://arxiv.org/pdf/1506.03099.pdf)

model.py

+130
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import random
4+
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
9+
10+
class EncRNN(nn.Module):
11+
def __init__(self, vsz, embed_dim, hidden_dim, n_layers, use_birnn, dout):
12+
super(EncRNN, self).__init__()
13+
self.embed = nn.Embedding(vsz, embed_dim)
14+
self.rnn = nn.LSTM(embed_dim, hidden_dim, n_layers,
15+
bidirectional=use_birnn)
16+
self.dropout = nn.Dropout(dout)
17+
18+
def forward(self, inputs):
19+
embs = self.dropout(self.embed(inputs))
20+
enc_outs, hidden = self.rnn(embs)
21+
return self.dropout(enc_outs), hidden
22+
23+
24+
class Attention(nn.Module):
25+
def __init__(self, hidden_dim, method):
26+
super(Attention, self).__init__()
27+
self.method = method
28+
self.hidden_dim = hidden_dim
29+
30+
if method == 'general':
31+
self.w = nn.Linear(hidden_dim, hidden_dim)
32+
elif method == 'concat':
33+
self.w = nn.Linear(hidden_dim*2, hidden_dim)
34+
self.v = torch.nn.Parameter(torch.FloatTensor(hidden_dim))
35+
36+
def forward(self, dec_out, enc_outs):
37+
if self.method == 'dot':
38+
attn_energies = self.dot(dec_out, enc_outs)
39+
elif self.method == 'general':
40+
attn_energies = self.general(dec_out, enc_outs)
41+
elif self.method == 'concat':
42+
attn_energies = self.concat(dec_out, enc_outs)
43+
return F.softmax(attn_energies, dim=0)
44+
45+
def dot(self, dec_out, enc_outs):
46+
return torch.sum(dec_out*enc_outs, dim=2)
47+
48+
def general(self, dec_out, enc_outs):
49+
energy = self.w(enc_outs)
50+
return torch.sum(dec_out*energy, dim=2)
51+
52+
def concat(self, dec_out, enc_outs):
53+
dec_out = dec_out.expand(enc_outs.shape[0], -1, -1)
54+
energy = torch.cat((dec_out, enc_outs), 2)
55+
return torch.sum(self.v * self.w(energy).tanh(), dim=2)
56+
57+
58+
class DecRNN(nn.Module):
59+
def __init__(self, vsz, embed_dim, hidden_dim, n_layers, use_birnn,
60+
dout, attn, tied):
61+
super(DecRNN, self).__init__()
62+
hidden_dim = hidden_dim*2 if use_birnn else hidden_dim
63+
64+
self.embed = nn.Embedding(vsz, embed_dim)
65+
self.rnn = nn.LSTM(embed_dim, hidden_dim , n_layers)
66+
67+
self.w = nn.Linear(hidden_dim*2, hidden_dim)
68+
self.attn = Attention(hidden_dim, attn)
69+
70+
self.out_projection = nn.Linear(hidden_dim, vsz)
71+
if tied:
72+
if embed_dim != hidden_dim:
73+
raise ValueError(
74+
f"when using the tied flag, embed-dim:{embed_dim} \
75+
must be equal to hidden-dim:{hidden_dim}")
76+
self.out_projection.weight = self.embed.weight
77+
self.dropout = nn.Dropout(dout)
78+
79+
def forward(self, inputs, hidden, enc_outs):
80+
inputs = inputs.unsqueeze(0)
81+
embs = self.dropout(self.embed(inputs))
82+
dec_out, hidden = self.rnn(embs, hidden)
83+
84+
attn_weights = self.attn(dec_out, enc_outs).transpose(1, 0)
85+
enc_outs = enc_outs.transpose(1, 0)
86+
context = torch.bmm(attn_weights.unsqueeze(1), enc_outs)
87+
cats = self.w(torch.cat((dec_out, context.transpose(1, 0)), dim=2))
88+
pred = self.out_projection(cats.tanh().squeeze(0))
89+
return pred, hidden
90+
91+
92+
class Seq2seqAttn(nn.Module):
93+
def __init__(self, args, fields, device):
94+
super().__init__()
95+
self.src_field, self.tgt_field = fields
96+
self.src_vsz = len(self.src_field[1].vocab.itos)
97+
self.tgt_vsz = len(self.tgt_field[1].vocab.itos)
98+
self.encoder = EncRNN(self.src_vsz, args.embed_dim, args.hidden_dim,
99+
args.n_layers, args.bidirectional, args.dropout)
100+
self.decoder = DecRNN(self.tgt_vsz, args.embed_dim, args.hidden_dim,
101+
args.n_layers, args.bidirectional, args.dropout,
102+
args.attn, args.tied)
103+
self.device = device
104+
self.n_layers = args.n_layers
105+
self.hidden_dim = args.hidden_dim
106+
self.use_birnn = args.bidirectional
107+
108+
def forward(self, srcs, tgts=None, maxlen=100, tf_ratio=0.0):
109+
slen, bsz = srcs.size()
110+
tlen = tgts.size(0) if isinstance(tgts, torch.Tensor) else maxlen
111+
tf_ratio = tf_ratio if isinstance(tgts, torch.Tensor) else 0.0
112+
113+
enc_outs, hidden = self.encoder(srcs)
114+
115+
dec_inputs = torch.ones_like(srcs[0]) * 2 # <eos> is mapped to id=2
116+
outs = []
117+
118+
if self.use_birnn:
119+
def trans_hidden(hs):
120+
hs = hs.view(self.n_layers, 2, bsz, self.hidden_dim)
121+
hs = torch.stack([torch.cat((h[0], h[1]), 1) for h in hs])
122+
return hs
123+
hidden = tuple(trans_hidden(hs) for hs in hidden)
124+
125+
for i in range(tlen):
126+
preds, hidden = self.decoder(dec_inputs, hidden, enc_outs)
127+
outs.append(preds)
128+
use_tf = random.random() < tf_ratio
129+
dec_inputs = tgts[i] if use_tf else preds.max(1)[1]
130+
return torch.stack(outs)

options.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# -*- coding: utf-8 -*-
2+
3+
4+
def train_opts(parser):
5+
group = parser.add_argument_group('Training')
6+
group.add_argument('--train', default='./sample_data/sample_train.tsv',
7+
help='path to a train data')
8+
group.add_argument('--valid', default='./sample_data/sample_valid.tsv',
9+
help='path to a validation data')
10+
group.add_argument('--batch-size', type=int, default=32,
11+
help='batch size')
12+
group.add_argument('--savedir', default='./checkpoints',
13+
help='path to save models')
14+
group.add_argument('--max-epoch', type=int, default=0,
15+
help='number of epochs')
16+
group.add_argument('--max-update', type=int, default=0,
17+
help='number of updates')
18+
group.add_argument('--lr', type=float, default=0.25,
19+
help='learning rate')
20+
group.add_argument('--min-lr', type=float, default=1e-5,
21+
help='minimum learning rate')
22+
group.add_argument('--clip', type=float, default=0.1,
23+
help='gradient cliping')
24+
group.add_argument('--tf-ratio', type=float, default=0.5,
25+
help='teaching force ratio')
26+
group.add_argument('--gpu', action='store_true',
27+
help='whether gpu is used')
28+
return group
29+
30+
31+
def translate_opts(parser):
32+
group = parser.add_argument_group('Translation')
33+
group.add_argument('--model', default='./checkpoints/checkpoint_best.pt',
34+
help='model file for translation')
35+
group.add_argument('--input', default='./sample_data/sample_test.txt',
36+
help='input file')
37+
group.add_argument('--batch-size', type=int, default=32,
38+
help='batch size')
39+
group.add_argument('--maxlen', type=int, default=100,
40+
help='maximum length of output sentence')
41+
group.add_argument('--gpu', action='store_true',
42+
help='whether gpu is used')
43+
return group
44+
45+
46+
def model_opts(parser):
47+
group = parser.add_argument_group('Model\'s hyper-parameters')
48+
group.add_argument('--embed-dim', type=int, default=200,
49+
help='dimension of word embeddings')
50+
group.add_argument('--src_min-freq', type=int, default=0,
51+
help='''map words of source side appearing less than
52+
threshold times to unknown''')
53+
group.add_argument('--tgt_min-freq', type=int, default=0,
54+
help='''map words of target side appearing less than
55+
threshold times to unknown''')
56+
group.add_argument('--rnn', choices=['lstm'], default='lstm',
57+
help='rnn\'s architechture')
58+
group.add_argument('--hidden-dim', type=int, default=1024,
59+
help='number of hidden units per layer')
60+
group.add_argument('--n-layers', type=int, default=2,
61+
help='number of LSTM layers')
62+
group.add_argument('--bidirectional', action='store_true',
63+
help='whether use bidirectional LSTM for encoder')
64+
group.add_argument('--attn', choices=['dot', 'general', 'concat'],
65+
default='dot', help='attention type')
66+
group.add_argument('--dropout', type=float, default=0.2,
67+
help='dropout applied to layers (0 means no dropout)')
68+
group.add_argument('--tied', action='store_true',
69+
help='tie the word embedding and softmax weight')
70+
return group
71+

0 commit comments

Comments
 (0)