|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +import random |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.nn as nn |
| 7 | +import torch.nn.functional as F |
| 8 | + |
| 9 | + |
| 10 | +class EncRNN(nn.Module): |
| 11 | + def __init__(self, vsz, embed_dim, hidden_dim, n_layers, use_birnn, dout): |
| 12 | + super(EncRNN, self).__init__() |
| 13 | + self.embed = nn.Embedding(vsz, embed_dim) |
| 14 | + self.rnn = nn.LSTM(embed_dim, hidden_dim, n_layers, |
| 15 | + bidirectional=use_birnn) |
| 16 | + self.dropout = nn.Dropout(dout) |
| 17 | + |
| 18 | + def forward(self, inputs): |
| 19 | + embs = self.dropout(self.embed(inputs)) |
| 20 | + enc_outs, hidden = self.rnn(embs) |
| 21 | + return self.dropout(enc_outs), hidden |
| 22 | + |
| 23 | + |
| 24 | +class Attention(nn.Module): |
| 25 | + def __init__(self, hidden_dim, method): |
| 26 | + super(Attention, self).__init__() |
| 27 | + self.method = method |
| 28 | + self.hidden_dim = hidden_dim |
| 29 | + |
| 30 | + if method == 'general': |
| 31 | + self.w = nn.Linear(hidden_dim, hidden_dim) |
| 32 | + elif method == 'concat': |
| 33 | + self.w = nn.Linear(hidden_dim*2, hidden_dim) |
| 34 | + self.v = torch.nn.Parameter(torch.FloatTensor(hidden_dim)) |
| 35 | + |
| 36 | + def forward(self, dec_out, enc_outs): |
| 37 | + if self.method == 'dot': |
| 38 | + attn_energies = self.dot(dec_out, enc_outs) |
| 39 | + elif self.method == 'general': |
| 40 | + attn_energies = self.general(dec_out, enc_outs) |
| 41 | + elif self.method == 'concat': |
| 42 | + attn_energies = self.concat(dec_out, enc_outs) |
| 43 | + return F.softmax(attn_energies, dim=0) |
| 44 | + |
| 45 | + def dot(self, dec_out, enc_outs): |
| 46 | + return torch.sum(dec_out*enc_outs, dim=2) |
| 47 | + |
| 48 | + def general(self, dec_out, enc_outs): |
| 49 | + energy = self.w(enc_outs) |
| 50 | + return torch.sum(dec_out*energy, dim=2) |
| 51 | + |
| 52 | + def concat(self, dec_out, enc_outs): |
| 53 | + dec_out = dec_out.expand(enc_outs.shape[0], -1, -1) |
| 54 | + energy = torch.cat((dec_out, enc_outs), 2) |
| 55 | + return torch.sum(self.v * self.w(energy).tanh(), dim=2) |
| 56 | + |
| 57 | + |
| 58 | +class DecRNN(nn.Module): |
| 59 | + def __init__(self, vsz, embed_dim, hidden_dim, n_layers, use_birnn, |
| 60 | + dout, attn, tied): |
| 61 | + super(DecRNN, self).__init__() |
| 62 | + hidden_dim = hidden_dim*2 if use_birnn else hidden_dim |
| 63 | + |
| 64 | + self.embed = nn.Embedding(vsz, embed_dim) |
| 65 | + self.rnn = nn.LSTM(embed_dim, hidden_dim , n_layers) |
| 66 | + |
| 67 | + self.w = nn.Linear(hidden_dim*2, hidden_dim) |
| 68 | + self.attn = Attention(hidden_dim, attn) |
| 69 | + |
| 70 | + self.out_projection = nn.Linear(hidden_dim, vsz) |
| 71 | + if tied: |
| 72 | + if embed_dim != hidden_dim: |
| 73 | + raise ValueError( |
| 74 | + f"when using the tied flag, embed-dim:{embed_dim} \ |
| 75 | + must be equal to hidden-dim:{hidden_dim}") |
| 76 | + self.out_projection.weight = self.embed.weight |
| 77 | + self.dropout = nn.Dropout(dout) |
| 78 | + |
| 79 | + def forward(self, inputs, hidden, enc_outs): |
| 80 | + inputs = inputs.unsqueeze(0) |
| 81 | + embs = self.dropout(self.embed(inputs)) |
| 82 | + dec_out, hidden = self.rnn(embs, hidden) |
| 83 | + |
| 84 | + attn_weights = self.attn(dec_out, enc_outs).transpose(1, 0) |
| 85 | + enc_outs = enc_outs.transpose(1, 0) |
| 86 | + context = torch.bmm(attn_weights.unsqueeze(1), enc_outs) |
| 87 | + cats = self.w(torch.cat((dec_out, context.transpose(1, 0)), dim=2)) |
| 88 | + pred = self.out_projection(cats.tanh().squeeze(0)) |
| 89 | + return pred, hidden |
| 90 | + |
| 91 | + |
| 92 | +class Seq2seqAttn(nn.Module): |
| 93 | + def __init__(self, args, fields, device): |
| 94 | + super().__init__() |
| 95 | + self.src_field, self.tgt_field = fields |
| 96 | + self.src_vsz = len(self.src_field[1].vocab.itos) |
| 97 | + self.tgt_vsz = len(self.tgt_field[1].vocab.itos) |
| 98 | + self.encoder = EncRNN(self.src_vsz, args.embed_dim, args.hidden_dim, |
| 99 | + args.n_layers, args.bidirectional, args.dropout) |
| 100 | + self.decoder = DecRNN(self.tgt_vsz, args.embed_dim, args.hidden_dim, |
| 101 | + args.n_layers, args.bidirectional, args.dropout, |
| 102 | + args.attn, args.tied) |
| 103 | + self.device = device |
| 104 | + self.n_layers = args.n_layers |
| 105 | + self.hidden_dim = args.hidden_dim |
| 106 | + self.use_birnn = args.bidirectional |
| 107 | + |
| 108 | + def forward(self, srcs, tgts=None, maxlen=100, tf_ratio=0.0): |
| 109 | + slen, bsz = srcs.size() |
| 110 | + tlen = tgts.size(0) if isinstance(tgts, torch.Tensor) else maxlen |
| 111 | + tf_ratio = tf_ratio if isinstance(tgts, torch.Tensor) else 0.0 |
| 112 | + |
| 113 | + enc_outs, hidden = self.encoder(srcs) |
| 114 | + |
| 115 | + dec_inputs = torch.ones_like(srcs[0]) * 2 # <eos> is mapped to id=2 |
| 116 | + outs = [] |
| 117 | + |
| 118 | + if self.use_birnn: |
| 119 | + def trans_hidden(hs): |
| 120 | + hs = hs.view(self.n_layers, 2, bsz, self.hidden_dim) |
| 121 | + hs = torch.stack([torch.cat((h[0], h[1]), 1) for h in hs]) |
| 122 | + return hs |
| 123 | + hidden = tuple(trans_hidden(hs) for hs in hidden) |
| 124 | + |
| 125 | + for i in range(tlen): |
| 126 | + preds, hidden = self.decoder(dec_inputs, hidden, enc_outs) |
| 127 | + outs.append(preds) |
| 128 | + use_tf = random.random() < tf_ratio |
| 129 | + dec_inputs = tgts[i] if use_tf else preds.max(1)[1] |
| 130 | + return torch.stack(outs) |
0 commit comments