Skip to content

Commit 96c6033

Browse files
authored
0.1.5 + minor fixes (#106)
* Modified parameter order of DecoderRNN.forward (#85) * Updated TopKDecoder (#86) * Fixed topk decoder. * Use torchtext from pipy (#87) * Use torchtext from pipe. * Fixed torch text sorting order. * attention is not required when only using teacher forcing in decoder (#90) * attention is not required when only using teacher forcing in decoder * Updated docs and version. * Fixed code style. * bugfix (#92) Fixed field arguments validation. * Removed `initial_lr` when resuming optimizer with scheduler. (#95) * shuffle the training data (#97) * 0.1.5 (#91) * Modified parameter order of DecoderRNN.forward (#85) * Updated TopKDecoder (#86) * Fixed topk decoder. * Use torchtext from pipy (#87) * Use torchtext from pipe. * Fixed torch text sorting order. * attention is not required when only using teacher forcing in decoder (#90) * attention is not required when only using teacher forcing in decoder * Updated docs and version. * Fixed code style. * shuffle the training data * fix example of inflate function in TopKDecoer.py (#98) * fix example of inflate function in TopKDecoer.py * Fix hidden_layer size for one-directional decoder (#99) * Fix hidden_layer size for one-directional decoder Hidden layer size of the decoder was given `hidden_size * 2 if bidirectional else 1`, resulting in a dimensionality error for non-bidirectional decoders. Changed `1` to `hidden_size`. * Adapt load to allow CPU loading of GPU models (#100) * Adapt load to allow CPU loading of GPU models Add storage parameter to torch.load to allow loading models on a CPU that are trained on the GPU, depending on availability of cuda. * Fix wrong parameter use on DecoderRNN (#103) * Fix wrong parameter use on DecoderRNN
1 parent e8250fb commit 96c6033

File tree

7 files changed

+20
-15
lines changed

7 files changed

+20
-15
lines changed

examples/sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def len_filter(example):
100100
bidirectional = True
101101
encoder = EncoderRNN(len(src.vocab), max_len, hidden_size,
102102
bidirectional=bidirectional, variable_lengths=True)
103-
decoder = DecoderRNN(len(tgt.vocab), max_len, hidden_size * 2 if bidirectional else 1,
103+
decoder = DecoderRNN(len(tgt.vocab), max_len, hidden_size * 2 if bidirectional else hidden_size,
104104
dropout_p=0.2, use_attention=True, bidirectional=bidirectional,
105105
eos_id=tgt.eos_id, sos_id=tgt.sos_id)
106106
seq2seq = Seq2seq(encoder, decoder)

seq2seq/dataset/fields.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(self, **kwargs):
1111
if kwargs.get('batch_first') is False:
1212
logger.warning("Option batch_first has to be set to use pytorch-seq2seq. Changed to True.")
1313
kwargs['batch_first'] = True
14-
if kwargs.get('batch_first') is False:
14+
if kwargs.get('include_lengths') is False:
1515
logger.warning("Option include_lengths has to be set to use pytorch-seq2seq. Changed to True.")
1616
kwargs['include_lengths'] = True
1717

seq2seq/models/DecoderRNN.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def decode(step, step_output, step_attn):
131131
eos_batches = symbols.data.eq(self.eos_id)
132132
if eos_batches.dim() > 0:
133133
eos_batches = eos_batches.cpu().view(-1).numpy()
134-
update_idx = ((lengths > di) & eos_batches) != 0
134+
update_idx = ((lengths > step) & eos_batches) != 0
135135
lengths[update_idx] = len(sequence_symbols)
136136
return symbols
137137

seq2seq/models/TopKDecoder.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def _inflate(tensor, times, dim):
99
Args:
1010
tensor: A :class:`Tensor` to inflate
1111
times: number of repetitions
12-
dimension: axis for inflation (default=0)
12+
dim: axis for inflation (default=0)
1313
1414
Returns:
1515
A :class:`Tensor`
@@ -20,17 +20,16 @@ def _inflate(tensor, times, dim):
2020
1 2
2121
3 4
2222
[torch.LongTensor of size 2x2]
23-
>> decoder = TopKDecoder(nn.RNN(10, 20, 2), 3)
24-
>> b = decoder._inflate(a, 1, dimension=1)
23+
>> b = ._inflate(a, 2, dim=1)
2524
>> b
26-
1 1 2 2
27-
3 3 4 4
25+
1 2 1 2
26+
3 4 3 4
2827
[torch.LongTensor of size 2x4]
29-
>> c = decoder._inflate(a, 1, dimension=0)
28+
>> c = _inflate(a, 2, dim=0)
3029
>> c
3130
1 2
32-
1 2
3331
3 4
32+
1 2
3433
3 4
3534
[torch.LongTensor of size 4x2]
3635

seq2seq/models/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class Attention(nn.Module):
1010
.. math::
1111
\begin{array}{ll}
1212
x = context*output \\
13-
attn = exp(x_i - max_i x_i) / sum_j exp(x_j - max_i x_i) \\
13+
attn = exp(x_i) / sum_j exp(x_j) \\
1414
output = \tanh(w * (attn * context) + b * output)
1515
\end{array}
1616

seq2seq/trainer/supervised_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def _train_epoches(self, data, model, n_epochs, start_epoch, start_step,
7575
device = None if torch.cuda.is_available() else -1
7676
batch_iterator = torchtext.data.BucketIterator(
7777
dataset=data, batch_size=self.batch_size,
78-
sort=True, sort_key=lambda x: len(x.src),
78+
sort=False, sort_within_batch=True,
79+
sort_key=lambda x: len(x.src),
7980
device=device, repeat=False)
8081

8182
steps_per_epoch = len(batch_iterator)
@@ -166,6 +167,7 @@ def train(self, model, data, num_epochs=5,
166167
resume_optim = self.optimizer.optimizer
167168
defaults = resume_optim.param_groups[0]
168169
defaults.pop('params', None)
170+
defaults.pop('initial_lr', None)
169171
self.optimizer.optimizer = resume_optim.__class__(model.parameters(), **defaults)
170172

171173
start_epoch = resume_checkpoint.epoch

seq2seq/util/checkpoint.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,13 @@ def load(cls, path):
9191
Returns:
9292
checkpoint (Checkpoint): checkpoint object with fields copied from those stored on disk
9393
"""
94-
print("Loading checkpoints from {}".format(path))
95-
resume_checkpoint = torch.load(os.path.join(path, cls.TRAINER_STATE_NAME))
96-
model = torch.load(os.path.join(path, cls.MODEL_NAME))
94+
if torch.cuda.is_available():
95+
resume_checkpoint = torch.load(os.path.join(path, cls.TRAINER_STATE_NAME))
96+
model = torch.load(os.path.join(path, cls.MODEL_NAME))
97+
else:
98+
resume_checkpoint = torch.load(os.path.join(path, cls.TRAINER_STATE_NAME), map_location=lambda storage, loc: storage)
99+
model = torch.load(os.path.join(path, cls.MODEL_NAME), map_location=lambda storage, loc: storage)
100+
97101
model.flatten_parameters() # make RNN parameters contiguous
98102
with open(os.path.join(path, cls.INPUT_VOCAB_FILE), 'rb') as fin:
99103
input_vocab = dill.load(fin)

0 commit comments

Comments
 (0)