|
13 | 13 | from huggingface_hub import HfApi
|
14 | 14 | from lightning import LightningDataModule
|
15 | 15 | from torch.distributed import get_rank, get_world_size, is_initialized
|
16 |
| -from torch.utils.data import DataLoader, IterableDataset, get_worker_info |
| 16 | +from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info |
17 | 17 |
|
18 | 18 | from fish_speech.conversation import (
|
19 | 19 | CODEBOOK_PAD_TOKEN_ID,
|
@@ -59,7 +59,7 @@ def split_by_rank_worker(files):
|
59 | 59 | return files
|
60 | 60 |
|
61 | 61 |
|
62 |
| -class AutoTextSemanticInstructionDataset(IterableDataset): |
| 62 | +class AutoTextSemanticInstructionIterableDataset(IterableDataset): |
63 | 63 | """
|
64 | 64 | Auto Augment Dataset by Speaker
|
65 | 65 |
|
@@ -295,6 +295,214 @@ def augment(self):
|
295 | 295 | return data
|
296 | 296 |
|
297 | 297 |
|
| 298 | +class AutoTextSemanticInstructionDataset(Dataset): |
| 299 | + """ |
| 300 | + Auto Augment Dataset by Speaker |
| 301 | +
|
| 302 | + 1. Random concatenate multiple sentences from the same speaker to form a longer sentence |
| 303 | + 2. Automatically normalize the text |
| 304 | +
|
| 305 | + For interactive mode, we use the following format (multiple sequences): |
| 306 | + <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s> |
| 307 | +
|
| 308 | + For non-interactive mode, we use the following format (one long sequence): |
| 309 | + <s> [INST] text [/INST] ... </s> |
| 310 | + """ |
| 311 | + |
| 312 | + def __init__( |
| 313 | + self, |
| 314 | + proto_files: list[str], |
| 315 | + seed: int = 42, |
| 316 | + interactive_prob: float = 0.5, |
| 317 | + max_length: int = 1024, |
| 318 | + tokenizer: FishTokenizer = None, |
| 319 | + use_speaker: bool | float = True, |
| 320 | + causal: bool = True, |
| 321 | + num_codebooks: Optional[int] = None, |
| 322 | + skip_text_prob: float = 0.0, |
| 323 | + ): |
| 324 | + """ |
| 325 | + Args: |
| 326 | + proto_files: proto buf files if using local data |
| 327 | + seed: random seed |
| 328 | + interactive_prob: probability to use interactive mode |
| 329 | + max_length: max length of the text |
| 330 | + tokenizer: tokenizer |
| 331 | + use_speaker: include speaker information in the prompt |
| 332 | + causal: use causal sampling when using local data, disable will lead to random sampling |
| 333 | + num_codebooks: number of codebooks, if None, it will be automatically detected |
| 334 | + skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode |
| 335 | + """ |
| 336 | + super().__init__() |
| 337 | + |
| 338 | + assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]" |
| 339 | + |
| 340 | + self.seed = seed |
| 341 | + self.max_length = max_length |
| 342 | + self.tokenizer = tokenizer |
| 343 | + self.interactive_prob = interactive_prob |
| 344 | + self.use_speaker = use_speaker |
| 345 | + self.proto_files = proto_files |
| 346 | + self.causal = causal |
| 347 | + self.num_codebooks = num_codebooks |
| 348 | + self.skip_text_prob = skip_text_prob |
| 349 | + |
| 350 | + self.data = [] |
| 351 | + self._init_data() |
| 352 | + |
| 353 | + def _init_data(self): |
| 354 | + expanded_proto_files = [] |
| 355 | + for filename in self.proto_files: |
| 356 | + for i in braceexpand(filename): |
| 357 | + i = Path(i) |
| 358 | + if i.is_file(): |
| 359 | + expanded_proto_files.append(i) |
| 360 | + elif i.is_dir(): |
| 361 | + expanded_proto_files.extend(i.rglob("*.proto")) |
| 362 | + expanded_proto_files.extend(i.rglob("*.protos")) |
| 363 | + else: |
| 364 | + raise ValueError(f"{i} is not a file or directory") |
| 365 | + |
| 366 | + expanded_proto_files = sorted(expanded_proto_files) |
| 367 | + Random(self.seed).shuffle(expanded_proto_files) |
| 368 | + |
| 369 | + groups = [] |
| 370 | + shard_proto_files = split_by_rank_worker(expanded_proto_files) |
| 371 | + log.info( |
| 372 | + f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files" |
| 373 | + ) |
| 374 | + |
| 375 | + count = 0 |
| 376 | + for filename in shard_proto_files: |
| 377 | + with open(filename, "rb") as f: |
| 378 | + for text_data in read_pb_stream(f): |
| 379 | + groups.append(text_data) |
| 380 | + count += 1 |
| 381 | + |
| 382 | + log.info(f"Read total {count} groups of data") |
| 383 | + |
| 384 | + for group in groups: |
| 385 | + if len(group.sentences) == 0: |
| 386 | + continue |
| 387 | + |
| 388 | + samples = list(group.sentences) |
| 389 | + for sentence in samples: |
| 390 | + text = clean_text(random.choice(sentence.texts)) |
| 391 | + |
| 392 | + tokens, labels = self.pack_sentences( |
| 393 | + sentences=[text], |
| 394 | + semantics=[sentence.semantics], |
| 395 | + skip_text=random.random() < self.skip_text_prob, |
| 396 | + ) |
| 397 | + |
| 398 | + self.data.append({"tokens": tokens, "labels": labels}) |
| 399 | + |
| 400 | + random.Random(self.seed).shuffle(self.data) |
| 401 | + |
| 402 | + def __len__(self): |
| 403 | + return len(self.data) |
| 404 | + |
| 405 | + def __getitem__(self, idx): |
| 406 | + return self.data[idx] |
| 407 | + |
| 408 | + def pack_sentences( |
| 409 | + self, |
| 410 | + sentences: list[str], |
| 411 | + semantics: list, |
| 412 | + skip_text: bool = False, |
| 413 | + ): |
| 414 | + messages = [ |
| 415 | + Message( |
| 416 | + role="system", |
| 417 | + parts=[TextPart(text="Speak out the provided text.")], |
| 418 | + ) |
| 419 | + ] |
| 420 | + |
| 421 | + cated_sentences = " ".join(sentences) |
| 422 | + if skip_text: |
| 423 | + cated_sentences = "<|skip_text|>" |
| 424 | + |
| 425 | + messages.append( |
| 426 | + Message( |
| 427 | + role="user", |
| 428 | + parts=[TextPart(text=cated_sentences)], |
| 429 | + ) |
| 430 | + ) |
| 431 | + |
| 432 | + vq_codes = [x.values for x in semantics[0]] |
| 433 | + vq_codes_tensor = torch.tensor(vq_codes).to(torch.int32) |
| 434 | + vqpart = VQPart(codes=vq_codes_tensor) |
| 435 | + messages.append( |
| 436 | + Message( |
| 437 | + role="assistant", |
| 438 | + parts=[TextPart(text="<|voice|>"), vqpart], |
| 439 | + cal_loss=True, |
| 440 | + ) |
| 441 | + ) |
| 442 | + |
| 443 | + num_codebooks = ( |
| 444 | + len(semantics[0]) if self.num_codebooks is None else self.num_codebooks |
| 445 | + ) |
| 446 | + |
| 447 | + conversation = Conversation(messages=messages) |
| 448 | + encoded = conversation.encode( |
| 449 | + tokenizer=self.tokenizer, |
| 450 | + ) |
| 451 | + |
| 452 | + tokens_raw = encoded.tokens |
| 453 | + tokens = torch.zeros((num_codebooks + 1, len(tokens_raw)), dtype=torch.int) |
| 454 | + tokens[0] = tokens_raw |
| 455 | + |
| 456 | + vq_parts = encoded.vq_parts |
| 457 | + vq_parts = [part.to(tokens.device) for part in vq_parts] |
| 458 | + vq_parts = torch.cat(vq_parts, dim=1) |
| 459 | + tokens[1:, encoded.vq_mask_tokens] = vq_parts |
| 460 | + |
| 461 | + labels_raw = encoded.labels |
| 462 | + labels = torch.full((num_codebooks + 1, len(labels_raw)), -100, dtype=torch.int) |
| 463 | + labels[0, :] = labels_raw |
| 464 | + labels[1:, encoded.vq_mask_labels] = vq_parts |
| 465 | + labels[1:, -1:] = CODEBOOK_PAD_TOKEN_ID |
| 466 | + |
| 467 | + tokens = tokens.long() |
| 468 | + labels = labels.long() |
| 469 | + |
| 470 | + assert (tokens[1:, ~(encoded.vq_mask_tokens)] == CODEBOOK_PAD_TOKEN_ID).all() |
| 471 | + assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all() |
| 472 | + |
| 473 | + return tokens, labels |
| 474 | + |
| 475 | + |
| 476 | +class InterleaveDataset(IterableDataset): |
| 477 | + def __init__( |
| 478 | + self, |
| 479 | + datasets: list[IterableDataset], |
| 480 | + probabilities: list[float], |
| 481 | + seed: int = 42, |
| 482 | + ): |
| 483 | + super().__init__() |
| 484 | + |
| 485 | + self.datasets = datasets |
| 486 | + self.probabilities = probabilities |
| 487 | + self.seed = seed |
| 488 | + |
| 489 | + def __iter__(self): |
| 490 | + rng = np.random.default_rng(self.seed) |
| 491 | + dataset_iterators = [iter(dataset) for dataset in self.datasets] |
| 492 | + |
| 493 | + while True: |
| 494 | + # Random choice one |
| 495 | + dataset_idx = rng.choice(len(self.datasets), p=self.probabilities) |
| 496 | + dataset_iterator = dataset_iterators[dataset_idx] |
| 497 | + |
| 498 | + try: |
| 499 | + yield next(dataset_iterator) |
| 500 | + except StopIteration: |
| 501 | + # Exhausted, create a new iterator |
| 502 | + dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx]) |
| 503 | + yield next(dataset_iterators[dataset_idx]) |
| 504 | + |
| 505 | + |
298 | 506 | @dataclass
|
299 | 507 | class TextDataCollator:
|
300 | 508 | tokenizer: FishTokenizer
|
@@ -369,41 +577,19 @@ def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
|
369 | 577 | }
|
370 | 578 |
|
371 | 579 |
|
372 |
| -class InterleaveDataset(IterableDataset): |
373 |
| - def __init__( |
374 |
| - self, |
375 |
| - datasets: list[IterableDataset], |
376 |
| - probabilities: list[float], |
377 |
| - seed: int = 42, |
378 |
| - ): |
379 |
| - super().__init__() |
380 |
| - |
381 |
| - self.datasets = datasets |
382 |
| - self.probabilities = probabilities |
383 |
| - self.seed = seed |
384 |
| - |
385 |
| - def __iter__(self): |
386 |
| - rng = np.random.default_rng(self.seed) |
387 |
| - dataset_iterators = [iter(dataset) for dataset in self.datasets] |
388 |
| - |
389 |
| - while True: |
390 |
| - # Random choice one |
391 |
| - dataset_idx = rng.choice(len(self.datasets), p=self.probabilities) |
392 |
| - dataset_iterator = dataset_iterators[dataset_idx] |
393 |
| - |
394 |
| - try: |
395 |
| - yield next(dataset_iterator) |
396 |
| - except StopIteration: |
397 |
| - # Exhausted, create a new iterator |
398 |
| - dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx]) |
399 |
| - yield next(dataset_iterators[dataset_idx]) |
400 |
| - |
401 |
| - |
402 | 580 | class SemanticDataModule(LightningDataModule):
|
403 | 581 | def __init__(
|
404 | 582 | self,
|
405 |
| - train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset], |
406 |
| - val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset], |
| 583 | + train_dataset: Union[ |
| 584 | + AutoTextSemanticInstructionDataset, |
| 585 | + AutoTextSemanticInstructionIterableDataset, |
| 586 | + InterleaveDataset, |
| 587 | + ], |
| 588 | + val_dataset: Union[ |
| 589 | + AutoTextSemanticInstructionDataset, |
| 590 | + AutoTextSemanticInstructionIterableDataset, |
| 591 | + InterleaveDataset, |
| 592 | + ], |
407 | 593 | batch_size: int = 32,
|
408 | 594 | tokenizer: FishTokenizer = None,
|
409 | 595 | max_length: int = 1024,
|
@@ -448,7 +634,6 @@ def val_dataloader(self):
|
448 | 634 | skip_text_prob=0.5,
|
449 | 635 | )
|
450 | 636 |
|
451 |
| - for i in ds: |
| 637 | + for i in range(100): |
452 | 638 | # Please uncomment line 235 to visualize the tokenized message
|
453 |
| - print(i) |
454 |
| - break |
| 639 | + print(ds[i]) |
0 commit comments