Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions metaseq/tasks/streaming_language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"DatasetWithShardInformation", ["dataset", "is_sharded", "shard_id", "num_shards"]
)

TEXT_DATA_EVALSETS = ["llama", "text_eval", "marmot"]
IMAGE_PREFIX = "IMGIMG"


Expand Down Expand Up @@ -679,8 +680,12 @@ def load_dataset(

dataset = torch.utils.data.ConcatDataset(datasets)

break_mode = "complete" if split != "train" else self.args.sample_break_mode
no_image_break = False if split != "train" else self.args.no_image_break
is_text = any([subset in split for subset in TEXT_DATA_EVALSETS])

# chunk into blocks of tokens
if self.has_cm3:
if self.has_cm3 and not is_text:
# We chose not to use compositional inheritance because there's a
# lot of downstream code that has isinstance checks.
# So just to be safe and not change anything we use proper inheritance.
Expand All @@ -695,19 +700,19 @@ def load_dataset(
# We generate blocks with one extra token, so that we have a target
# for the final input token. This results in slight data loss.
block_size=self.args.tokens_per_sample + 1,
break_mode=self.args.sample_break_mode,
break_mode=break_mode,
# we drop the remainder block during training
drop_last=(split == "train"),
padding_idx=self.source_dictionary.pad(),
seed=self.args.seed,
percent_full_document_rotation=self.args.cm3_percent_full_document_rotation,
no_break_image=self.args.no_break_image,
no_break_image=no_image_break,
)
else:
self.datasets[split] = DocumentToSequenceDataset(
dataset,
block_size=self.args.tokens_per_sample + 1,
break_mode=self.args.sample_break_mode,
break_mode=break_mode,
drop_last=(split == "train"),
padding_idx=self.source_dictionary.pad(),
seed=self.args.seed,
Expand Down