-
Notifications
You must be signed in to change notification settings - Fork 431
Pad each batch, not the whole dataset #30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
| return train_loader, valid_loader, train_sampler, valid_sampler | ||
|
|
||
|
|
||
| def make_data_lists(args, personachat, tokenizer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring
| for utterance in dialog["utterances"]: | ||
| history = utterance["history"][-(2*args.max_history+1):] | ||
| candidate_instances = defaultdict(list) | ||
| history = utterance["history"][-(2 * args.max_history + 1):] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could add assert len(utterance['candidates']) >= num_candidates
| return instance, sequence # TODO: second arg is never used, delete it | ||
|
|
||
|
|
||
| def pad_and_tensorize(batch_dict, padding): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this and ChatDataset should be easy to unit test
| valid_dataset = ChatDataset(datasets['valid'], pad_id) | ||
|
|
||
| logger.info("Build train and validation dataloaders") | ||
| train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(maybe) put this in ChatDataset.to_loader(self, args, shuffle) -> sampler, loader
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at some point might also want to document which tensors are 3D
| for input_name, input_array in instance.items(): | ||
| datasets[dataset_name][input_name].append(input_array) | ||
| candidate_instances[input_name].append(input_array) | ||
| for k in candidate_instances.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.items() will save some chars
train.py
Outdated
| for j, candidate in enumerate(utterance["candidates"][-num_candidates:]): | ||
| lm_labels = bool(j == num_candidates-1) | ||
| instance, _ = build_input_from_segments(persona, history, candidate, tokenizer, lm_labels) | ||
| lm_labels = bool(j == num_candidates - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better varname?
Previously, each sequence was padded to the length of the longest sequence in the dataset.
In this PR, each batch is padded to the length of the longest sequence in the batch. This results in a 30% speedup with negligible impact on metrics.
Code Changes
ChatDatasetyields example dicts like{'input_ids': [[hist + cand1], ..[hist +cand_n]],}for thePADDED_INPUTSandmc_token_idsandmc_labelsin the same format as previously.ChatDataset().collate_fn(examples: list)turns a list of example dicts into the list of 5 tensors by batching them and padding themget_dataloadersdoes much lessconvai_evaluation.pystill calls the oldpad_dataset1 Epoch Sanity Check
Before Change: 85 minutes
Validation: {'accuracy': 0.7483655941545956,
'average_accuracy': 0.7483655941545956,
'average_nll': 2.6815188920676687,
'average_ppl': 14.607263311061963,
'nll': 2.6815188920676687}
After Change: 60 minutes
Validation: {'accuracy': 0.7466991411357519,
'average_accuracy': 0.7466991411357519,
'average_nll': 2.6821035040007972,
'average_ppl': 14.615805388160778,
'nll': 2.6821035040007972}
Command: