From fe97990547362c45fd0a37841015e3b0a3198fa6 Mon Sep 17 00:00:00 2001 From: dinhnhat241103 <21522414@gm.uit.edu.vn> Date: Sat, 8 Jun 2024 16:07:36 +0700 Subject: [PATCH] Trigger Build --- dataset.py | 4 ++-- train.py | 30 ++++++++++++++---------------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/dataset.py b/dataset.py index 7aa175c..fa85d92 100644 --- a/dataset.py +++ b/dataset.py @@ -23,8 +23,8 @@ def __len__(self): def __getitem__(self, idx): src_target_pair = self.ds[idx] - src_text = src_target_pair['translation'][self.src_lang] - tgt_text = src_target_pair['translation'][self.tgt_lang] + src_text = src_target_pair[self.src_lang] + tgt_text = src_target_pair[self.tgt_lang] # Transform the text into tokens enc_input_tokens = self.tokenizer_src.encode(src_text).ids diff --git a/train.py b/train.py index 6c28318..4a1404b 100644 --- a/train.py +++ b/train.py @@ -123,7 +123,7 @@ def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, def get_all_sentences(ds, lang): for item in ds: - yield item['translation'][lang] + yield item[lang] def get_or_build_tokenizer(config, ds, lang): tokenizer_path = Path(config['tokenizer_file'].format(lang)) @@ -140,38 +140,36 @@ def get_or_build_tokenizer(config, ds, lang): def get_ds(config): # It only has the train split, so we divide it overselves - ds_raw = load_dataset(f"{config['datasource']}", f"{config['lang_src']}-{config['lang_tgt']}", split='train') + train_ds_raw, valid_ds_raw = load_dataset(f"{config['datasource']}", split=['train', 'valid']) # Build tokenizers - tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src']) - tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt']) + train_tokenizer_src = get_or_build_tokenizer(config, train_ds_raw, config['lang_src']) + train_tokenizer_tgt = get_or_build_tokenizer(config, train_ds_raw, config['lang_tgt']) - # Keep 90% for training, 10% for validation - train_ds_size = int(0.9 * len(ds_raw)) - val_ds_size = len(ds_raw) - train_ds_size - train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size]) + valid_tokenizer_src = get_or_build_tokenizer(config, valid_ds_raw, config['lang_src']) + valid_tokenizer_tgt = get_or_build_tokenizer(config, valid_ds_raw, config['lang_tgt']) - train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len']) - val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len']) + train_ds = BilingualDataset(train_ds_raw, train_tokenizer_src, train_tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len']) + val_ds = BilingualDataset(valid_ds_raw, valid_tokenizer_src, valid_tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len']) # Find the maximum length of each sentence in the source and target sentence max_len_src = 0 max_len_tgt = 0 - for item in ds_raw: - src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids - tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids + for item in train_ds_raw: + src_ids = train_tokenizer_src.encode(item[config['lang_src']]).ids + tgt_ids = train_tokenizer_tgt.encode(item[config['lang_tgt']]).ids max_len_src = max(max_len_src, len(src_ids)) max_len_tgt = max(max_len_tgt, len(tgt_ids)) - print(f'Max length of source sentence: {max_len_src}') - print(f'Max length of target sentence: {max_len_tgt}') + print(f'Max length of train source sentence: {max_len_src}') + print(f'Max length of train target sentence: {max_len_tgt}') train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True) val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True) - return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt + return train_dataloader, val_dataloader, train_tokenizer_src, train_tokenizer_tgt def get_model(config, vocab_src_len, vocab_tgt_len): model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config['seq_len'], d_model=config['d_model'])