Skip to content

Change #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -23,8 +23,8 @@ def __len__(self):

def __getitem__(self, idx):
src_target_pair = self.ds[idx]
src_text = src_target_pair['translation'][self.src_lang]
tgt_text = src_target_pair['translation'][self.tgt_lang]
src_text = src_target_pair[self.src_lang]
tgt_text = src_target_pair[self.tgt_lang]

# Transform the text into tokens
enc_input_tokens = self.tokenizer_src.encode(src_text).ids
30 changes: 14 additions & 16 deletions train.py
Original file line number Diff line number Diff line change
@@ -123,7 +123,7 @@ def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len,

def get_all_sentences(ds, lang):
for item in ds:
yield item['translation'][lang]
yield item[lang]

def get_or_build_tokenizer(config, ds, lang):
tokenizer_path = Path(config['tokenizer_file'].format(lang))
@@ -140,38 +140,36 @@ def get_or_build_tokenizer(config, ds, lang):

def get_ds(config):
# It only has the train split, so we divide it overselves
ds_raw = load_dataset(f"{config['datasource']}", f"{config['lang_src']}-{config['lang_tgt']}", split='train')
train_ds_raw, valid_ds_raw = load_dataset(f"{config['datasource']}", split=['train', 'valid'])

# Build tokenizers
tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])
train_tokenizer_src = get_or_build_tokenizer(config, train_ds_raw, config['lang_src'])
train_tokenizer_tgt = get_or_build_tokenizer(config, train_ds_raw, config['lang_tgt'])

# Keep 90% for training, 10% for validation
train_ds_size = int(0.9 * len(ds_raw))
val_ds_size = len(ds_raw) - train_ds_size
train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])
valid_tokenizer_src = get_or_build_tokenizer(config, valid_ds_raw, config['lang_src'])
valid_tokenizer_tgt = get_or_build_tokenizer(config, valid_ds_raw, config['lang_tgt'])

train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
train_ds = BilingualDataset(train_ds_raw, train_tokenizer_src, train_tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
val_ds = BilingualDataset(valid_ds_raw, valid_tokenizer_src, valid_tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])

# Find the maximum length of each sentence in the source and target sentence
max_len_src = 0
max_len_tgt = 0

for item in ds_raw:
src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
for item in train_ds_raw:
src_ids = train_tokenizer_src.encode(item[config['lang_src']]).ids
tgt_ids = train_tokenizer_tgt.encode(item[config['lang_tgt']]).ids
max_len_src = max(max_len_src, len(src_ids))
max_len_tgt = max(max_len_tgt, len(tgt_ids))

print(f'Max length of source sentence: {max_len_src}')
print(f'Max length of target sentence: {max_len_tgt}')
print(f'Max length of train source sentence: {max_len_src}')
print(f'Max length of train target sentence: {max_len_tgt}')


train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt
return train_dataloader, val_dataloader, train_tokenizer_src, train_tokenizer_tgt

def get_model(config, vocab_src_len, vocab_tgt_len):
model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config['seq_len'], d_model=config['d_model'])