diff --git a/README.md b/README.md index f93768c38a..03de115b22 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,8 @@ https://pytorch.org/examples/ - [PyTorch Module Transformations using fx](./fx/README.md) - Distributed PyTorch examples with [Distributed Data Parallel](./distributed/ddp/README.md) and [RPC](./distributed/rpc) - [Several examples illustrating the C++ Frontend](cpp) -- [Image Classification Using Forward-Forward ](./mnist_forward_forward/README.md) +- [Image Classification Using Forward-Forward](./mnist_forward_forward/README.md) +- [Language Translation using Transformers](./language_translation/README.md) diff --git a/language_translation/README.md b/language_translation/README.md new file mode 100644 index 0000000000..32833fa00c --- /dev/null +++ b/language_translation/README.md @@ -0,0 +1,49 @@ +# Language Translation + +This example shows how one might use transformers for language translation. In particular, this implementation is loosely based on the [Attention is All You Need paper](https://arxiv.org/abs/1706.03762). + +## Requirements + +We will need a tokenizer for our languages. Torchtext does include a tokenizer for English, but unfortunately, we will need more languages then that. We can get these tokenizers via ```spacy``` + +```bash +python3 -m spacy download +python3 -m spacy download en +python3 -m spacy download de +``` + +Spacy supports many languages. For a full accounting of supported languages, please look [here](https://spacy.io/usage/models). This example will default from German to English. + +Torchtext is also required: +```bash +pip install torchtext +``` + +Just running these commands will get you started: +```bash +pip install -r requirements.txt +python3 -m spacy download +``` + +## Usage + +This example contains a lot of flags that you can set to change the behavior / training of the module. You can see all of them by running: + +```bash +python3 main.py -h +``` + +But in general, all of the settings have "sensible" defaults; however, the default translation is to translate from German to English. To *train* the model, you only need to run the following command, but there is also an example for how to use any language you want: + +```bash +python3 main.py +python3 main.py --src en --tgt fr # For english to french translation +``` + +For model inference, you can use this command: + +```bash +python3 main.py --inference --model_path +``` + +After some loading time, this will open an interactive interface where you can type in whatever sentence you are interested in translating. diff --git a/language_translation/main.py b/language_translation/main.py new file mode 100644 index 0000000000..2b4fbb94c3 --- /dev/null +++ b/language_translation/main.py @@ -0,0 +1,306 @@ +from time import time # Track how long an epoch takes +import os # Creating and finding files/directories +import logging # Logging tools +from datetime import date # Logging the date for model versioning + +import torch # For ML +from tqdm import tqdm # For fancy progress bars + +from src.model import Translator # Our model +from src.data import get_data, create_mask, generate_square_subsequent_mask # Loading data and data preprocessing +from argparse import ArgumentParser # For args + +# Train on the GPU if possible +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Function to generate output sequence using greedy algorithm +def greedy_decode(model, src, src_mask, max_len, start_symbol, end_symbol): + + # Move to device + src = src.to(DEVICE) + src_mask = src_mask.to(DEVICE) + + # Encode input + memory = model.encode(src, src_mask) + + # Output will be stored here + ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE) + + # For each element in our translation (which could range up to the maximum translation length) + for _ in range(max_len-1): + + # Decode the encoded representation of the input + memory = memory.to(DEVICE) + tgt_mask = (generate_square_subsequent_mask(ys.size(0), DEVICE).type(torch.bool)).to(DEVICE) + out = model.decode(ys, memory, tgt_mask) + + # Reshape + out = out.transpose(0, 1) + + # Covert to probabilities and take the max of these probabilities + prob = model.ff(out[:, -1]) + _, next_word = torch.max(prob, dim=1) + next_word = next_word.item() + + # Now we have an output which is the vector representation of the translation + ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0) + if next_word == end_symbol: + break + + return ys + +# Opens an user interface where users can translate an arbitrary sentence +def inference(opts): + + # Get training data, tokenizer and vocab + # objects as well as any special symbols we added to our dataset + _, _, src_vocab, tgt_vocab, src_transform, _, special_symbols = get_data(opts) + + src_vocab_size = len(src_vocab) + tgt_vocab_size = len(tgt_vocab) + + # Create model + model = Translator( + num_encoder_layers=opts.enc_layers, + num_decoder_layers=opts.dec_layers, + embed_size=opts.embed_size, + num_heads=opts.attn_heads, + src_vocab_size=src_vocab_size, + tgt_vocab_size=tgt_vocab_size, + dim_feedforward=opts.dim_feedforward, + dropout=opts.dropout + ).to(DEVICE) + + # Load in weights + model.load_state_dict(torch.load(opts.model_path)) + + # Set to inference + model.eval() + + # Accept input and keep translating until they quit + while True: + print("> ", end="") + + sentence = input() + + # Convert to tokens + src = src_transform(sentence).view(-1, 1) + num_tokens = src.shape[0] + + src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool) + + # Decode + tgt_tokens = greedy_decode( + model, src, src_mask, max_len=num_tokens+5, start_symbol=special_symbols[""], end_symbol=special_symbols[""] + ).flatten() + + # Convert to list of tokens + output_as_list = list(tgt_tokens.cpu().numpy()) + + # Convert tokens to words + output_list_words = tgt_vocab.lookup_tokens(output_as_list) + + # Remove special tokens and convert to string + translation = " ".join(output_list_words).replace("", "").replace("", "") + + print(translation) + +# Train the model for 1 epoch +def train(model, train_dl, loss_fn, optim, special_symbols, opts): + + # Object for accumulating losses + losses = 0 + + # Put model into inference mode + model.train() + for src, tgt in tqdm(train_dl, ascii=True): + + src = src.to(DEVICE) + tgt = tgt.to(DEVICE) + + # We need to reshape the input slightly to fit into the transformer + tgt_input = tgt[:-1, :] + + # Create masks + src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input, special_symbols[""], DEVICE) + + # Pass into model, get probability over the vocab out + logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask) + + # Reset gradients before we try to compute the gradients over the loss + optim.zero_grad() + + # Get original shape back + tgt_out = tgt[1:, :] + + # Compute loss and gradient over that loss + loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1)) + loss.backward() + + # Step weights + optim.step() + + # Accumulate a running loss for reporting + losses += loss.item() + + if opts.dry_run: + break + + # Return the average loss + return losses / len(list(train_dl)) + +# Check the model accuracy on the validation dataset +def validate(model, valid_dl, loss_fn, special_symbols): + + # Object for accumulating losses + losses = 0 + + # Turn off gradients a moment + model.eval() + + for src, tgt in tqdm(valid_dl): + + src = src.to(DEVICE) + tgt = tgt.to(DEVICE) + + # We need to reshape the input slightly to fit into the transformer + tgt_input = tgt[:-1, :] + + # Create masks + src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input, special_symbols[""], DEVICE) + + # Pass into model, get probability over the vocab out + logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask) + + # Get original shape back, compute loss, accumulate that loss + tgt_out = tgt[1:, :] + loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1)) + losses += loss.item() + + # Return the average loss + return losses / len(list(valid_dl)) + +# Train the model +def main(opts): + + # Set up logging + os.makedirs(opts.logging_dir, exist_ok=True) + logger = logging.getLogger(__name__) + logging.basicConfig(filename=opts.logging_dir + "log.txt", level=logging.INFO) + + # This prints it to the screen as well + console = logging.StreamHandler() + console.setLevel(logging.INFO) + logging.getLogger().addHandler(console) + + logging.info(f"Translation task: {opts.src} -> {opts.tgt}") + logging.info(f"Using device: {DEVICE}") + + # Get training data, tokenizer and vocab + # objects as well as any special symbols we added to our dataset + train_dl, valid_dl, src_vocab, tgt_vocab, _, _, special_symbols = get_data(opts) + + logging.info("Loaded data") + + src_vocab_size = len(src_vocab) + tgt_vocab_size = len(tgt_vocab) + + logging.info(f"{opts.src} vocab size: {src_vocab_size}") + logging.info(f"{opts.tgt} vocab size: {tgt_vocab_size}") + + # Create model + model = Translator( + num_encoder_layers=opts.enc_layers, + num_decoder_layers=opts.dec_layers, + embed_size=opts.embed_size, + num_heads=opts.attn_heads, + src_vocab_size=src_vocab_size, + tgt_vocab_size=tgt_vocab_size, + dim_feedforward=opts.dim_feedforward, + dropout=opts.dropout + ).to(DEVICE) + + logging.info("Model created... starting training!") + + # Set up our learning tools + loss_fn = torch.nn.CrossEntropyLoss(ignore_index=special_symbols[""]) + + # These special values are from the "Attention is all you need" paper + optim = torch.optim.Adam(model.parameters(), lr=opts.lr, betas=(0.9, 0.98), eps=1e-9) + + best_val_loss = 1e6 + + for idx, epoch in enumerate(range(1, opts.epochs+1)): + + start_time = time() + train_loss = train(model, train_dl, loss_fn, optim, special_symbols, opts) + epoch_time = time() - start_time + val_loss = validate(model, valid_dl, loss_fn, special_symbols) + + # Once training is done, we want to save out the model + if val_loss < best_val_loss: + best_val_loss = val_loss + logging.info("New best model, saving...") + torch.save(model.state_dict(), opts.logging_dir + "best.pt") + + torch.save(model.state_dict(), opts.logging_dir + "last.pt") + + logger.info(f"Epoch: {epoch}\n\tTrain loss: {train_loss:.3f}\n\tVal loss: {val_loss:.3f}\n\tEpoch time = {epoch_time:.1f} seconds\n\tETA = {epoch_time*(opts.epochs-idx-1):.1f} seconds") + +if __name__ == "__main__": + + parser = ArgumentParser( + prog="Machine Translator training and inference", + ) + + # Inference mode + parser.add_argument("--inference", action="store_true", + help="Set true to run inference") + parser.add_argument("--model_path", type=str, + help="Path to the model to run inference on") + + # Translation settings + parser.add_argument("--src", type=str, default="de", + help="Source language (translating FROM this language)") + parser.add_argument("--tgt", type=str, default="en", + help="Target language (translating TO this language)") + + # Training settings + parser.add_argument("-e", "--epochs", type=int, default=30, + help="Epochs") + parser.add_argument("--lr", type=float, default=1e-4, + help="Default learning rate") + parser.add_argument("--batch", type=int, default=128, + help="Batch size") + parser.add_argument("--backend", type=str, default="cpu", + help="Batch size") + + # Transformer settings + parser.add_argument("--attn_heads", type=int, default=8, + help="Number of attention heads") + parser.add_argument("--enc_layers", type=int, default=5, + help="Number of encoder layers") + parser.add_argument("--dec_layers", type=int, default=5, + help="Number of decoder layers") + parser.add_argument("--embed_size", type=int, default=512, + help="Size of the language embedding") + parser.add_argument("--dim_feedforward", type=int, default=512, + help="Feedforward dimensionality") + parser.add_argument("--dropout", type=float, default=0.1, + help="Transformer dropout") + + # Logging settings + parser.add_argument("--logging_dir", type=str, default="./" + str(date.today()) + "/", + help="Where the output of this program should be placed") + + # Just for continuous integration + parser.add_argument("--dry_run", action="store_true") + + args = parser.parse_args() + + DEVICE = torch.device("cuda" if args.backend == "gpu" and torch.cuda.is_available() else "cpu") + + if args.inference: + inference(args) + else: + main(args) diff --git a/language_translation/requirements.txt b/language_translation/requirements.txt new file mode 100644 index 0000000000..0e98d6f3b1 --- /dev/null +++ b/language_translation/requirements.txt @@ -0,0 +1,5 @@ +torch +torchtext +torchdata +spacy +portalocker diff --git a/language_translation/src/data.py b/language_translation/src/data.py new file mode 100644 index 0000000000..c1c4c7f545 --- /dev/null +++ b/language_translation/src/data.py @@ -0,0 +1,134 @@ +import torch +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader +from torchtext.data.utils import get_tokenizer +from torchtext.vocab import build_vocab_from_iterator +from torchtext.datasets import Multi30k, multi30k + +# Turns an iterable into a generator +def _yield_tokens(iterable_data, tokenizer, src): + + # Iterable data stores the samples as (src, tgt) so this will help us select just one language or the other + index = 0 if src else 1 + + for data in iterable_data: + yield tokenizer(data[index]) + +# Get data, tokenizer, text transform, vocab objs, etc. Everything we +# need to start training the model +def get_data(opts): + + src_lang = opts.src + tgt_lang = opts.tgt + + multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz" + multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz" + + # Define a token "unkown", "padding", "beginning of sentence", and "end of sentence" + special_symbols = { + "":0, + "":1, + "":2, + "":3 + } + + # Get training examples from torchtext (the multi30k dataset) + train_iterator = Multi30k(split="train", language_pair=(src_lang, tgt_lang)) + valid_iterator = Multi30k(split="valid", language_pair=(src_lang, tgt_lang)) + + # Grab a tokenizer for these languages + src_tokenizer = get_tokenizer("spacy", src_lang) + tgt_tokenizer = get_tokenizer("spacy", tgt_lang) + + # Build a vocabulary object for these languages + src_vocab = build_vocab_from_iterator( + _yield_tokens(train_iterator, src_tokenizer, src_lang), + min_freq=1, + specials=list(special_symbols.keys()), + special_first=True + ) + + tgt_vocab = build_vocab_from_iterator( + _yield_tokens(train_iterator, tgt_tokenizer, tgt_lang), + min_freq=1, + specials=list(special_symbols.keys()), + special_first=True + ) + + src_vocab.set_default_index(special_symbols[""]) + tgt_vocab.set_default_index(special_symbols[""]) + + # Helper function to sequentially apply transformations + def _seq_transform(*transforms): + def func(txt_input): + for transform in transforms: + txt_input = transform(txt_input) + return txt_input + return func + + # Function to add BOS/EOS and create tensor for input sequence indices + def _tensor_transform(token_ids): + return torch.cat( + (torch.tensor([special_symbols[""]]), + torch.tensor(token_ids), + torch.tensor([special_symbols[""]])) + ) + + src_lang_transform = _seq_transform(src_tokenizer, src_vocab, _tensor_transform) + tgt_lang_transform = _seq_transform(tgt_tokenizer, tgt_vocab, _tensor_transform) + + # Now we want to convert the torchtext data pipeline to a dataloader. We + # will need to collate batches + def _collate_fn(batch): + src_batch, tgt_batch = [], [] + for src_sample, tgt_sample in batch: + src_batch.append(src_lang_transform(src_sample.rstrip("\n"))) + tgt_batch.append(tgt_lang_transform(tgt_sample.rstrip("\n"))) + + src_batch = pad_sequence(src_batch, padding_value=special_symbols[""]) + tgt_batch = pad_sequence(tgt_batch, padding_value=special_symbols[""]) + return src_batch, tgt_batch + + # Create the dataloader + train_dataloader = DataLoader(train_iterator, batch_size=opts.batch, collate_fn=_collate_fn) + valid_dataloader = DataLoader(valid_iterator, batch_size=opts.batch, collate_fn=_collate_fn) + + return train_dataloader, valid_dataloader, src_vocab, tgt_vocab, src_lang_transform, tgt_lang_transform, special_symbols + +def generate_square_subsequent_mask(size, device): + mask = (torch.triu(torch.ones((size, size), device=device)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + return mask + +# Create masks for input into model +def create_mask(src, tgt, pad_idx, device): + + # Get sequence length + src_seq_len = src.shape[0] + tgt_seq_len = tgt.shape[0] + + # Generate the mask + tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device) + src_mask = torch.zeros((src_seq_len, src_seq_len),device=device).type(torch.bool) + + # Overlay the mask over the original input + src_padding_mask = (src == pad_idx).transpose(0, 1) + tgt_padding_mask = (tgt == pad_idx).transpose(0, 1) + return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask + +# A small test to make sure our data loasd in correctly +if __name__=="__main__": + + class Opts: + def __init__(self): + self.src = "en", + self.tgt = "de" + self.batch = 128 + + opts = Opts() + + train_dl, valid_dl, src_vocab, tgt_vocab, src_lang_transform, tgt_lang_transform, special_symbols = get_data(opts) + + print(f"{opts.src} vocab size: {len(src_vocab)}") + print(f"{opts.src} vocab size: {len(tgt_vocab)}") + diff --git a/language_translation/src/model.py b/language_translation/src/model.py new file mode 100644 index 0000000000..ec4a28ba1b --- /dev/null +++ b/language_translation/src/model.py @@ -0,0 +1,98 @@ +import math + +import torch +from torch.nn import functional as F +from torch import nn + +class PositionalEncoding(nn.Module): + def __init__( + self, + emb_size, + dropout, + maxlen=5000 + ): + super(PositionalEncoding, self).__init__() + den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size) + pos = torch.arange(0, maxlen).reshape(maxlen, 1) + pos_embedding = torch.zeros((maxlen, emb_size)) + pos_embedding[:, 0::2] = torch.sin(pos * den) + pos_embedding[:, 1::2] = torch.cos(pos * den) + pos_embedding = pos_embedding.unsqueeze(-2) + + self.dropout = nn.Dropout(dropout) + self.register_buffer('pos_embedding', pos_embedding) + + def forward(self, token_embedding): + return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :]) + +class Translator(nn.Module): + def __init__( + self, + num_encoder_layers, + num_decoder_layers, + embed_size, + num_heads, + src_vocab_size, + tgt_vocab_size, + dim_feedforward, + dropout + ): + super(Translator, self).__init__() + + # Output of embedding must be equal (embed_size) + self.src_embedding = nn.Embedding(src_vocab_size, embed_size) + self.tgt_embedding = nn.Embedding(tgt_vocab_size, embed_size) + + self.pos_enc = PositionalEncoding(embed_size, dropout) + + self.transformer = nn.Transformer( + d_model=embed_size, + nhead=num_heads, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + dim_feedforward=dim_feedforward, + dropout=dropout + ) + + self.ff = nn.Linear(embed_size, tgt_vocab_size) + + self._init_weights() + + def _init_weights(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, trg, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask): + + src_emb = self.pos_enc(self.src_embedding(src)) + tgt_emb = self.pos_enc(self.tgt_embedding(trg)) + + outs = self.transformer( + src_emb, + tgt_emb, + src_mask, + tgt_mask, + None, + src_padding_mask, + tgt_padding_mask, + memory_key_padding_mask + ) + + return self.ff(outs) + + def encode(self, src, src_mask): + + embed = self.src_embedding(src) + + pos_enc = self.pos_enc(embed) + + return self.transformer.encoder(pos_enc, src_mask) + + def decode(self, tgt, memory, tgt_mask): + + embed = self.tgt_embedding(tgt) + + pos_enc = self.pos_enc(embed) + + return self.transformer.decoder(pos_enc, memory, tgt_mask) diff --git a/run_python_examples.sh b/run_python_examples.sh index a9ff393e80..c5665def13 100755 --- a/run_python_examples.sh +++ b/run_python_examples.sh @@ -13,6 +13,11 @@ BASE_DIR=`pwd`"/"`dirname $0` EXAMPLES=`echo $1 | sed -e 's/ //g'` +# Redirect 'python' calls to 'python3' +python() { + command python3 "$@" +} + USE_CUDA=$(python -c "import torchvision, torch; print(torch.cuda.is_available())") case $USE_CUDA in "True") @@ -91,6 +96,13 @@ function imagenet() { python main.py --epochs 1 sample/ || error "imagenet example failed" } +function language_translation() { + start + python -m spacy download en || error "couldn't download en package from spacy" + python -m spacy download de || error "couldn't download de package from spacy" + python main.py -e 1 --enc_layers 1 --dec_layers 1 --backend cpu --logging_dir output/ --dry_run || error "language translation example failed" +} + function mnist() { start python main.py --epochs 1 --dry-run || error "mnist example failed" @@ -195,6 +207,7 @@ function clean() { imagenet/lsun/ \ imagenet/model_best.pth.tar \ imagenet/sample/ \ + language_translation/output/ \ snli/.data/ \ snli/.vector_cache/ \ snli/results/ \ @@ -215,6 +228,7 @@ function run_all() { distributed fast_neural_style imagenet + language_translation mnist mnist_forward_forward mnist_hogwild