Skip to content

Commit 7902e40

Browse files
[feature]add dataset classs (#775)
* [feature]add dataset classs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b8bdcd4 commit 7902e40

File tree

2 files changed

+224
-39
lines changed

2 files changed

+224
-39
lines changed

fish_speech/configs/text2semantic_finetune.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ tokenizer:
2626

2727
# Dataset Configuration
2828
train_dataset:
29-
_target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
29+
_target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
3030
proto_files:
3131
- data/protos
3232
tokenizer: ${tokenizer}
@@ -36,7 +36,7 @@ train_dataset:
3636
interactive_prob: 0.7
3737

3838
val_dataset:
39-
_target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
39+
_target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
4040
proto_files:
4141
- data/protos
4242
tokenizer: ${tokenizer}

fish_speech/datasets/semantic.py

+222-37
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from huggingface_hub import HfApi
1414
from lightning import LightningDataModule
1515
from torch.distributed import get_rank, get_world_size, is_initialized
16-
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
16+
from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info
1717

1818
from fish_speech.conversation import (
1919
CODEBOOK_PAD_TOKEN_ID,
@@ -59,7 +59,7 @@ def split_by_rank_worker(files):
5959
return files
6060

6161

62-
class AutoTextSemanticInstructionDataset(IterableDataset):
62+
class AutoTextSemanticInstructionIterableDataset(IterableDataset):
6363
"""
6464
Auto Augment Dataset by Speaker
6565
@@ -295,6 +295,214 @@ def augment(self):
295295
return data
296296

297297

298+
class AutoTextSemanticInstructionDataset(Dataset):
299+
"""
300+
Auto Augment Dataset by Speaker
301+
302+
1. Random concatenate multiple sentences from the same speaker to form a longer sentence
303+
2. Automatically normalize the text
304+
305+
For interactive mode, we use the following format (multiple sequences):
306+
<s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
307+
308+
For non-interactive mode, we use the following format (one long sequence):
309+
<s> [INST] text [/INST] ... </s>
310+
"""
311+
312+
def __init__(
313+
self,
314+
proto_files: list[str],
315+
seed: int = 42,
316+
interactive_prob: float = 0.5,
317+
max_length: int = 1024,
318+
tokenizer: FishTokenizer = None,
319+
use_speaker: bool | float = True,
320+
causal: bool = True,
321+
num_codebooks: Optional[int] = None,
322+
skip_text_prob: float = 0.0,
323+
):
324+
"""
325+
Args:
326+
proto_files: proto buf files if using local data
327+
seed: random seed
328+
interactive_prob: probability to use interactive mode
329+
max_length: max length of the text
330+
tokenizer: tokenizer
331+
use_speaker: include speaker information in the prompt
332+
causal: use causal sampling when using local data, disable will lead to random sampling
333+
num_codebooks: number of codebooks, if None, it will be automatically detected
334+
skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
335+
"""
336+
super().__init__()
337+
338+
assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
339+
340+
self.seed = seed
341+
self.max_length = max_length
342+
self.tokenizer = tokenizer
343+
self.interactive_prob = interactive_prob
344+
self.use_speaker = use_speaker
345+
self.proto_files = proto_files
346+
self.causal = causal
347+
self.num_codebooks = num_codebooks
348+
self.skip_text_prob = skip_text_prob
349+
350+
self.data = []
351+
self._init_data()
352+
353+
def _init_data(self):
354+
expanded_proto_files = []
355+
for filename in self.proto_files:
356+
for i in braceexpand(filename):
357+
i = Path(i)
358+
if i.is_file():
359+
expanded_proto_files.append(i)
360+
elif i.is_dir():
361+
expanded_proto_files.extend(i.rglob("*.proto"))
362+
expanded_proto_files.extend(i.rglob("*.protos"))
363+
else:
364+
raise ValueError(f"{i} is not a file or directory")
365+
366+
expanded_proto_files = sorted(expanded_proto_files)
367+
Random(self.seed).shuffle(expanded_proto_files)
368+
369+
groups = []
370+
shard_proto_files = split_by_rank_worker(expanded_proto_files)
371+
log.info(
372+
f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
373+
)
374+
375+
count = 0
376+
for filename in shard_proto_files:
377+
with open(filename, "rb") as f:
378+
for text_data in read_pb_stream(f):
379+
groups.append(text_data)
380+
count += 1
381+
382+
log.info(f"Read total {count} groups of data")
383+
384+
for group in groups:
385+
if len(group.sentences) == 0:
386+
continue
387+
388+
samples = list(group.sentences)
389+
for sentence in samples:
390+
text = clean_text(random.choice(sentence.texts))
391+
392+
tokens, labels = self.pack_sentences(
393+
sentences=[text],
394+
semantics=[sentence.semantics],
395+
skip_text=random.random() < self.skip_text_prob,
396+
)
397+
398+
self.data.append({"tokens": tokens, "labels": labels})
399+
400+
random.Random(self.seed).shuffle(self.data)
401+
402+
def __len__(self):
403+
return len(self.data)
404+
405+
def __getitem__(self, idx):
406+
return self.data[idx]
407+
408+
def pack_sentences(
409+
self,
410+
sentences: list[str],
411+
semantics: list,
412+
skip_text: bool = False,
413+
):
414+
messages = [
415+
Message(
416+
role="system",
417+
parts=[TextPart(text="Speak out the provided text.")],
418+
)
419+
]
420+
421+
cated_sentences = " ".join(sentences)
422+
if skip_text:
423+
cated_sentences = "<|skip_text|>"
424+
425+
messages.append(
426+
Message(
427+
role="user",
428+
parts=[TextPart(text=cated_sentences)],
429+
)
430+
)
431+
432+
vq_codes = [x.values for x in semantics[0]]
433+
vq_codes_tensor = torch.tensor(vq_codes).to(torch.int32)
434+
vqpart = VQPart(codes=vq_codes_tensor)
435+
messages.append(
436+
Message(
437+
role="assistant",
438+
parts=[TextPart(text="<|voice|>"), vqpart],
439+
cal_loss=True,
440+
)
441+
)
442+
443+
num_codebooks = (
444+
len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
445+
)
446+
447+
conversation = Conversation(messages=messages)
448+
encoded = conversation.encode(
449+
tokenizer=self.tokenizer,
450+
)
451+
452+
tokens_raw = encoded.tokens
453+
tokens = torch.zeros((num_codebooks + 1, len(tokens_raw)), dtype=torch.int)
454+
tokens[0] = tokens_raw
455+
456+
vq_parts = encoded.vq_parts
457+
vq_parts = [part.to(tokens.device) for part in vq_parts]
458+
vq_parts = torch.cat(vq_parts, dim=1)
459+
tokens[1:, encoded.vq_mask_tokens] = vq_parts
460+
461+
labels_raw = encoded.labels
462+
labels = torch.full((num_codebooks + 1, len(labels_raw)), -100, dtype=torch.int)
463+
labels[0, :] = labels_raw
464+
labels[1:, encoded.vq_mask_labels] = vq_parts
465+
labels[1:, -1:] = CODEBOOK_PAD_TOKEN_ID
466+
467+
tokens = tokens.long()
468+
labels = labels.long()
469+
470+
assert (tokens[1:, ~(encoded.vq_mask_tokens)] == CODEBOOK_PAD_TOKEN_ID).all()
471+
assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
472+
473+
return tokens, labels
474+
475+
476+
class InterleaveDataset(IterableDataset):
477+
def __init__(
478+
self,
479+
datasets: list[IterableDataset],
480+
probabilities: list[float],
481+
seed: int = 42,
482+
):
483+
super().__init__()
484+
485+
self.datasets = datasets
486+
self.probabilities = probabilities
487+
self.seed = seed
488+
489+
def __iter__(self):
490+
rng = np.random.default_rng(self.seed)
491+
dataset_iterators = [iter(dataset) for dataset in self.datasets]
492+
493+
while True:
494+
# Random choice one
495+
dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
496+
dataset_iterator = dataset_iterators[dataset_idx]
497+
498+
try:
499+
yield next(dataset_iterator)
500+
except StopIteration:
501+
# Exhausted, create a new iterator
502+
dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
503+
yield next(dataset_iterators[dataset_idx])
504+
505+
298506
@dataclass
299507
class TextDataCollator:
300508
tokenizer: FishTokenizer
@@ -369,41 +577,19 @@ def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
369577
}
370578

371579

372-
class InterleaveDataset(IterableDataset):
373-
def __init__(
374-
self,
375-
datasets: list[IterableDataset],
376-
probabilities: list[float],
377-
seed: int = 42,
378-
):
379-
super().__init__()
380-
381-
self.datasets = datasets
382-
self.probabilities = probabilities
383-
self.seed = seed
384-
385-
def __iter__(self):
386-
rng = np.random.default_rng(self.seed)
387-
dataset_iterators = [iter(dataset) for dataset in self.datasets]
388-
389-
while True:
390-
# Random choice one
391-
dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
392-
dataset_iterator = dataset_iterators[dataset_idx]
393-
394-
try:
395-
yield next(dataset_iterator)
396-
except StopIteration:
397-
# Exhausted, create a new iterator
398-
dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
399-
yield next(dataset_iterators[dataset_idx])
400-
401-
402580
class SemanticDataModule(LightningDataModule):
403581
def __init__(
404582
self,
405-
train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
406-
val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
583+
train_dataset: Union[
584+
AutoTextSemanticInstructionDataset,
585+
AutoTextSemanticInstructionIterableDataset,
586+
InterleaveDataset,
587+
],
588+
val_dataset: Union[
589+
AutoTextSemanticInstructionDataset,
590+
AutoTextSemanticInstructionIterableDataset,
591+
InterleaveDataset,
592+
],
407593
batch_size: int = 32,
408594
tokenizer: FishTokenizer = None,
409595
max_length: int = 1024,
@@ -448,7 +634,6 @@ def val_dataloader(self):
448634
skip_text_prob=0.5,
449635
)
450636

451-
for i in ds:
637+
for i in range(100):
452638
# Please uncomment line 235 to visualize the tokenized message
453-
print(i)
454-
break
639+
print(ds[i])

0 commit comments

Comments
 (0)