From fbd28ecd3f559cafb25cf776bb216d31c69f9fb6 Mon Sep 17 00:00:00 2001 From: Luca Della Libera <34525085+lucadellalib@users.noreply.github.com> Date: Mon, 5 Feb 2024 01:49:43 -0500 Subject: [PATCH 1/5] Update README.md --- benchmarks/CL_MASR/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/CL_MASR/README.md b/benchmarks/CL_MASR/README.md index 0ce63a236..18b2d0395 100644 --- a/benchmarks/CL_MASR/README.md +++ b/benchmarks/CL_MASR/README.md @@ -134,7 +134,7 @@ If you use the CL-MASR benchmark, please cite: ```bibtex @article{dellalibera2023clmasr, - author = {Luca Della Libera, Pooneh Mousavi, Salah Zaiem, Cem Subakan, Mirco Ravanelli}, + author = {Luca Della Libera and Pooneh Mousavi and Salah Zaiem and Cem Subakan and Mirco Ravanelli}, title = {{CL-MASR}: A Continual Learning Benchmark for Multilingual {ASR}}, journal = {arXiv preprint arXiv:2310.16931}, year = {2023}, From 44d316708c74bdedc881cec00369a345f254b34f Mon Sep 17 00:00:00 2001 From: Luca Della Libera Date: Thu, 3 Jul 2025 21:48:53 -0400 Subject: [PATCH 2/5] Update --- .../enhancement/{utils.py => common.py} | 53 +- .../enhancement/conformer/custom_model.py | 1 - .../enhancement/conformer/metrics/dnsmos.py | 1 - .../enhancement/conformer/metrics/dwer.py | 1 - .../enhancement/conformer/metrics/spk_sim.py | 1 - .../enhancement/conformer/train_dac.py | 158 -- .../conformer/train_discrete_ssl.py | 185 --- .../enhancement/conformer/train_encodec.py | 337 ---- .../conformer/train_speech_tokenizer.py | 151 -- .../VoiceBank/enhancement/conformer/utils.py | 1 - .../conformer/voicebank_prepare.py | 1 - .../enhancement/crdnn/custom_model.py | 1 - .../enhancement/crdnn/metrics/dnsmos.py | 1 - .../enhancement/crdnn/metrics/dwer.py | 1 - .../enhancement/crdnn/metrics/spk_sim.py | 1 - .../enhancement/crdnn/train_continuous_ssl.py | 345 ----- .../VoiceBank/enhancement/crdnn/train_dac.py | 158 -- .../enhancement/crdnn/train_discrete_ssl.py | 185 --- .../crdnn/train_speech_tokenizer.py | 151 -- .../DASB/VoiceBank/enhancement/crdnn/utils.py | 1 - .../enhancement/crdnn/voicebank_prepare.py | 1 - .../VoiceBank/enhancement/custom_model.py | 1 + .../CRDNN/train.yaml} | 203 +-- .../CRDNN}/train_continuous_wavlm.yaml | 51 +- .../enhancement/hparams/CRDNN/train_dac.yaml | 210 +++ .../CRDNN/train_encodec.yaml} | 88 +- .../hparams/CRDNN/train_hubert.yaml | 219 +++ .../enhancement/hparams/CRDNN/train_mimi.yaml | 210 +++ .../CRDNN/train_speech_tokenizer.yaml} | 94 +- .../hparams/CRDNN/train_sqcodec.yaml | 211 +++ .../hparams/CRDNN/train_wav2vec2.yaml | 219 +++ .../hparams/CRDNN/train_wavlm.yaml | 219 +++ .../CRDNN/train_wavtokenizer.yaml} | 101 +- .../Conformer/train.yaml} | 201 +-- .../Conformer}/train_continuous_wavlm.yaml | 43 +- .../hparams/Conformer/train_dac.yaml | 200 +++ .../Conformer/train_encodec.yaml} | 80 +- .../hparams/Conformer/train_hubert.yaml | 209 +++ .../hparams/Conformer/train_mimi.yaml | 200 +++ .../Conformer/train_speech_tokenizer.yaml | 199 +++ .../Conformer/train_sqcodec.yaml} | 88 +- .../hparams/Conformer/train_wav2vec2.yaml | 209 +++ .../hparams/Conformer/train_wavlm.yaml | 209 +++ .../Conformer/train_wavtokenizer.yaml} | 91 +- .../VoiceBank/enhancement/metrics/spk_sim.py | 40 +- .../VoiceBank/enhancement/model/ __init__.py | 0 .../enhancement/model/custom_model.py | 111 ++ .../VoiceBank/enhancement/model/sq_codec.py | 1361 +++++++++++++++++ .../{crdnn/train_encodec.py => train.py} | 103 +- .../{conformer => }/train_continuous_ssl.py | 29 +- .../VoiceBank/enhancement/utils/__init__.py | 0 .../enhancement/utils/aggregate_results.py | 149 ++ .../enhancement/utils/audio_tokens.py | 193 +++ .../DASB/VoiceBank/enhancement/utils/data.py | 91 ++ .../DASB/VoiceBank/enhancement/utils/eval.py | 1028 +++++++++++++ .../enhancement/utils/preparation.py | 470 ++++++ .../enhancement/utils/tokenizer_interface.py | 515 +++++++ .../VoiceBank/enhancement/utils/tokens.py | 411 +++++ .../enhancement/voicebank_prepare.py | 1 + .../DASB/VoiceBank/extraction/extract.py | 105 ++ .../VoiceBank/extraction/hparams/dac.yaml | 70 + .../extraction/hparams/discrete_ssl.yaml | 112 ++ .../VoiceBank/extraction/hparams/encodec.yaml | 67 + .../VoiceBank/extraction/hparams/mimi.yaml | 61 + .../extraction/hparams/speech_tokenizer.yaml | 59 + .../VoiceBank/extraction/hparams/sqcodec.yaml | 61 + .../extraction/hparams/wavtokenizer.yaml | 64 + .../VoiceBank/extraction/voicebank_prepare.py | 1 + 68 files changed, 7931 insertions(+), 2461 deletions(-) rename benchmarks/DASB/VoiceBank/enhancement/{utils.py => common.py} (86%) delete mode 120000 benchmarks/DASB/VoiceBank/enhancement/conformer/custom_model.py delete mode 120000 benchmarks/DASB/VoiceBank/enhancement/conformer/metrics/dnsmos.py delete mode 120000 benchmarks/DASB/VoiceBank/enhancement/conformer/metrics/dwer.py delete mode 120000 benchmarks/DASB/VoiceBank/enhancement/conformer/metrics/spk_sim.py delete mode 100644 benchmarks/DASB/VoiceBank/enhancement/conformer/train_dac.py delete mode 100644 benchmarks/DASB/VoiceBank/enhancement/conformer/train_discrete_ssl.py delete mode 100644 benchmarks/DASB/VoiceBank/enhancement/conformer/train_encodec.py delete mode 100644 benchmarks/DASB/VoiceBank/enhancement/conformer/train_speech_tokenizer.py delete mode 120000 benchmarks/DASB/VoiceBank/enhancement/conformer/utils.py delete mode 120000 benchmarks/DASB/VoiceBank/enhancement/conformer/voicebank_prepare.py delete mode 120000 benchmarks/DASB/VoiceBank/enhancement/crdnn/custom_model.py delete mode 120000 benchmarks/DASB/VoiceBank/enhancement/crdnn/metrics/dnsmos.py delete mode 120000 benchmarks/DASB/VoiceBank/enhancement/crdnn/metrics/dwer.py delete mode 120000 benchmarks/DASB/VoiceBank/enhancement/crdnn/metrics/spk_sim.py delete mode 100644 benchmarks/DASB/VoiceBank/enhancement/crdnn/train_continuous_ssl.py delete mode 100644 benchmarks/DASB/VoiceBank/enhancement/crdnn/train_dac.py delete mode 100644 benchmarks/DASB/VoiceBank/enhancement/crdnn/train_discrete_ssl.py delete mode 100644 benchmarks/DASB/VoiceBank/enhancement/crdnn/train_speech_tokenizer.py delete mode 120000 benchmarks/DASB/VoiceBank/enhancement/crdnn/utils.py delete mode 120000 benchmarks/DASB/VoiceBank/enhancement/crdnn/voicebank_prepare.py create mode 120000 benchmarks/DASB/VoiceBank/enhancement/custom_model.py rename benchmarks/DASB/VoiceBank/enhancement/{crdnn/hparams/train_discrete_ssl.yaml => hparams/CRDNN/train.yaml} (54%) rename benchmarks/DASB/VoiceBank/enhancement/{crdnn/hparams => hparams/CRDNN}/train_continuous_wavlm.yaml (77%) create mode 100644 benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_dac.yaml rename benchmarks/DASB/VoiceBank/enhancement/{crdnn/hparams/train_speech_tokenizer.yaml => hparams/CRDNN/train_encodec.yaml} (72%) create mode 100644 benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_hubert.yaml create mode 100644 benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_mimi.yaml rename benchmarks/DASB/VoiceBank/enhancement/{crdnn/hparams/train_dac.yaml => hparams/CRDNN/train_speech_tokenizer.yaml} (68%) create mode 100644 benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_sqcodec.yaml create mode 100644 benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wav2vec2.yaml create mode 100644 benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wavlm.yaml rename benchmarks/DASB/VoiceBank/enhancement/{crdnn/hparams/train_encodec.yaml => hparams/CRDNN/train_wavtokenizer.yaml} (68%) rename benchmarks/DASB/VoiceBank/enhancement/{conformer/hparams/train_discrete_ssl.yaml => hparams/Conformer/train.yaml} (51%) rename benchmarks/DASB/VoiceBank/enhancement/{conformer/hparams => hparams/Conformer}/train_continuous_wavlm.yaml (77%) create mode 100644 benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_dac.yaml rename benchmarks/DASB/VoiceBank/enhancement/{conformer/hparams/train_speech_tokenizer.yaml => hparams/Conformer/train_encodec.yaml} (71%) create mode 100644 benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_hubert.yaml create mode 100644 benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_mimi.yaml create mode 100644 benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_speech_tokenizer.yaml rename benchmarks/DASB/VoiceBank/enhancement/{conformer/hparams/train_dac.yaml => hparams/Conformer/train_sqcodec.yaml} (67%) create mode 100644 benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wav2vec2.yaml create mode 100644 benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wavlm.yaml rename benchmarks/DASB/VoiceBank/enhancement/{conformer/hparams/train_encodec.yaml => hparams/Conformer/train_wavtokenizer.yaml} (68%) create mode 100644 benchmarks/DASB/VoiceBank/enhancement/model/ __init__.py create mode 100644 benchmarks/DASB/VoiceBank/enhancement/model/custom_model.py create mode 100644 benchmarks/DASB/VoiceBank/enhancement/model/sq_codec.py rename benchmarks/DASB/VoiceBank/enhancement/{crdnn/train_encodec.py => train.py} (73%) rename benchmarks/DASB/VoiceBank/enhancement/{conformer => }/train_continuous_ssl.py (89%) create mode 100644 benchmarks/DASB/VoiceBank/enhancement/utils/__init__.py create mode 100644 benchmarks/DASB/VoiceBank/enhancement/utils/aggregate_results.py create mode 100644 benchmarks/DASB/VoiceBank/enhancement/utils/audio_tokens.py create mode 100644 benchmarks/DASB/VoiceBank/enhancement/utils/data.py create mode 100644 benchmarks/DASB/VoiceBank/enhancement/utils/eval.py create mode 100644 benchmarks/DASB/VoiceBank/enhancement/utils/preparation.py create mode 100644 benchmarks/DASB/VoiceBank/enhancement/utils/tokenizer_interface.py create mode 100644 benchmarks/DASB/VoiceBank/enhancement/utils/tokens.py create mode 120000 benchmarks/DASB/VoiceBank/enhancement/voicebank_prepare.py create mode 100644 benchmarks/DASB/VoiceBank/extraction/extract.py create mode 100644 benchmarks/DASB/VoiceBank/extraction/hparams/dac.yaml create mode 100644 benchmarks/DASB/VoiceBank/extraction/hparams/discrete_ssl.yaml create mode 100644 benchmarks/DASB/VoiceBank/extraction/hparams/encodec.yaml create mode 100644 benchmarks/DASB/VoiceBank/extraction/hparams/mimi.yaml create mode 100644 benchmarks/DASB/VoiceBank/extraction/hparams/speech_tokenizer.yaml create mode 100644 benchmarks/DASB/VoiceBank/extraction/hparams/sqcodec.yaml create mode 100644 benchmarks/DASB/VoiceBank/extraction/hparams/wavtokenizer.yaml create mode 120000 benchmarks/DASB/VoiceBank/extraction/voicebank_prepare.py diff --git a/benchmarks/DASB/VoiceBank/enhancement/utils.py b/benchmarks/DASB/VoiceBank/enhancement/common.py similarity index 86% rename from benchmarks/DASB/VoiceBank/enhancement/utils.py rename to benchmarks/DASB/VoiceBank/enhancement/common.py index 2e740db09..227b839fe 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/utils.py +++ b/benchmarks/DASB/VoiceBank/enhancement/common.py @@ -4,8 +4,6 @@ * Luca Della Libera 2024 """ -import os - import speechbrain as sb import torch import torchaudio @@ -100,30 +98,6 @@ def dataio_prepare( It also defines the data processing pipeline through user-defined functions. """ - if isinstance(train_csv, (list, tuple)): - csvs = [os.path.basename(x) for x in train_csv] - save_folder = os.path.dirname(train_csv[0]) - merge_csvs( - save_folder, csvs, "train.csv", - ) - train_csv = os.path.join(save_folder, "train.csv") - - if isinstance(valid_csv, (list, tuple)): - csvs = [os.path.basename(x) for x in valid_csv] - save_folder = os.path.dirname(valid_csv[0]) - merge_csvs( - save_folder, csvs, "valid.csv", - ) - valid_csv = os.path.join(save_folder, "valid.csv") - - if isinstance(test_csv, (list, tuple)): - csvs = [os.path.basename(x) for x in test_csv] - save_folder = os.path.dirname(test_csv[0]) - merge_csvs( - save_folder, csvs, "test.csv", - ) - test_csv = os.path.join(save_folder, "test.csv") - train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( csv_path=train_csv, replacements={"DATA_ROOT": data_folder}, ) @@ -155,10 +129,33 @@ def dataio_prepare( ) datasets = [train_data, valid_data, test_data] + output_keys = ["id"] + + if "tokens_loader_in" in hparams and "tokens_loader_out" in hparams: + # Define tokens pipeline: + takes = ["id"] + provides = ["in_toks", "out_toks"] + output_keys += provides + + tokens_loader_in = hparams["tokens_loader_in"] + tokens_loader_out = hparams["tokens_loader_out"] + num_codebooks = hparams["num_codebooks"] + + def toks_pipeline(id): + in_toks = tokens_loader_in.tokens_by_uttid(id, num_codebooks=num_codebooks) + yield in_toks + + out_toks = tokens_loader_out.tokens_by_uttid(id, num_codebooks=num_codebooks) + yield out_toks + + sb.dataio.dataset.add_dynamic_item( + datasets, toks_pipeline, takes, provides + ) # Define audio pipeline takes = ["clean_wav", "noisy_wav"] provides = ["in_sig", "out_sig"] + output_keys += provides def audio_pipeline(clean_wav, noisy_wav): # Clean signal @@ -185,11 +182,11 @@ def audio_pipeline(clean_wav, noisy_wav): yield out_sig sb.dataio.dataset.add_dynamic_item( - [train_data, valid_data, test_data], audio_pipeline, takes, provides + datasets, audio_pipeline, takes, provides ) # Set output - sb.dataio.dataset.set_output_keys(datasets, ["id"] + provides) + sb.dataio.dataset.set_output_keys(datasets, output_keys) return train_data, valid_data, test_data diff --git a/benchmarks/DASB/VoiceBank/enhancement/conformer/custom_model.py b/benchmarks/DASB/VoiceBank/enhancement/conformer/custom_model.py deleted file mode 120000 index 4b3f08ebb..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/conformer/custom_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../model/custom_model.py \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/conformer/metrics/dnsmos.py b/benchmarks/DASB/VoiceBank/enhancement/conformer/metrics/dnsmos.py deleted file mode 120000 index 4ed89e197..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/conformer/metrics/dnsmos.py +++ /dev/null @@ -1 +0,0 @@ -../../metrics/dnsmos.py \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/conformer/metrics/dwer.py b/benchmarks/DASB/VoiceBank/enhancement/conformer/metrics/dwer.py deleted file mode 120000 index fe0803d67..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/conformer/metrics/dwer.py +++ /dev/null @@ -1 +0,0 @@ -../../metrics/dwer.py \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/conformer/metrics/spk_sim.py b/benchmarks/DASB/VoiceBank/enhancement/conformer/metrics/spk_sim.py deleted file mode 120000 index 961df0736..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/conformer/metrics/spk_sim.py +++ /dev/null @@ -1 +0,0 @@ -../../metrics/spk_sim.py \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/conformer/train_dac.py b/benchmarks/DASB/VoiceBank/enhancement/conformer/train_dac.py deleted file mode 100644 index 2fe919e6a..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/conformer/train_dac.py +++ /dev/null @@ -1,158 +0,0 @@ -#!/usr/bin/env/python - -"""Recipe for training a transformer-based speech enhancement system using DAC audio representations. - -To run this recipe: -> python train_dac.py hparams/.yaml - -Authors - * Luca Della Libera 2024 -""" - -import os -import sys -import warnings - -import speechbrain as sb -import torch -from hyperpyyaml import load_hyperpyyaml -from speechbrain.utils.distributed import run_on_main - -from train_encodec import Enhancement as EnhancementEncodec - - -class Enhancement(EnhancementEncodec): - @torch.no_grad() - def sig_to_toks(self, sig, lens): - # sig: [B, T] - self.hparams.codec.to(self.device).eval() - toks, _ = self.hparams.codec( - sig[:, None], n_quantizers=self.hparams.num_codebooks - ) # [B, K, N] - toks = toks.movedim(-1, -2) # [B, N, K] - return toks - - @torch.no_grad() - def toks_to_sig(self, toks): - # toks: [B, N, K] - self.hparams.codec.to(self.device).eval() - qfeats, _, _ = self.hparams.codec.quantizer.from_codes( - toks.movedim(-1, -2) # [B, K, N] - ) - sig = self.hparams.codec.decode(qfeats)[:, 0] # [B, T] - return sig - - -if __name__ == "__main__": - # Command-line interface - hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) - with open(hparams_file) as fin: - hparams = load_hyperpyyaml(fin, overrides) - - # Filter warnings - warnings.filterwarnings("once") - warnings.filterwarnings("ignore", module="torch") - - # If --distributed_launch then create ddp_init_group with the right communication protocol - sb.utils.distributed.ddp_init_group(run_opts) - - # Create experiment directory - sb.create_experiment_directory( - experiment_directory=hparams["output_folder"], - hyperparams_to_save=hparams_file, - overrides=overrides, - ) - - # Dataset preparation - from voicebank_prepare import prepare_voicebank as prepare_data - - prepare_data_kwargs = { - "data_folder": hparams["data_folder"], - "save_folder": hparams["save_folder"], - "splits": hparams["splits"], - "num_valid_speakers": hparams["num_valid_speakers"], - } - - # Due to DDP, do the preparation ONLY on the main Python process - run_on_main(prepare_data, kwargs=prepare_data_kwargs) - - # Create the datasets objects - from utils import dataio_prepare - - train_data, valid_data, test_data = dataio_prepare( - debug=run_opts.get("debug", False), **hparams - ) - - # Pretrain the specified modules - if "pretrainer" in hparams: - run_on_main(hparams["pretrainer"].collect_files) - run_on_main(hparams["pretrainer"].load_collected) - - # Use pretrained embeddings - if hparams["pretrain_embedding"]: - # See https://github.com/descriptinc/descript-audio-codec/blob/c7cfc5d2647e26471dc394f95846a0830e7bec34/dac/nn/quantize.py#L200 - toks = torch.arange(hparams["vocab_size"], device=run_opts["device"]) - toks = ( - toks[:, None, None].expand(-1, hparams["num_codebooks"], -1).clone() - ) # [C, K, 1] - hparams["codec"].to(run_opts["device"]).eval() - with torch.no_grad(): - z_q, z_p, _ = hparams["codec"].quantizer.from_codes(toks) - z_ps = z_p.split(z_p.shape[1] // toks.shape[1], dim=1) # [C, D, 1] * K - z_qs = [] - for i, z_p_i in enumerate(z_ps): - with torch.no_grad(): - z_q_i = ( - hparams["codec"].quantizer.quantizers[i].out_proj(z_p_i) - ) # [C, H, 1] - z_qs.append(z_q_i) - assert (z_q == sum(z_qs)).all() - # Embeddings pre-projections: size = 8 - # embs = torch.cat(z_ps)[:, :, 0] - # Embeddings post-projections: size = 1024 - embs = torch.cat(z_qs)[:, :, 0] # [CK, H] - hparams["embedding"].embedding.weight.data.copy_(embs) - - # Log number of parameters/buffers - codec_params = sum( - [x.numel() for x in hparams["codec"].state_dict().values()] - ) - model_params = sum( - [ - x.numel() - for module in hparams["modules"].values() - for x in module.state_dict().values() - ] - ) - hparams["train_logger"].log_stats( - stats_meta={ - f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", - "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", - }, - ) - - # Trainer initialization - brain = Enhancement( - modules=hparams["modules"], - opt_class=hparams["opt_class"], - hparams=hparams, - run_opts=run_opts, - checkpointer=hparams["checkpointer"], - ) - - # Train - brain.fit( - brain.hparams.epoch_counter, - train_data, - valid_data, - train_loader_kwargs=hparams["train_dataloader_kwargs"], - valid_loader_kwargs=hparams["valid_dataloader_kwargs"], - ) - - # Test - brain.hparams.dwer_file = os.path.join(hparams["output_folder"], "dwer.txt") - brain.evaluate( - test_data, - min_key="TER", - test_loader_kwargs=hparams["test_dataloader_kwargs"], - ) diff --git a/benchmarks/DASB/VoiceBank/enhancement/conformer/train_discrete_ssl.py b/benchmarks/DASB/VoiceBank/enhancement/conformer/train_discrete_ssl.py deleted file mode 100644 index f6811b5fa..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/conformer/train_discrete_ssl.py +++ /dev/null @@ -1,185 +0,0 @@ -#!/usr/bin/env/python - -"""Recipe for training a transformer-based speech enhancement system using discrete SSL audio representations. - -To run this recipe: -> python train_discrete_ssl.py hparams/.yaml - -Authors - * Luca Della Libera 2024 -""" - -import os -import sys -import warnings - -import speechbrain as sb -import torch -from hyperpyyaml import load_hyperpyyaml -from speechbrain.utils.distributed import run_on_main - -from train_encodec import Enhancement as EnhancementEncodec - - -# To use in configuration files -def len_(SSL_layers, vocab_size): - return len(SSL_layers) * vocab_size - - -class Enhancement(EnhancementEncodec): - @torch.no_grad() - def sig_to_toks(self, sig, lens): - # sig: [B, T] - self.hparams.codec_quantizer.to(self.device).eval() - toks, _, _ = self.hparams.codec_quantizer( - sig, - lens, - SSL_layers=self.hparams.SSL_layers, - deduplicates=[False] * len(self.hparams.SSL_layers), - bpe_tokenizers=[None] * len(self.hparams.SSL_layers), - ) # [B, N, K] - return toks - - @torch.no_grad() - def toks_to_sig(self, toks): - # toks: [B, N, K] - self.hparams.codec_vocoder.device = self.device - self.hparams.codec_vocoder.to(self.device).eval() - - # Add offset for embedding layer - all_layer_ids = [1, 3, 7, 12, 18, 23] - offsets = torch.arange( - 0, - len(all_layer_ids) * self.hparams.vocab_size, - self.hparams.vocab_size, - device=self.device, - ) - offset_idxes = [all_layer_ids.index(x) for x in self.hparams.SSL_layers] - offsets = offsets[offset_idxes] - toks = toks + offsets + 1 - - # Handle missing codebooks - if len(self.hparams.SSL_layers) < len(all_layer_ids): - full_toks = torch.zeros( - *toks.shape[:2], - len(all_layer_ids), - dtype=toks.dtype, - device=self.device, - ) - for i, idx in enumerate(offset_idxes): - full_toks[..., idx] = toks[..., i] - toks = full_toks - - self.hparams.codec_vocoder.tokenize = False - sig = self.hparams.codec_vocoder(toks)[:, 0] # [B, T] - return sig - - -if __name__ == "__main__": - # Command-line interface - hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) - with open(hparams_file) as fin: - hparams = load_hyperpyyaml(fin, overrides) - - # Filter warnings - warnings.filterwarnings("once") - warnings.filterwarnings("ignore", module="torch") - - # If --distributed_launch then create ddp_init_group with the right communication protocol - sb.utils.distributed.ddp_init_group(run_opts) - - # Create experiment directory - sb.create_experiment_directory( - experiment_directory=hparams["output_folder"], - hyperparams_to_save=hparams_file, - overrides=overrides, - ) - - # Dataset preparation - from voicebank_prepare import prepare_voicebank as prepare_data - - prepare_data_kwargs = { - "data_folder": hparams["data_folder"], - "save_folder": hparams["save_folder"], - "splits": hparams["splits"], - "num_valid_speakers": hparams["num_valid_speakers"], - } - - # Due to DDP, do the preparation ONLY on the main Python process - run_on_main(prepare_data, kwargs=prepare_data_kwargs) - - # Create the datasets objects - from utils import dataio_prepare - - train_data, valid_data, test_data = dataio_prepare( - debug=run_opts.get("debug", False), **hparams - ) - - # Pretrain the specified modules - if "pretrainer" in hparams: - run_on_main(hparams["pretrainer"].collect_files) - run_on_main(hparams["pretrainer"].load_collected) - - # Use pretrained embeddings - if hparams["pretrain_embedding"]: - # See https://github.com/speechbrain/speechbrain/blob/60062c2536e8122253d6ad0e681208f554528950/speechbrain/lobes/models/huggingface_transformers/discrete_ssl.py#L197 - hparams["codec_quantizer"].to(run_opts["device"]).eval() - embs = [] - for layer_num, vocabulary in zip( - hparams["codec_quantizer"].ssl_layer_ids, - hparams["codec_quantizer"].vocabularies, - ): - if layer_num not in hparams["SSL_layers"]: - continue - embs.append( - torch.as_tensor( - vocabulary, dtype=torch.float32, device=run_opts["device"] - ) - ) - embs = torch.cat(embs) - hparams["embedding"].embedding.weight.data.copy_(embs) - - # Log number of parameters/buffers - codec_params = sum( - [x.numel() for x in hparams["codec_quantizer"].state_dict().values()] - + [x.numel() for x in hparams["codec_vocoder"].state_dict().values()] - ) - model_params = sum( - [ - x.numel() - for module in hparams["modules"].values() - for x in module.state_dict().values() - ] - ) - hparams["train_logger"].log_stats( - stats_meta={ - f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", - "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", - }, - ) - - # Trainer initialization - brain = Enhancement( - modules=hparams["modules"], - opt_class=hparams["opt_class"], - hparams=hparams, - run_opts=run_opts, - checkpointer=hparams["checkpointer"], - ) - - # Train - brain.fit( - brain.hparams.epoch_counter, - train_data, - valid_data, - train_loader_kwargs=hparams["train_dataloader_kwargs"], - valid_loader_kwargs=hparams["valid_dataloader_kwargs"], - ) - - # Test - brain.hparams.dwer_file = os.path.join(hparams["output_folder"], "dwer.txt") - brain.evaluate( - test_data, - min_key="TER", - test_loader_kwargs=hparams["test_dataloader_kwargs"], - ) diff --git a/benchmarks/DASB/VoiceBank/enhancement/conformer/train_encodec.py b/benchmarks/DASB/VoiceBank/enhancement/conformer/train_encodec.py deleted file mode 100644 index e9a0a7649..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/conformer/train_encodec.py +++ /dev/null @@ -1,337 +0,0 @@ -#!/usr/bin/env/python - -"""Recipe for training a transformer-based speech enhancement system using EnCodec audio representations. - -To run this recipe: -> python train_encodec.py hparams/.yaml - -Authors - * Luca Della Libera 2024 -""" - -import os -import sys -import warnings - -import speechbrain as sb -import torch -from hyperpyyaml import load_hyperpyyaml -from speechbrain.dataio.dataio import write_audio -from speechbrain.utils.distributed import if_main_process, run_on_main - - -_CACHE = {} - - -class Enhancement(sb.Brain): - @torch.no_grad() - def sig_to_toks(self, sig, lens): - # sig: [B, T] - self.hparams.codec.to(self.device).eval() - toks, _ = self.hparams.codec.encode(sig, lens) # [B, N, K] - return toks - - @torch.no_grad() - def toks_to_sig(self, toks): - # toks: [B, N, K] - self.hparams.codec.to(self.device).eval() - sig = self.hparams.codec.decode(toks)[:, 0] # [B, T] - return sig - - def compute_forward(self, batch, stage): - """Forward pass.""" - batch = batch.to(self.device) - in_sig, in_lens = batch.in_sig # [B, T] - out_sig, out_lens = batch.out_sig # [B, T] - - # Augment if specified - if stage == sb.Stage.TRAIN and self.hparams.augment: - in_sig, in_lens = self.hparams.augmentation(in_sig, in_lens) - - # Extract tokens (cache them at first epoch if augmentation is disabled) - key = tuple(sorted(batch.id)) - try: - in_toks, out_toks = _CACHE[key] - in_toks = in_toks.to(self.device) - out_toks = out_toks.to(self.device) - except KeyError: - assert (in_lens == out_lens).all() - sig = torch.cat([in_sig, out_sig]) # [B2, T] - lens = torch.cat([in_lens, out_lens]) # [B2, T] - toks = self.sig_to_toks(sig, lens) # [B2, N, K] - in_toks, out_toks = toks.split( - [len(in_sig), len(out_sig)] - ) # [B, N, K], [B, N, K] - out_toks = out_toks.reshape( - len(in_sig), -1, self.hparams.num_codebooks, - ) # [B, N, K] - if self.hparams.use_cache and (not self.hparams.augment): - _CACHE[key] = in_toks.cpu(), out_toks.cpu() - - # Avoid in-place modification from embedding layer - in_toks = in_toks.clone() - - # Forward embedding + attention - in_embs = self.modules.embedding(in_toks) # [B, N, K, H] - att_w = self.modules.attention_mlp(in_embs) # [B, N, K, 1] - in_embs = torch.matmul(att_w.transpose(2, -1), in_embs).squeeze( - -2 - ) # [B, N, H] - - # Forward encoder - hyp_embs = self.modules.encoder.encode(in_embs, in_lens) # [B, N, H] - - # Forward head - log_probs = ( - self.modules.head(hyp_embs) - .reshape( - len(hyp_embs), - -1, - self.hparams.num_codebooks, - self.hparams.vocab_size, - ) - .log_softmax(dim=-1) - ) # [B, N, K, C] - - return log_probs, out_toks - - def compute_objectives(self, predictions, batch, stage): - """Computes the objectives.""" - log_probs, out_toks = predictions # [B, N, K, C], [B, N, K] - - IDs = batch.id - in_sig, _ = batch.in_sig - out_sig, out_lens = batch.out_sig - - # Cross-entropy loss - loss = self.hparams.ce_loss( - log_probs.flatten(start_dim=1, end_dim=2), # [B, NK, C] - out_toks.flatten(start_dim=1), # [B, NK] - length=out_lens, - ) - - # Compute TER - if stage != sb.Stage.TRAIN: - self.ter_metric.append( - IDs, - log_probs.flatten(start_dim=1, end_dim=2), - out_toks.flatten(start_dim=1), - out_lens, - ) - - # Vocode - if stage == sb.Stage.TEST and self.hparams.compute_metrics: - hyp_toks = log_probs.argmax(dim=-1) # [B, N, K] - self.vocode(IDs, in_sig, out_sig, hyp_toks, out_toks, out_lens) - - return loss - - @torch.no_grad() - def vocode(self, IDs, in_sig, out_sig, hyp_toks, out_toks, lens): - hyp_sig = self.toks_to_sig(hyp_toks) # [B, T] - rec_sig = self.toks_to_sig(out_toks) # [B, T] - - # Adjust length - if out_sig.shape[-1] > hyp_sig.shape[-1]: - pad = [0, out_sig.shape[-1] - hyp_sig.shape[-1]] - hyp_sig = torch.nn.functional.pad( - hyp_sig, pad, mode="replicate" - ) # [B, T_out] - rec_sig = torch.nn.functional.pad( - rec_sig, pad, mode="replicate" - ) # [B, T_out] - elif out_sig.shape[-1] < hyp_sig.shape[-1]: - hyp_sig = hyp_sig.narrow(-1, 0, out_sig.shape[-1]) # [B, T_out] - rec_sig = rec_sig.narrow(-1, 0, out_sig.shape[-1]) # [B, T_out] - - self.dnsmos_metric.append(IDs, hyp_sig, lens) - self.rec_dnsmos_metric.append(IDs, rec_sig, lens) - self.ref_dnsmos_metric.append(IDs, out_sig, lens) - self.dwer_metric.append(IDs, hyp_sig, out_sig, lens) - self.wavlm_sim_metric.append(IDs, hyp_sig, out_sig, lens) - self.ecapatdnn_sim_metric.append(IDs, hyp_sig, out_sig, lens) - - if self.hparams.save_audios: - save_folder = os.path.join(self.hparams.output_folder, "audios") - os.makedirs(save_folder, exist_ok=True) - for i in range(len(IDs)): - write_audio( - os.path.join(save_folder, f"{IDs[i]}_hyp.wav"), - hyp_sig[i].cpu(), - self.hparams.sample_rate, - ) - write_audio( - os.path.join(save_folder, f"{IDs[i]}_rec.wav"), - rec_sig[i].cpu(), - self.hparams.sample_rate, - ) - write_audio( - os.path.join(save_folder, f"{IDs[i]}_ref.wav"), - out_sig[i].cpu(), - self.hparams.sample_rate, - ) - write_audio( - os.path.join(save_folder, f"{IDs[i]}_in.wav"), - in_sig[i].cpu(), - self.hparams.sample_rate, - ) - - def on_stage_start(self, stage, epoch=None): - """Gets called at the beginning of each epoch.""" - super().on_stage_start(stage, epoch) - if stage != sb.Stage.TRAIN: - self.ter_metric = self.hparams.ter_computer() - if stage == sb.Stage.TEST and self.hparams.compute_metrics: - self.dnsmos_metric = self.hparams.dnsmos_computer() - self.rec_dnsmos_metric = self.hparams.dnsmos_computer() - self.ref_dnsmos_metric = self.hparams.dnsmos_computer() - self.dwer_metric = self.hparams.dwer_computer() - self.wavlm_sim_metric = self.hparams.wavlm_sim_computer() - self.ecapatdnn_sim_metric = self.hparams.ecapatdnn_sim_computer() - - def on_stage_end(self, stage, stage_loss, epoch=None): - """Gets called at the end of each epoch.""" - # Compute/store important stats - stage_stats = {"loss": stage_loss} - - if stage == sb.Stage.TRAIN: - self.train_stats = stage_stats - else: - stage_stats["TER"] = self.ter_metric.summarize("average") * 100 - - # Perform end-of-iteration operations, like annealing, logging, etc. - if stage == sb.Stage.VALID: - _, lr = self.hparams.scheduler(stage_stats["TER"]) - sb.nnet.schedulers.update_learning_rate(self.optimizer, lr) - steps = self.optimizer_step - self.hparams.train_logger.log_stats( - stats_meta={"epoch": epoch, "lr": lr, "steps": steps}, - train_stats=self.train_stats, - valid_stats=stage_stats, - ) - self.checkpointer.save_and_keep_only( - meta={"TER": stage_stats["TER"]}, - min_keys=["TER"], - num_to_keep=self.hparams.keep_checkpoints, - ) - - elif stage == sb.Stage.TEST: - if self.hparams.compute_metrics: - stage_stats["DNSMOS"] = self.dnsmos_metric.summarize("average") - stage_stats["RecDNSMOS"] = self.rec_dnsmos_metric.summarize( - "average" - ) - stage_stats["RefDNSMOS"] = self.ref_dnsmos_metric.summarize( - "average" - ) - stage_stats["dWER"] = self.dwer_metric.summarize("error_rate") - stage_stats["WavLMSim"] = self.wavlm_sim_metric.summarize( - "average" - ) - stage_stats[ - "ECAPATDNNSim" - ] = self.ecapatdnn_sim_metric.summarize("average") - self.hparams.train_logger.log_stats( - stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, - test_stats=stage_stats, - ) - if if_main_process(): - # Save dWER - if self.hparams.compute_metrics: - with open(self.hparams.dwer_file, "w") as w: - self.dwer_metric.write_stats(w) - - -if __name__ == "__main__": - # Command-line interface - hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) - with open(hparams_file) as fin: - hparams = load_hyperpyyaml(fin, overrides) - - # Filter warnings - warnings.filterwarnings("once") - warnings.filterwarnings("ignore", module="torch") - - # If --distributed_launch then create ddp_init_group with the right communication protocol - sb.utils.distributed.ddp_init_group(run_opts) - - # Create experiment directory - sb.create_experiment_directory( - experiment_directory=hparams["output_folder"], - hyperparams_to_save=hparams_file, - overrides=overrides, - ) - - # Dataset preparation - from voicebank_prepare import prepare_voicebank as prepare_data - - prepare_data_kwargs = { - "data_folder": hparams["data_folder"], - "save_folder": hparams["save_folder"], - "splits": hparams["splits"], - "num_valid_speakers": hparams["num_valid_speakers"], - } - - # Due to DDP, do the preparation ONLY on the main Python process - run_on_main(prepare_data, kwargs=prepare_data_kwargs) - - # Create the datasets objects - from utils import dataio_prepare - - train_data, valid_data, test_data = dataio_prepare( - debug=run_opts.get("debug", False), **hparams - ) - - # Pretrain the specified modules - if "pretrainer" in hparams: - run_on_main(hparams["pretrainer"].collect_files) - run_on_main(hparams["pretrainer"].load_collected) - - # Use pretrained embeddings - if hparams["pretrain_embedding"]: - embs = hparams["codec"].vocabulary.reshape(-1, hparams["embedding_dim"]) - hparams["embedding"].embedding.weight.data.copy_(embs) - - # Log number of parameters/buffers - codec_params = sum( - [x.numel() for x in hparams["codec"].state_dict().values()] - ) - model_params = sum( - [ - x.numel() - for module in hparams["modules"].values() - for x in module.state_dict().values() - ] - ) - hparams["train_logger"].log_stats( - stats_meta={ - f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", - "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", - }, - ) - - # Trainer initialization - brain = Enhancement( - modules=hparams["modules"], - opt_class=hparams["opt_class"], - hparams=hparams, - run_opts=run_opts, - checkpointer=hparams["checkpointer"], - ) - - # Train - brain.fit( - brain.hparams.epoch_counter, - train_data, - valid_data, - train_loader_kwargs=hparams["train_dataloader_kwargs"], - valid_loader_kwargs=hparams["valid_dataloader_kwargs"], - ) - - # Test - brain.hparams.dwer_file = os.path.join(hparams["output_folder"], "dwer.txt") - brain.evaluate( - test_data, - min_key="TER", - test_loader_kwargs=hparams["test_dataloader_kwargs"], - ) diff --git a/benchmarks/DASB/VoiceBank/enhancement/conformer/train_speech_tokenizer.py b/benchmarks/DASB/VoiceBank/enhancement/conformer/train_speech_tokenizer.py deleted file mode 100644 index c25d78e26..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/conformer/train_speech_tokenizer.py +++ /dev/null @@ -1,151 +0,0 @@ -#!/usr/bin/env/python - -"""Recipe for training a transformer-based speech enhancement system using SpeechTokenizer audio representations. - -To run this recipe: -> python train_speech_tokenizer.py hparams/.yaml - -Authors - * Luca Della Libera 2024 -""" - -import os -import sys -import warnings - -import speechbrain as sb -import torch -from hyperpyyaml import load_hyperpyyaml -from speechbrain.utils.distributed import run_on_main - -from train_encodec import Enhancement as EnhancementEncodec - - -class Enhancement(EnhancementEncodec): - @torch.no_grad() - def sig_to_toks(self, sig, lens): - # sig: [B, T] - self.hparams.codec.to(self.device).eval() - toks = self.hparams.codec(sig)[ - : self.hparams.num_codebooks - ] # [K, B, N] - toks = toks.movedim(-3, -1) # [B, N, K] - return toks - - @torch.no_grad() - def toks_to_sig(self, toks): - # toks: [B, N, K] - self.hparams.codec.to(self.device).eval() - toks = toks.movedim(-1, -3) # [K, B, N] - sig = self.hparams.codec.decode(toks) # [B, T] - return sig - - -if __name__ == "__main__": - # Command-line interface - hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) - with open(hparams_file) as fin: - hparams = load_hyperpyyaml(fin, overrides) - - # Filter warnings - warnings.filterwarnings("once") - warnings.filterwarnings("ignore", module="torch") - - # If --distributed_launch then create ddp_init_group with the right communication protocol - sb.utils.distributed.ddp_init_group(run_opts) - - # Create experiment directory - sb.create_experiment_directory( - experiment_directory=hparams["output_folder"], - hyperparams_to_save=hparams_file, - overrides=overrides, - ) - - # Dataset preparation - from voicebank_prepare import prepare_voicebank as prepare_data - - prepare_data_kwargs = { - "data_folder": hparams["data_folder"], - "save_folder": hparams["save_folder"], - "splits": hparams["splits"], - "num_valid_speakers": hparams["num_valid_speakers"], - } - - # Due to DDP, do the preparation ONLY on the main Python process - run_on_main(prepare_data, kwargs=prepare_data_kwargs) - - # Create the datasets objects - from utils import dataio_prepare - - train_data, valid_data, test_data = dataio_prepare( - debug=run_opts.get("debug", False), **hparams - ) - - # Pretrain the specified modules - if "pretrainer" in hparams: - run_on_main(hparams["pretrainer"].collect_files) - run_on_main(hparams["pretrainer"].load_collected) - - # Use pretrained embeddings - if hparams["pretrain_embedding"]: - # See https://github.com/ZhangXInFD/SpeechTokenizer/blob/a9f88dc72642b600654a62861e34342babae6c71/speechtokenizer/quantization/core_vq.py#L360 - toks = torch.arange(hparams["vocab_size"], device=run_opts["device"]) - toks = ( - toks[None, :, None].expand(hparams["num_codebooks"], -1, -1).clone() - ) # [K, C, 1] - hparams["codec"].to(run_opts["device"]).eval() - embs = [] - for i, indices in enumerate(toks): - layer = hparams["codec"].model.quantizer.vq.layers[i] - with torch.no_grad(): - quantized = layer.decode(indices) # [C, H, 1] - embs.append(quantized) - assert ( - hparams["codec"].model.quantizer.decode(toks) == sum(embs) - ).all() - embs = torch.cat(embs)[:, :, 0] # [CK, H] - hparams["embedding"].embedding.weight.data.copy_(embs) - - # Log number of parameters/buffers - codec_params = sum( - [x.numel() for x in hparams["codec"].state_dict().values()] - ) - model_params = sum( - [ - x.numel() - for module in hparams["modules"].values() - for x in module.state_dict().values() - ] - ) - hparams["train_logger"].log_stats( - stats_meta={ - f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", - "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", - }, - ) - - # Trainer initialization - brain = Enhancement( - modules=hparams["modules"], - opt_class=hparams["opt_class"], - hparams=hparams, - run_opts=run_opts, - checkpointer=hparams["checkpointer"], - ) - - # Train - brain.fit( - brain.hparams.epoch_counter, - train_data, - valid_data, - train_loader_kwargs=hparams["train_dataloader_kwargs"], - valid_loader_kwargs=hparams["valid_dataloader_kwargs"], - ) - - # Test - brain.hparams.dwer_file = os.path.join(hparams["output_folder"], "dwer.txt") - brain.evaluate( - test_data, - min_key="TER", - test_loader_kwargs=hparams["test_dataloader_kwargs"], - ) diff --git a/benchmarks/DASB/VoiceBank/enhancement/conformer/utils.py b/benchmarks/DASB/VoiceBank/enhancement/conformer/utils.py deleted file mode 120000 index 50fbc6d8f..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/conformer/utils.py +++ /dev/null @@ -1 +0,0 @@ -../utils.py \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/conformer/voicebank_prepare.py b/benchmarks/DASB/VoiceBank/enhancement/conformer/voicebank_prepare.py deleted file mode 120000 index 66cb2e6cc..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/conformer/voicebank_prepare.py +++ /dev/null @@ -1 +0,0 @@ -../../voicebank_prepare.py \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/crdnn/custom_model.py b/benchmarks/DASB/VoiceBank/enhancement/crdnn/custom_model.py deleted file mode 120000 index 4b3f08ebb..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/crdnn/custom_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../model/custom_model.py \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/crdnn/metrics/dnsmos.py b/benchmarks/DASB/VoiceBank/enhancement/crdnn/metrics/dnsmos.py deleted file mode 120000 index 4ed89e197..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/crdnn/metrics/dnsmos.py +++ /dev/null @@ -1 +0,0 @@ -../../metrics/dnsmos.py \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/crdnn/metrics/dwer.py b/benchmarks/DASB/VoiceBank/enhancement/crdnn/metrics/dwer.py deleted file mode 120000 index fe0803d67..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/crdnn/metrics/dwer.py +++ /dev/null @@ -1 +0,0 @@ -../../metrics/dwer.py \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/crdnn/metrics/spk_sim.py b/benchmarks/DASB/VoiceBank/enhancement/crdnn/metrics/spk_sim.py deleted file mode 120000 index 961df0736..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/crdnn/metrics/spk_sim.py +++ /dev/null @@ -1 +0,0 @@ -../../metrics/spk_sim.py \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/crdnn/train_continuous_ssl.py b/benchmarks/DASB/VoiceBank/enhancement/crdnn/train_continuous_ssl.py deleted file mode 100644 index a4c55687c..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/crdnn/train_continuous_ssl.py +++ /dev/null @@ -1,345 +0,0 @@ -#!/usr/bin/env/python - -"""Recipe for training a transformer-based speech enhancement system using continuous SSL audio representations. - -To run this recipe: -> python train_continuous_ssl.py hparams/.yaml - -Authors - * Luca Della Libera 2024 -""" - -import os -import sys -import warnings - -import speechbrain as sb -import torch -from hyperpyyaml import load_hyperpyyaml -from speechbrain.dataio.dataio import write_audio -from speechbrain.utils.distributed import if_main_process, run_on_main - - -_CACHE = {} - - -# To use in configuration files -def len_(SSL_layers, embedding_dim): - return len(SSL_layers) * embedding_dim - - -class Enhancement(sb.Brain): - @torch.no_grad() - def sig_to_embs(self, sig, lens): - # sig: [B, T] - self.hparams.ssl_model.to(self.device).eval() - embs = self.hparams.ssl_model(sig, lens)[ - self.hparams.SSL_layers - ] # [K, B, N, H] - embs = embs.movedim(0, -2) # [B, N, K, H] - return embs - - @torch.no_grad() - def embs_to_sig(self, embs): - # embs: [B, N, K, H] - self.hparams.ssl_vocoder.device = self.device - self.hparams.ssl_vocoder.to(self.device).eval() - - # Handle missing codebooks - all_layer_ids = [1, 3, 7, 12, 18, 23] - if len(self.hparams.SSL_layers) < len(all_layer_ids): - offset_idxes = [ - all_layer_ids.index(x) for x in self.hparams.SSL_layers - ] - full_embs = torch.zeros( - *embs.shape[:2], - len(all_layer_ids), - embs.shape[-1], - dtype=embs.dtype, - device=self.device, - ) - for i, idx in enumerate(offset_idxes): - full_embs[..., idx, :] = embs[..., i, :] - embs = full_embs - - self.hparams.ssl_vocoder.tokenize = False - sig = self.hparams.ssl_vocoder(embs)[:, 0] # [B, T] - return sig - - def compute_forward(self, batch, stage): - """Forward pass.""" - batch = batch.to(self.device) - in_sig, in_lens = batch.in_sig # [B, T] - out_sig, out_lens = batch.out_sig # [B, T] - - # Augment if specified - if stage == sb.Stage.TRAIN and self.hparams.augment: - in_sig, in_lens = self.hparams.augmentation(in_sig, in_lens) - - # Extract features (cache them at first epoch if augmentation is disabled) - key = tuple(sorted(batch.id)) - try: - in_embs, out_embs = _CACHE[key] - in_embs = in_embs.to(self.device) - out_embs = out_embs.to(self.device) - except KeyError: - assert (in_lens == out_lens).all() - sig = torch.cat([in_sig, out_sig]) # [B2, T] - lens = torch.cat([in_lens, out_lens]) # [B2, T] - embs = self.sig_to_embs(sig, lens) # [B2, N, K, H] - in_embs, out_embs = embs.split( - [len(in_sig), len(out_sig)] - ) # [B, N, K, H], [B, N, K, H] - out_embs = out_embs.reshape( - len(in_sig), - -1, - self.hparams.num_codebooks, - self.hparams.embedding_dim, - ) # [B, N, K, H] - if self.hparams.use_cache and (not self.hparams.augment): - _CACHE[key] = in_embs.cpu(), out_embs.cpu() - - # Avoid in-place modification from attention - in_embs = in_embs.clone() - - # Forward attention - att_w = self.modules.attention_mlp(in_embs) # [B, N, K, 1] - in_embs = torch.matmul(att_w.transpose(2, -1), in_embs).squeeze( - -2 - ) # [B, N, H] - - # Forward encoder - hyp_embs = self.modules.encoder(in_embs) - - # Forward head - hyp_embs = self.modules.head(hyp_embs).reshape( - len(hyp_embs), - -1, - self.hparams.num_codebooks, - self.hparams.embedding_dim, - ) # [B, N, K, H] - - return hyp_embs, out_embs - - def compute_objectives(self, predictions, batch, stage): - """Computes the objectives.""" - hyp_embs, out_embs = predictions # [B, N, K, H], [B, N, K, H] - - IDs = batch.id - in_sig, _ = batch.in_sig - out_sig, out_lens = batch.out_sig - - # Reconstruction loss - loss = self.hparams.rec_loss( - hyp_embs.flatten(start_dim=1, end_dim=-2), # [B, NK, H] - out_embs.flatten(start_dim=1, end_dim=-2), # [B, NK, H] - length=out_lens, - ) - - # Vocode - if stage == sb.Stage.TEST and self.hparams.compute_metrics: - self.vocode(IDs, in_sig, out_sig, hyp_embs, out_embs, out_lens) - - return loss - - @torch.no_grad() - def vocode(self, IDs, in_sig, out_sig, hyp_embs, out_embs, lens): - hyp_sig = self.embs_to_sig(hyp_embs) # [B, T] - rec_sig = self.embs_to_sig(out_embs) # [B, T] - - # Adjust length - if out_sig.shape[-1] > hyp_sig.shape[-1]: - pad = [0, out_sig.shape[-1] - hyp_sig.shape[-1]] - hyp_sig = torch.nn.functional.pad( - hyp_sig, pad, mode="replicate" - ) # [B, T_out] - rec_sig = torch.nn.functional.pad( - rec_sig, pad, mode="replicate" - ) # [B, T_out] - elif out_sig.shape[-1] < hyp_sig.shape[-1]: - hyp_sig = hyp_sig.narrow(-1, 0, out_sig.shape[-1]) # [B, T_out] - rec_sig = rec_sig.narrow(-1, 0, out_sig.shape[-1]) # [B, T_out] - - self.dnsmos_metric.append(IDs, hyp_sig, lens) - self.rec_dnsmos_metric.append(IDs, rec_sig, lens) - self.ref_dnsmos_metric.append(IDs, out_sig, lens) - self.dwer_metric.append(IDs, hyp_sig, out_sig, lens) - self.wavlm_sim_metric.append(IDs, hyp_sig, out_sig, lens) - self.ecapatdnn_sim_metric.append(IDs, hyp_sig, out_sig, lens) - - if self.hparams.save_audios: - save_folder = os.path.join(self.hparams.output_folder, "audios") - os.makedirs(save_folder, exist_ok=True) - for i in range(len(IDs)): - write_audio( - os.path.join(save_folder, f"{IDs[i]}_hyp.wav"), - hyp_sig[i].cpu(), - self.hparams.sample_rate, - ) - write_audio( - os.path.join(save_folder, f"{IDs[i]}_rec.wav"), - rec_sig[i].cpu(), - self.hparams.sample_rate, - ) - write_audio( - os.path.join(save_folder, f"{IDs[i]}_ref.wav"), - out_sig[i].cpu(), - self.hparams.sample_rate, - ) - write_audio( - os.path.join(save_folder, f"{IDs[i]}_in.wav"), - in_sig[i].cpu(), - self.hparams.sample_rate, - ) - - def on_stage_start(self, stage, epoch=None): - """Gets called at the beginning of each epoch.""" - super().on_stage_start(stage, epoch) - if stage == sb.Stage.TEST and self.hparams.compute_metrics: - self.dnsmos_metric = self.hparams.dnsmos_computer() - self.rec_dnsmos_metric = self.hparams.dnsmos_computer() - self.ref_dnsmos_metric = self.hparams.dnsmos_computer() - self.dwer_metric = self.hparams.dwer_computer() - self.wavlm_sim_metric = self.hparams.wavlm_sim_computer() - self.ecapatdnn_sim_metric = self.hparams.ecapatdnn_sim_computer() - - def on_stage_end(self, stage, stage_loss, epoch=None): - """Gets called at the end of each epoch.""" - # Compute/store important stats - stage_stats = {"loss": stage_loss} - - if stage == sb.Stage.TRAIN: - self.train_stats = stage_stats - - # Perform end-of-iteration operations, like annealing, logging, etc. - if stage == sb.Stage.VALID: - _, lr = self.hparams.scheduler(stage_stats["loss"]) - sb.nnet.schedulers.update_learning_rate(self.optimizer, lr) - steps = self.optimizer_step - self.hparams.train_logger.log_stats( - stats_meta={"epoch": epoch, "lr": lr, "steps": steps}, - train_stats=self.train_stats, - valid_stats=stage_stats, - ) - self.checkpointer.save_and_keep_only( - meta={"loss": stage_stats["loss"]}, - min_keys=["loss"], - num_to_keep=self.hparams.keep_checkpoints, - ) - - elif stage == sb.Stage.TEST: - if self.hparams.compute_metrics: - stage_stats["DNSMOS"] = self.dnsmos_metric.summarize("average") - stage_stats["RecDNSMOS"] = self.rec_dnsmos_metric.summarize( - "average" - ) - stage_stats["RefDNSMOS"] = self.ref_dnsmos_metric.summarize( - "average" - ) - stage_stats["dWER"] = self.dwer_metric.summarize("error_rate") - stage_stats["WavLMSim"] = self.wavlm_sim_metric.summarize( - "average" - ) - stage_stats[ - "ECAPATDNNSim" - ] = self.ecapatdnn_sim_metric.summarize("average") - self.hparams.train_logger.log_stats( - stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, - test_stats=stage_stats, - ) - if if_main_process(): - # Save dWER - if self.hparams.compute_metrics: - with open(self.hparams.dwer_file, "w") as w: - self.dwer_metric.write_stats(w) - - -if __name__ == "__main__": - # Command-line interface - hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) - with open(hparams_file) as fin: - hparams = load_hyperpyyaml(fin, overrides) - - # Filter warnings - warnings.filterwarnings("once") - warnings.filterwarnings("ignore", module="torch") - - # If --distributed_launch then create ddp_init_group with the right communication protocol - sb.utils.distributed.ddp_init_group(run_opts) - - # Create experiment directory - sb.create_experiment_directory( - experiment_directory=hparams["output_folder"], - hyperparams_to_save=hparams_file, - overrides=overrides, - ) - - # Dataset preparation - from voicebank_prepare import prepare_voicebank as prepare_data - - prepare_data_kwargs = { - "data_folder": hparams["data_folder"], - "save_folder": hparams["save_folder"], - "splits": hparams["splits"], - "num_valid_speakers": hparams["num_valid_speakers"], - } - - # Due to DDP, do the preparation ONLY on the main Python process - run_on_main(prepare_data, kwargs=prepare_data_kwargs) - - # Create the datasets objects - from utils import dataio_prepare - - train_data, valid_data, test_data = dataio_prepare( - debug=run_opts.get("debug", False), **hparams - ) - - # Pretrain the specified modules - if "pretrainer" in hparams: - run_on_main(hparams["pretrainer"].collect_files) - run_on_main(hparams["pretrainer"].load_collected) - - # Log number of parameters/buffers - ssl_params = sum( - [x.numel() for x in hparams["ssl_model"].state_dict().values()] - + [x.numel() for x in hparams["ssl_vocoder"].state_dict().values()] - ) - model_params = sum( - [ - x.numel() - for module in hparams["modules"].values() - for x in module.state_dict().values() - ] - ) - hparams["train_logger"].log_stats( - stats_meta={ - f"SSL parameters/buffers (M)": f"{ssl_params / 1e6:.2f}", - "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", - }, - ) - - # Trainer initialization - brain = Enhancement( - modules=hparams["modules"], - opt_class=hparams["opt_class"], - hparams=hparams, - run_opts=run_opts, - checkpointer=hparams["checkpointer"], - ) - - # Train - brain.fit( - brain.hparams.epoch_counter, - train_data, - valid_data, - train_loader_kwargs=hparams["train_dataloader_kwargs"], - valid_loader_kwargs=hparams["valid_dataloader_kwargs"], - ) - - # Test - brain.hparams.dwer_file = os.path.join(hparams["output_folder"], "dwer.txt") - brain.evaluate( - test_data, - min_key="loss", - test_loader_kwargs=hparams["test_dataloader_kwargs"], - ) diff --git a/benchmarks/DASB/VoiceBank/enhancement/crdnn/train_dac.py b/benchmarks/DASB/VoiceBank/enhancement/crdnn/train_dac.py deleted file mode 100644 index 2fe919e6a..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/crdnn/train_dac.py +++ /dev/null @@ -1,158 +0,0 @@ -#!/usr/bin/env/python - -"""Recipe for training a transformer-based speech enhancement system using DAC audio representations. - -To run this recipe: -> python train_dac.py hparams/.yaml - -Authors - * Luca Della Libera 2024 -""" - -import os -import sys -import warnings - -import speechbrain as sb -import torch -from hyperpyyaml import load_hyperpyyaml -from speechbrain.utils.distributed import run_on_main - -from train_encodec import Enhancement as EnhancementEncodec - - -class Enhancement(EnhancementEncodec): - @torch.no_grad() - def sig_to_toks(self, sig, lens): - # sig: [B, T] - self.hparams.codec.to(self.device).eval() - toks, _ = self.hparams.codec( - sig[:, None], n_quantizers=self.hparams.num_codebooks - ) # [B, K, N] - toks = toks.movedim(-1, -2) # [B, N, K] - return toks - - @torch.no_grad() - def toks_to_sig(self, toks): - # toks: [B, N, K] - self.hparams.codec.to(self.device).eval() - qfeats, _, _ = self.hparams.codec.quantizer.from_codes( - toks.movedim(-1, -2) # [B, K, N] - ) - sig = self.hparams.codec.decode(qfeats)[:, 0] # [B, T] - return sig - - -if __name__ == "__main__": - # Command-line interface - hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) - with open(hparams_file) as fin: - hparams = load_hyperpyyaml(fin, overrides) - - # Filter warnings - warnings.filterwarnings("once") - warnings.filterwarnings("ignore", module="torch") - - # If --distributed_launch then create ddp_init_group with the right communication protocol - sb.utils.distributed.ddp_init_group(run_opts) - - # Create experiment directory - sb.create_experiment_directory( - experiment_directory=hparams["output_folder"], - hyperparams_to_save=hparams_file, - overrides=overrides, - ) - - # Dataset preparation - from voicebank_prepare import prepare_voicebank as prepare_data - - prepare_data_kwargs = { - "data_folder": hparams["data_folder"], - "save_folder": hparams["save_folder"], - "splits": hparams["splits"], - "num_valid_speakers": hparams["num_valid_speakers"], - } - - # Due to DDP, do the preparation ONLY on the main Python process - run_on_main(prepare_data, kwargs=prepare_data_kwargs) - - # Create the datasets objects - from utils import dataio_prepare - - train_data, valid_data, test_data = dataio_prepare( - debug=run_opts.get("debug", False), **hparams - ) - - # Pretrain the specified modules - if "pretrainer" in hparams: - run_on_main(hparams["pretrainer"].collect_files) - run_on_main(hparams["pretrainer"].load_collected) - - # Use pretrained embeddings - if hparams["pretrain_embedding"]: - # See https://github.com/descriptinc/descript-audio-codec/blob/c7cfc5d2647e26471dc394f95846a0830e7bec34/dac/nn/quantize.py#L200 - toks = torch.arange(hparams["vocab_size"], device=run_opts["device"]) - toks = ( - toks[:, None, None].expand(-1, hparams["num_codebooks"], -1).clone() - ) # [C, K, 1] - hparams["codec"].to(run_opts["device"]).eval() - with torch.no_grad(): - z_q, z_p, _ = hparams["codec"].quantizer.from_codes(toks) - z_ps = z_p.split(z_p.shape[1] // toks.shape[1], dim=1) # [C, D, 1] * K - z_qs = [] - for i, z_p_i in enumerate(z_ps): - with torch.no_grad(): - z_q_i = ( - hparams["codec"].quantizer.quantizers[i].out_proj(z_p_i) - ) # [C, H, 1] - z_qs.append(z_q_i) - assert (z_q == sum(z_qs)).all() - # Embeddings pre-projections: size = 8 - # embs = torch.cat(z_ps)[:, :, 0] - # Embeddings post-projections: size = 1024 - embs = torch.cat(z_qs)[:, :, 0] # [CK, H] - hparams["embedding"].embedding.weight.data.copy_(embs) - - # Log number of parameters/buffers - codec_params = sum( - [x.numel() for x in hparams["codec"].state_dict().values()] - ) - model_params = sum( - [ - x.numel() - for module in hparams["modules"].values() - for x in module.state_dict().values() - ] - ) - hparams["train_logger"].log_stats( - stats_meta={ - f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", - "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", - }, - ) - - # Trainer initialization - brain = Enhancement( - modules=hparams["modules"], - opt_class=hparams["opt_class"], - hparams=hparams, - run_opts=run_opts, - checkpointer=hparams["checkpointer"], - ) - - # Train - brain.fit( - brain.hparams.epoch_counter, - train_data, - valid_data, - train_loader_kwargs=hparams["train_dataloader_kwargs"], - valid_loader_kwargs=hparams["valid_dataloader_kwargs"], - ) - - # Test - brain.hparams.dwer_file = os.path.join(hparams["output_folder"], "dwer.txt") - brain.evaluate( - test_data, - min_key="TER", - test_loader_kwargs=hparams["test_dataloader_kwargs"], - ) diff --git a/benchmarks/DASB/VoiceBank/enhancement/crdnn/train_discrete_ssl.py b/benchmarks/DASB/VoiceBank/enhancement/crdnn/train_discrete_ssl.py deleted file mode 100644 index f6811b5fa..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/crdnn/train_discrete_ssl.py +++ /dev/null @@ -1,185 +0,0 @@ -#!/usr/bin/env/python - -"""Recipe for training a transformer-based speech enhancement system using discrete SSL audio representations. - -To run this recipe: -> python train_discrete_ssl.py hparams/.yaml - -Authors - * Luca Della Libera 2024 -""" - -import os -import sys -import warnings - -import speechbrain as sb -import torch -from hyperpyyaml import load_hyperpyyaml -from speechbrain.utils.distributed import run_on_main - -from train_encodec import Enhancement as EnhancementEncodec - - -# To use in configuration files -def len_(SSL_layers, vocab_size): - return len(SSL_layers) * vocab_size - - -class Enhancement(EnhancementEncodec): - @torch.no_grad() - def sig_to_toks(self, sig, lens): - # sig: [B, T] - self.hparams.codec_quantizer.to(self.device).eval() - toks, _, _ = self.hparams.codec_quantizer( - sig, - lens, - SSL_layers=self.hparams.SSL_layers, - deduplicates=[False] * len(self.hparams.SSL_layers), - bpe_tokenizers=[None] * len(self.hparams.SSL_layers), - ) # [B, N, K] - return toks - - @torch.no_grad() - def toks_to_sig(self, toks): - # toks: [B, N, K] - self.hparams.codec_vocoder.device = self.device - self.hparams.codec_vocoder.to(self.device).eval() - - # Add offset for embedding layer - all_layer_ids = [1, 3, 7, 12, 18, 23] - offsets = torch.arange( - 0, - len(all_layer_ids) * self.hparams.vocab_size, - self.hparams.vocab_size, - device=self.device, - ) - offset_idxes = [all_layer_ids.index(x) for x in self.hparams.SSL_layers] - offsets = offsets[offset_idxes] - toks = toks + offsets + 1 - - # Handle missing codebooks - if len(self.hparams.SSL_layers) < len(all_layer_ids): - full_toks = torch.zeros( - *toks.shape[:2], - len(all_layer_ids), - dtype=toks.dtype, - device=self.device, - ) - for i, idx in enumerate(offset_idxes): - full_toks[..., idx] = toks[..., i] - toks = full_toks - - self.hparams.codec_vocoder.tokenize = False - sig = self.hparams.codec_vocoder(toks)[:, 0] # [B, T] - return sig - - -if __name__ == "__main__": - # Command-line interface - hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) - with open(hparams_file) as fin: - hparams = load_hyperpyyaml(fin, overrides) - - # Filter warnings - warnings.filterwarnings("once") - warnings.filterwarnings("ignore", module="torch") - - # If --distributed_launch then create ddp_init_group with the right communication protocol - sb.utils.distributed.ddp_init_group(run_opts) - - # Create experiment directory - sb.create_experiment_directory( - experiment_directory=hparams["output_folder"], - hyperparams_to_save=hparams_file, - overrides=overrides, - ) - - # Dataset preparation - from voicebank_prepare import prepare_voicebank as prepare_data - - prepare_data_kwargs = { - "data_folder": hparams["data_folder"], - "save_folder": hparams["save_folder"], - "splits": hparams["splits"], - "num_valid_speakers": hparams["num_valid_speakers"], - } - - # Due to DDP, do the preparation ONLY on the main Python process - run_on_main(prepare_data, kwargs=prepare_data_kwargs) - - # Create the datasets objects - from utils import dataio_prepare - - train_data, valid_data, test_data = dataio_prepare( - debug=run_opts.get("debug", False), **hparams - ) - - # Pretrain the specified modules - if "pretrainer" in hparams: - run_on_main(hparams["pretrainer"].collect_files) - run_on_main(hparams["pretrainer"].load_collected) - - # Use pretrained embeddings - if hparams["pretrain_embedding"]: - # See https://github.com/speechbrain/speechbrain/blob/60062c2536e8122253d6ad0e681208f554528950/speechbrain/lobes/models/huggingface_transformers/discrete_ssl.py#L197 - hparams["codec_quantizer"].to(run_opts["device"]).eval() - embs = [] - for layer_num, vocabulary in zip( - hparams["codec_quantizer"].ssl_layer_ids, - hparams["codec_quantizer"].vocabularies, - ): - if layer_num not in hparams["SSL_layers"]: - continue - embs.append( - torch.as_tensor( - vocabulary, dtype=torch.float32, device=run_opts["device"] - ) - ) - embs = torch.cat(embs) - hparams["embedding"].embedding.weight.data.copy_(embs) - - # Log number of parameters/buffers - codec_params = sum( - [x.numel() for x in hparams["codec_quantizer"].state_dict().values()] - + [x.numel() for x in hparams["codec_vocoder"].state_dict().values()] - ) - model_params = sum( - [ - x.numel() - for module in hparams["modules"].values() - for x in module.state_dict().values() - ] - ) - hparams["train_logger"].log_stats( - stats_meta={ - f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", - "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", - }, - ) - - # Trainer initialization - brain = Enhancement( - modules=hparams["modules"], - opt_class=hparams["opt_class"], - hparams=hparams, - run_opts=run_opts, - checkpointer=hparams["checkpointer"], - ) - - # Train - brain.fit( - brain.hparams.epoch_counter, - train_data, - valid_data, - train_loader_kwargs=hparams["train_dataloader_kwargs"], - valid_loader_kwargs=hparams["valid_dataloader_kwargs"], - ) - - # Test - brain.hparams.dwer_file = os.path.join(hparams["output_folder"], "dwer.txt") - brain.evaluate( - test_data, - min_key="TER", - test_loader_kwargs=hparams["test_dataloader_kwargs"], - ) diff --git a/benchmarks/DASB/VoiceBank/enhancement/crdnn/train_speech_tokenizer.py b/benchmarks/DASB/VoiceBank/enhancement/crdnn/train_speech_tokenizer.py deleted file mode 100644 index c25d78e26..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/crdnn/train_speech_tokenizer.py +++ /dev/null @@ -1,151 +0,0 @@ -#!/usr/bin/env/python - -"""Recipe for training a transformer-based speech enhancement system using SpeechTokenizer audio representations. - -To run this recipe: -> python train_speech_tokenizer.py hparams/.yaml - -Authors - * Luca Della Libera 2024 -""" - -import os -import sys -import warnings - -import speechbrain as sb -import torch -from hyperpyyaml import load_hyperpyyaml -from speechbrain.utils.distributed import run_on_main - -from train_encodec import Enhancement as EnhancementEncodec - - -class Enhancement(EnhancementEncodec): - @torch.no_grad() - def sig_to_toks(self, sig, lens): - # sig: [B, T] - self.hparams.codec.to(self.device).eval() - toks = self.hparams.codec(sig)[ - : self.hparams.num_codebooks - ] # [K, B, N] - toks = toks.movedim(-3, -1) # [B, N, K] - return toks - - @torch.no_grad() - def toks_to_sig(self, toks): - # toks: [B, N, K] - self.hparams.codec.to(self.device).eval() - toks = toks.movedim(-1, -3) # [K, B, N] - sig = self.hparams.codec.decode(toks) # [B, T] - return sig - - -if __name__ == "__main__": - # Command-line interface - hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) - with open(hparams_file) as fin: - hparams = load_hyperpyyaml(fin, overrides) - - # Filter warnings - warnings.filterwarnings("once") - warnings.filterwarnings("ignore", module="torch") - - # If --distributed_launch then create ddp_init_group with the right communication protocol - sb.utils.distributed.ddp_init_group(run_opts) - - # Create experiment directory - sb.create_experiment_directory( - experiment_directory=hparams["output_folder"], - hyperparams_to_save=hparams_file, - overrides=overrides, - ) - - # Dataset preparation - from voicebank_prepare import prepare_voicebank as prepare_data - - prepare_data_kwargs = { - "data_folder": hparams["data_folder"], - "save_folder": hparams["save_folder"], - "splits": hparams["splits"], - "num_valid_speakers": hparams["num_valid_speakers"], - } - - # Due to DDP, do the preparation ONLY on the main Python process - run_on_main(prepare_data, kwargs=prepare_data_kwargs) - - # Create the datasets objects - from utils import dataio_prepare - - train_data, valid_data, test_data = dataio_prepare( - debug=run_opts.get("debug", False), **hparams - ) - - # Pretrain the specified modules - if "pretrainer" in hparams: - run_on_main(hparams["pretrainer"].collect_files) - run_on_main(hparams["pretrainer"].load_collected) - - # Use pretrained embeddings - if hparams["pretrain_embedding"]: - # See https://github.com/ZhangXInFD/SpeechTokenizer/blob/a9f88dc72642b600654a62861e34342babae6c71/speechtokenizer/quantization/core_vq.py#L360 - toks = torch.arange(hparams["vocab_size"], device=run_opts["device"]) - toks = ( - toks[None, :, None].expand(hparams["num_codebooks"], -1, -1).clone() - ) # [K, C, 1] - hparams["codec"].to(run_opts["device"]).eval() - embs = [] - for i, indices in enumerate(toks): - layer = hparams["codec"].model.quantizer.vq.layers[i] - with torch.no_grad(): - quantized = layer.decode(indices) # [C, H, 1] - embs.append(quantized) - assert ( - hparams["codec"].model.quantizer.decode(toks) == sum(embs) - ).all() - embs = torch.cat(embs)[:, :, 0] # [CK, H] - hparams["embedding"].embedding.weight.data.copy_(embs) - - # Log number of parameters/buffers - codec_params = sum( - [x.numel() for x in hparams["codec"].state_dict().values()] - ) - model_params = sum( - [ - x.numel() - for module in hparams["modules"].values() - for x in module.state_dict().values() - ] - ) - hparams["train_logger"].log_stats( - stats_meta={ - f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", - "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", - }, - ) - - # Trainer initialization - brain = Enhancement( - modules=hparams["modules"], - opt_class=hparams["opt_class"], - hparams=hparams, - run_opts=run_opts, - checkpointer=hparams["checkpointer"], - ) - - # Train - brain.fit( - brain.hparams.epoch_counter, - train_data, - valid_data, - train_loader_kwargs=hparams["train_dataloader_kwargs"], - valid_loader_kwargs=hparams["valid_dataloader_kwargs"], - ) - - # Test - brain.hparams.dwer_file = os.path.join(hparams["output_folder"], "dwer.txt") - brain.evaluate( - test_data, - min_key="TER", - test_loader_kwargs=hparams["test_dataloader_kwargs"], - ) diff --git a/benchmarks/DASB/VoiceBank/enhancement/crdnn/utils.py b/benchmarks/DASB/VoiceBank/enhancement/crdnn/utils.py deleted file mode 120000 index 50fbc6d8f..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/crdnn/utils.py +++ /dev/null @@ -1 +0,0 @@ -../utils.py \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/crdnn/voicebank_prepare.py b/benchmarks/DASB/VoiceBank/enhancement/crdnn/voicebank_prepare.py deleted file mode 120000 index 66cb2e6cc..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/crdnn/voicebank_prepare.py +++ /dev/null @@ -1 +0,0 @@ -../../voicebank_prepare.py \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/custom_model.py b/benchmarks/DASB/VoiceBank/enhancement/custom_model.py new file mode 120000 index 000000000..6cb6f37ba --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/custom_model.py @@ -0,0 +1 @@ +../../model/custom_model.py \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/crdnn/hparams/train_discrete_ssl.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train.yaml similarity index 54% rename from benchmarks/DASB/VoiceBank/enhancement/crdnn/hparams/train_discrete_ssl.yaml rename to benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train.yaml index 940b65d9f..d7b663ad6 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/crdnn/hparams/train_discrete_ssl.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train.yaml @@ -1,9 +1,9 @@ # ########################################################################################### -# Model: CRDNN with discrete SSL audio representations +# Model: CRDNN with discrete audio representations # Authors: Luca Della Libera 2024 # ########################################################################################### -experiment_name: discrete_ssl +run_name: encodec # Seed needs to be set at top of YAML seed: 0 @@ -11,19 +11,22 @@ __set_seed: !apply:torch.manual_seed [!ref ] # Data preparation data_folder: !PLACEHOLDER +tokens_folder: !PLACEHOLDER +cached_data_folder: !ref # e.g., path/to/cache train_csv: !ref /trainset_28spk_wav.csv valid_csv: !ref /validset_wav.csv test_csv: !ref /testset_wav.csv splits: [trainset_28spk_wav, validset_wav, testset_wav] num_valid_speakers: 2 +testing: True # Output folders -output_folder: !ref results// +output_folder: !ref results/CRDNN// save_folder: !ref /save cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE # Save options -compute_metrics: True +compute_metrics: False save_audios: False # Preprocessing parameters @@ -31,125 +34,57 @@ train_remove_if_longer: 60.0 # Seconds valid_remove_if_longer: 60.0 # Seconds test_remove_if_longer: 60.0 # Seconds sorting: random -use_cache: True # Training parameters num_epochs: 50 -grad_accumulation_factor: 16 -train_batch_size: 1 +grad_accumulation_factor: 1 +train_batch_size: 16 valid_batch_size: 1 test_batch_size: 1 dataloader_workers: 4 nonfinite_patience: 10 -max_grad_norm: 5.0 -precision: fp32 +max_grad_norm: 0.001 +precision: fp16 ckpt_interval_minutes: 6000 keep_checkpoints: 1 -augment: False -augment_prob: 0.75 # Optimizer parameters -lr: 0.0005 +lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.0005)" weight_decay: 0.01 improvement_threshold: 0.0025 annealing_factor: 0.9 patient: 1 -# Discrete SSL parameters -sample_rate: 16000 -num_clusters: 1000 -vocab_size: !ref -SSL_layers: [1, 3, 7, 12, 18, 23] -num_codebooks: !apply:len [!ref ] -ssl_hub: facebook/hubert-large-ll60k -vocoder_hub: speechbrain/hifigan-hubert-l1-3-7-12-18-23-k1000-LibriTTS # Must be consistent with ssl_hub/SSL_layers/num_clusters -kmeans_repo_id: speechbrain/SSL_Quantization -kmeans_dataset: LibriSpeech-100-360-500 -ssl_model_type: hubert +# Codec parameters +codec_type: encodec +sample_rate: 24000 # Should match the tokenizer's sample rate +vocab_size: 1024 +num_codebooks: 2 +bandwidth: 24.0 +kmeans_dataset: LibriSpeech +SSL_layers: [1, 3] #[1, 3, 7, 12, 18, 23] +tokenizer_save_path: /home/luca/Downloads/SQ-Codec/sqcodec # Embedding parameters embedding_dim: 1024 -pretrain_embedding: False # If True, must match the codec's embedding size (1024) freeze_embedding: False # Encoder parameters dropout: 0.1 activation: !name:torch.nn.LeakyReLU rnn_class: !name:speechbrain.nnet.RNN.LSTM -rnn_layers: 4 +rnn_layers: 4 # @orion_step1: --rnn_layers~"uniform(1, 4,discrete=True)" time_pooling_size: 1 rnn_bidirectional: True rnn_neurons: 256 dnn_blocks: 2 dnn_neurons: 256 cnn_blocks: 2 -cnn_channels: (16, 16) +cnn_channels: (12, 12) inter_layer_pooling_size: (2, 2) cnn_kernelsize: (3, 3) -# Augmentation -drop_freq: !new:speechbrain.augment.time_domain.DropFreq - drop_freq_low: 0 # Min frequency band dropout probability - drop_freq_high: 1 # Max frequency band dropout probability - drop_freq_count_low: 1 # Min number of frequency bands to drop - drop_freq_count_high: 3 # Max number of frequency bands to drop - drop_freq_width: 0.05 # Width of frequency bands to drop - -drop_chunk: !new:speechbrain.augment.time_domain.DropChunk - drop_length_low: 1 # Min number of audio chunks to drop - drop_length_high: 5 # Max number of audio chunks to drop - drop_count_low: 1000 # Min length of audio chunks to drop - drop_count_high: 2000 # Max length of audio chunks to drop - -augmentation: !new:speechbrain.augment.augmenter.Augmenter - parallel_augment: False - concat_original: False - repeat_augment: 1 - shuffle_augmentations: False - min_augmentations: 2 - max_augmentations: 2 - augment_prob: !ref - augmentations: [!ref , !ref ] - # Modules -ssl_model: !new:utils.SBWav2Vec2ForwardWrapper - wav2vec2: !apply:speechbrain.utils.hparams.choice - value: !ref - choices: - wavlm: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM - source: !ref - output_norm: False - freeze: True - freeze_feature_extractor: True - output_all_hiddens: True - save_path: !ref - hubert: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT - source: !ref - output_norm: False - freeze: True - freeze_feature_extractor: True - output_all_hiddens: True - save_path: !ref - wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 - source: !ref - output_norm: False - freeze: True - freeze_feature_extractor: True - output_all_hiddens: True - save_path: !ref - layer_ids: !ref - -codec_quantizer: !new:speechbrain.lobes.models.huggingface_transformers.discrete_ssl.DiscreteSSL - save_path: !ref - ssl_model: !ref - kmeans_dataset: !ref - kmeans_repo_id: !ref - num_clusters: !ref - -codec_vocoder: !apply:speechbrain.inference.vocoders.UnitHIFIGAN.from_hparams - source: !ref - savedir: !apply:os.path.join [!ref , !ref ] - embedding: !new:custom_model.Discrete_EmbeddingLayer num_codebooks: !ref vocab_size: !ref @@ -177,13 +112,12 @@ encoder: !new:speechbrain.lobes.models.CRDNN.CRDNN rnn_bidirectional: !ref dnn_blocks: !ref dnn_neurons: !ref - rnn_re_init: True + rnn_re_init: False use_rnnp: False head: !new:torch.nn.Linear in_features: !ref - # Workaround to bypass HyperPyYAML lack of flexibility - out_features: !apply:train_discrete_ssl.len_ [!ref , !ref ] + out_features: !ref * modules: embedding: !ref @@ -217,6 +151,88 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler annealing_factor: !ref patient: !ref +# Token loaders +tokens_loader_in: !new:utils.tokens.TokensLoader + data_path: !ref /input + +tokens_loader_out: !new:utils.tokens.TokensLoader + data_path: !ref /output + +# Codec +codec: !apply:speechbrain.utils.hparams.choice + value: !ref + choices: + dac: !new:utils.tokenizer_interface.DACTokenizer + model_type: 24khz + model_bitrate: 8kbps + load_pretrained: True + tag: latest + encodec: !new:utils.tokenizer_interface.EncodecTokenizer + source: facebook/encodec_24khz + save_path: !ref + sample_rate: !ref + bandwidth: !ref + flat_embeddings: False + freeze: True + renorm_embeddings: False + mimi: !new:utils.tokenizer_interface.MimiTokenizer + source: kyutai/mimi + save_path: !ref + num_codebooks: !ref + sample_rate: !ref + speech_tokenizer: !new:utils.tokenizer_interface.SpeechTokenizerWrapper + source: fnlp/SpeechTokenizer + save_path: !ref + sample_rate: !ref + sqcodec: !new:utils.tokenizer_interface.SQCodecTokenizer + save_path: !ref + checkpoint: ckpt_00190000.pth + config: config.yaml + sample_rate: !ref + wavtokenizer: !new:utils.tokenizer_interface.WavTokenizerWrapper + source: novateur/WavTokenizer-medium-music-audio-75token + save_path: !ref + checkpoint: wavtokenizer_medium_music_audio_320_24k_v2.ckpt + config: wavtokenizer_mediumdata_music_audio_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml + sample_rate: !ref + freeze: True + wavlm: !new:utils.tokenizer_interface.DiscreteSSLTokenizer + save_path: !ref + ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM + source: microsoft/wavlm-large + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref + vocoder_repo_id: speechbrain/hifigan-wavlm-k1000-LibriTTS + kmeans_dataset: !ref + num_clusters: !ref + hubert: !new:utils.tokenizer_interface.DiscreteSSLTokenizer + save_path: !ref + ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT + source: facebook/hubert-large-ll60k + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref + vocoder_repo_id: speechbrain/hifigan-hubert-k1000-LibriTTS + kmeans_dataset: !ref + num_clusters: !ref + wav2vec2: !new:utils.tokenizer_interface.DiscreteSSLTokenizer + save_path: !ref + ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 + source: facebook/wav2vec2-large + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref + vocoder_repo_id: speechbrain/hifigan-wav2vec2-k1000-LibriTTS + kmeans_dataset: !ref + num_clusters: !ref + # Dataloaders train_dataloader_kwargs: batch_size: !ref @@ -252,11 +268,6 @@ wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM save_path: !ref sample_rate: !ref -ecapatdnn_sim_computer: !name:metrics.spk_sim.SpkSimECAPATDNN - model_hub: speechbrain/spkrec-ecapa-voxceleb - save_path: !apply:os.path.join [!ref , models--speechbrain--spkrec-ecapa-voxceleb] - sample_rate: !ref - # Counters, checkpointers, loggers, etc. epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter limit: !ref diff --git a/benchmarks/DASB/VoiceBank/enhancement/crdnn/hparams/train_continuous_wavlm.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_continuous_wavlm.yaml similarity index 77% rename from benchmarks/DASB/VoiceBank/enhancement/crdnn/hparams/train_continuous_wavlm.yaml rename to benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_continuous_wavlm.yaml index 6068ac56d..a16be8b9d 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/crdnn/hparams/train_continuous_wavlm.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_continuous_wavlm.yaml @@ -3,7 +3,7 @@ # Authors: Luca Della Libera 2024 # ########################################################################################### -experiment_name: continuous_wavlm +run_name: continuous_wavlm_crdnn # Seed needs to be set at top of YAML seed: 0 @@ -18,12 +18,12 @@ splits: [trainset_28spk_wav, validset_wav, testset_wav] num_valid_speakers: 2 # Output folders -output_folder: !ref results// +output_folder: !ref results/CRDNN// save_folder: !ref /save cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE # Save options -compute_metrics: True +compute_metrics: False save_audios: False # Preprocessing parameters @@ -35,18 +35,16 @@ use_cache: False # Training parameters num_epochs: 50 -grad_accumulation_factor: 16 -train_batch_size: 1 +grad_accumulation_factor: 1 +train_batch_size: 16 valid_batch_size: 1 test_batch_size: 1 dataloader_workers: 4 nonfinite_patience: 10 -max_grad_norm: 5.0 -precision: fp32 +max_grad_norm: 0.001 +precision: fp16 ckpt_interval_minutes: 6000 keep_checkpoints: 1 -augment: False -augment_prob: 0.75 # Optimizer parameters lr: 0.0005 @@ -74,36 +72,12 @@ rnn_neurons: 256 dnn_blocks: 2 dnn_neurons: 256 cnn_blocks: 2 -cnn_channels: (16, 16) +cnn_channels: (12, 12) inter_layer_pooling_size: (2, 2) cnn_kernelsize: (3, 3) -# Augmentation -drop_freq: !new:speechbrain.augment.time_domain.DropFreq - drop_freq_low: 0 # Min frequency band dropout probability - drop_freq_high: 1 # Max frequency band dropout probability - drop_freq_count_low: 1 # Min number of frequency bands to drop - drop_freq_count_high: 3 # Max number of frequency bands to drop - drop_freq_width: 0.05 # Width of frequency bands to drop - -drop_chunk: !new:speechbrain.augment.time_domain.DropChunk - drop_length_low: 1 # Min number of audio chunks to drop - drop_length_high: 5 # Max number of audio chunks to drop - drop_count_low: 1000 # Min length of audio chunks to drop - drop_count_high: 2000 # Max length of audio chunks to drop - -augmentation: !new:speechbrain.augment.augmenter.Augmenter - parallel_augment: False - concat_original: False - repeat_augment: 1 - shuffle_augmentations: False - min_augmentations: 2 - max_augmentations: 2 - augment_prob: !ref - augmentations: [!ref , !ref ] - # Modules -ssl_model: !new:utils.SBWav2Vec2ForwardWrapper +ssl_model: !new:common.SBWav2Vec2ForwardWrapper wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM source: !ref output_norm: False @@ -138,7 +112,7 @@ encoder: !new:speechbrain.lobes.models.CRDNN.CRDNN rnn_bidirectional: !ref dnn_blocks: !ref dnn_neurons: !ref - rnn_re_init: True + rnn_re_init: False use_rnnp: False head: !new:torch.nn.Linear @@ -206,11 +180,6 @@ wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM save_path: !ref sample_rate: !ref -ecapatdnn_sim_computer: !name:metrics.spk_sim.SpkSimECAPATDNN - model_hub: speechbrain/spkrec-ecapa-voxceleb - save_path: !apply:os.path.join [!ref , models--speechbrain--spkrec-ecapa-voxceleb] - sample_rate: !ref - # Counters, checkpointers, loggers, etc. epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter limit: !ref diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_dac.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_dac.yaml new file mode 100644 index 000000000..8d3016ef1 --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_dac.yaml @@ -0,0 +1,210 @@ +# ########################################################################################### +# Model: CRDNN with discrete audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +run_name: dac_crdnn + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] + +# Data preparation +data_folder: !PLACEHOLDER +tokens_folder: !PLACEHOLDER +cached_data_folder: !ref # e.g., path/to/cache +train_csv: !ref /trainset_28spk_wav.csv +valid_csv: !ref /validset_wav.csv +test_csv: !ref /testset_wav.csv +splits: [trainset_28spk_wav, validset_wav, testset_wav] +num_valid_speakers: 2 +testing: True + +# Output folders +output_folder: !ref results/CRDNN// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: False +save_audios: False + +# Preprocessing parameters +train_remove_if_longer: 60.0 # Seconds +valid_remove_if_longer: 60.0 # Seconds +test_remove_if_longer: 60.0 # Seconds +sorting: random + +# Training parameters +num_epochs: 50 +grad_accumulation_factor: 1 +train_batch_size: 16 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 4 +nonfinite_patience: 10 +max_grad_norm: 0.001 +precision: fp16 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 + +# Optimizer parameters +lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.0005)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# Codec parameters +sample_rate: 24000 # Should match the tokenizer's sample rate +vocab_size: 1024 +num_codebooks: 2 + +# Embedding parameters +embedding_dim: 1024 +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.LeakyReLU +rnn_class: !name:speechbrain.nnet.RNN.LSTM +rnn_layers: 4 # @orion_step1: --rnn_layers~"uniform(1, 4,discrete=True)" +time_pooling_size: 1 +rnn_bidirectional: True +rnn_neurons: 256 +dnn_blocks: 2 +dnn_neurons: 256 +cnn_blocks: 2 +cnn_channels: (12, 12) +inter_layer_pooling_size: (2, 2) +cnn_kernelsize: (3, 3) + +# Modules +embedding: !new:custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.CRDNN.CRDNN + input_shape: [null, null, !ref ] + activation: !ref + dropout: !ref + cnn_blocks: !ref + cnn_channels: !ref + cnn_kernelsize: !ref + inter_layer_pooling_size: !ref + time_pooling: True + using_2d_pooling: False + time_pooling_size: !ref + rnn_class: !ref + rnn_layers: !ref + rnn_neurons: !ref + rnn_bidirectional: !ref + dnn_blocks: !ref + dnn_neurons: !ref + rnn_re_init: False + use_rnnp: False + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Token loaders +tokens_loader_in: !new:utils.tokens.TokensLoader + data_path: !ref /input + +tokens_loader_out: !new:utils.tokens.TokensLoader + data_path: !ref /output + +# Codec +codec: !name:utils.tokenizer_interface.DACTokenizer + model_type: 24khz + model_bitrate: 8kbps + load_pretrained: True + tag: latest + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +dnsmos_computer: !name:metrics.dnsmos.DNSMOS + sample_rate: !ref + +dwer_computer: !name:metrics.dwer.DWER + model_hub: openai/whisper-small + save_path: !ref + sample_rate: !ref + +wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM + model_hub: microsoft/wavlm-base-sv + save_path: !ref + sample_rate: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/crdnn/hparams/train_speech_tokenizer.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_encodec.yaml similarity index 72% rename from benchmarks/DASB/VoiceBank/enhancement/crdnn/hparams/train_speech_tokenizer.yaml rename to benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_encodec.yaml index b98ce84c9..8d07d34c0 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/crdnn/hparams/train_speech_tokenizer.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_encodec.yaml @@ -1,9 +1,9 @@ # ########################################################################################### -# Model: CRDNN with SpeechTokenizer audio representations +# Model: CRDNN with discrete audio representations # Authors: Luca Della Libera 2024 # ########################################################################################### -experiment_name: speech_tokenizer +run_name: encodec_crdnn # Seed needs to be set at top of YAML seed: 0 @@ -11,19 +11,22 @@ __set_seed: !apply:torch.manual_seed [!ref ] # Data preparation data_folder: !PLACEHOLDER +tokens_folder: !PLACEHOLDER +cached_data_folder: !ref # e.g., path/to/cache train_csv: !ref /trainset_28spk_wav.csv valid_csv: !ref /validset_wav.csv test_csv: !ref /testset_wav.csv splits: [trainset_28spk_wav, validset_wav, testset_wav] num_valid_speakers: 2 +testing: True # Output folders -output_folder: !ref results// +output_folder: !ref results/CRDNN// save_folder: !ref /save cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE # Save options -compute_metrics: True +compute_metrics: False save_audios: False # Preprocessing parameters @@ -31,84 +34,53 @@ train_remove_if_longer: 60.0 # Seconds valid_remove_if_longer: 60.0 # Seconds test_remove_if_longer: 60.0 # Seconds sorting: random -use_cache: True # Training parameters num_epochs: 50 -grad_accumulation_factor: 16 -train_batch_size: 1 +grad_accumulation_factor: 1 +train_batch_size: 16 valid_batch_size: 1 test_batch_size: 1 dataloader_workers: 4 nonfinite_patience: 10 -max_grad_norm: 5.0 -precision: fp32 +max_grad_norm: 0.001 +precision: fp16 ckpt_interval_minutes: 6000 keep_checkpoints: 1 -augment: False -augment_prob: 0.75 # Optimizer parameters -lr: 0.0005 +lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.0005)" weight_decay: 0.01 improvement_threshold: 0.0025 annealing_factor: 0.9 patient: 1 -# SpeechTokenizer parameters -sample_rate: 16000 +# Codec parameters +sample_rate: 24000 # Should match the tokenizer's sample rate vocab_size: 1024 -num_codebooks: 2 # Must be <= 8 +num_codebooks: 2 +bandwidth: 1.5 # Embedding parameters embedding_dim: 1024 -pretrain_embedding: False # If True, must match the codec's embedding size (1024) freeze_embedding: False # Encoder parameters dropout: 0.1 activation: !name:torch.nn.LeakyReLU rnn_class: !name:speechbrain.nnet.RNN.LSTM -rnn_layers: 4 +rnn_layers: 4 # @orion_step1: --rnn_layers~"uniform(1, 4,discrete=True)" time_pooling_size: 1 rnn_bidirectional: True rnn_neurons: 256 dnn_blocks: 2 dnn_neurons: 256 cnn_blocks: 2 -cnn_channels: (16, 16) +cnn_channels: (12, 12) inter_layer_pooling_size: (2, 2) cnn_kernelsize: (3, 3) -# Augmentation -drop_freq: !new:speechbrain.augment.time_domain.DropFreq - drop_freq_low: 0 # Min frequency band dropout probability - drop_freq_high: 1 # Max frequency band dropout probability - drop_freq_count_low: 1 # Min number of frequency bands to drop - drop_freq_count_high: 3 # Max number of frequency bands to drop - drop_freq_width: 0.05 # Width of frequency bands to drop - -drop_chunk: !new:speechbrain.augment.time_domain.DropChunk - drop_length_low: 1 # Min number of audio chunks to drop - drop_length_high: 5 # Max number of audio chunks to drop - drop_count_low: 1000 # Min length of audio chunks to drop - drop_count_high: 2000 # Max length of audio chunks to drop - -augmentation: !new:speechbrain.augment.augmenter.Augmenter - parallel_augment: False - concat_original: False - repeat_augment: 1 - shuffle_augmentations: False - min_augmentations: 2 - max_augmentations: 2 - augment_prob: !ref - augmentations: [!ref , !ref ] - # Modules -codec: !new:speechbrain.lobes.models.discrete.speechtokenizer_interface.SpeechTokenizer_interface - source: fnlp/SpeechTokenizer - save_path: !ref - embedding: !new:custom_model.Discrete_EmbeddingLayer num_codebooks: !ref vocab_size: !ref @@ -136,7 +108,7 @@ encoder: !new:speechbrain.lobes.models.CRDNN.CRDNN rnn_bidirectional: !ref dnn_blocks: !ref dnn_neurons: !ref - rnn_re_init: True + rnn_re_init: False use_rnnp: False head: !new:torch.nn.Linear @@ -175,6 +147,23 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler annealing_factor: !ref patient: !ref +# Token loaders +tokens_loader_in: !new:utils.tokens.TokensLoader + data_path: !ref /input + +tokens_loader_out: !new:utils.tokens.TokensLoader + data_path: !ref /output + +# Codec +codec: !name:utils.tokenizer_interface.EncodecTokenizer + source: facebook/encodec_24khz + save_path: !ref + sample_rate: !ref + bandwidth: !ref + flat_embeddings: False + freeze: True + renorm_embeddings: False + # Dataloaders train_dataloader_kwargs: batch_size: !ref @@ -210,11 +199,6 @@ wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM save_path: !ref sample_rate: !ref -ecapatdnn_sim_computer: !name:metrics.spk_sim.SpkSimECAPATDNN - model_hub: speechbrain/spkrec-ecapa-voxceleb - save_path: !apply:os.path.join [!ref , models--speechbrain--spkrec-ecapa-voxceleb] - sample_rate: !ref - # Counters, checkpointers, loggers, etc. epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter limit: !ref diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_hubert.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_hubert.yaml new file mode 100644 index 000000000..d3c84071e --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_hubert.yaml @@ -0,0 +1,219 @@ +# ########################################################################################### +# Model: CRDNN with discrete audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +run_name: hubert_crdnn + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] + +# Data preparation +data_folder: !PLACEHOLDER +tokens_folder: !PLACEHOLDER +cached_data_folder: !ref # e.g., path/to/cache +train_csv: !ref /trainset_28spk_wav.csv +valid_csv: !ref /validset_wav.csv +test_csv: !ref /testset_wav.csv +splits: [trainset_28spk_wav, validset_wav, testset_wav] +num_valid_speakers: 2 +testing: True + +# Output folders +output_folder: !ref results/CRDNN// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: False +save_audios: False + +# Preprocessing parameters +train_remove_if_longer: 60.0 # Seconds +valid_remove_if_longer: 60.0 # Seconds +test_remove_if_longer: 60.0 # Seconds +sorting: random + +# Training parameters +num_epochs: 50 +grad_accumulation_factor: 1 +train_batch_size: 16 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 4 +nonfinite_patience: 10 +max_grad_norm: 0.001 +precision: fp16 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 + +# Optimizer parameters +lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.0005)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# Codec parameters +sample_rate: 16000 # Should match the tokenizer's sample rate +vocab_size: 1000 +num_codebooks: 2 +kmeans_dataset: LibriSpeech +SSL_layers: [1, 3] #[1, 3, 7, 12, 18, 23] + +# Embedding parameters +embedding_dim: 1024 +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.LeakyReLU +rnn_class: !name:speechbrain.nnet.RNN.LSTM +rnn_layers: 4 # @orion_step1: --rnn_layers~"uniform(1, 4,discrete=True)" +time_pooling_size: 1 +rnn_bidirectional: True +rnn_neurons: 256 +dnn_blocks: 2 +dnn_neurons: 256 +cnn_blocks: 2 +cnn_channels: (12, 12) +inter_layer_pooling_size: (2, 2) +cnn_kernelsize: (3, 3) + +# Modules +embedding: !new:custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.CRDNN.CRDNN + input_shape: [null, null, !ref ] + activation: !ref + dropout: !ref + cnn_blocks: !ref + cnn_channels: !ref + cnn_kernelsize: !ref + inter_layer_pooling_size: !ref + time_pooling: True + using_2d_pooling: False + time_pooling_size: !ref + rnn_class: !ref + rnn_layers: !ref + rnn_neurons: !ref + rnn_bidirectional: !ref + dnn_blocks: !ref + dnn_neurons: !ref + rnn_re_init: False + use_rnnp: False + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Token loaders +tokens_loader_in: !new:utils.tokens.TokensLoader + data_path: !ref /input + +tokens_loader_out: !new:utils.tokens.TokensLoader + data_path: !ref /output + +# Codec +codec: !name:utils.tokenizer_interface.DiscreteSSLTokenizer + save_path: !ref + ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT + source: facebook/hubert-large-ll60k + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref + vocoder_repo_id: speechbrain/hifigan-hubert-k1000-LibriTTS + kmeans_dataset: !ref + num_clusters: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +dnsmos_computer: !name:metrics.dnsmos.DNSMOS + sample_rate: !ref + +dwer_computer: !name:metrics.dwer.DWER + model_hub: openai/whisper-small + save_path: !ref + sample_rate: !ref + +wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM + model_hub: microsoft/wavlm-base-sv + save_path: !ref + sample_rate: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_mimi.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_mimi.yaml new file mode 100644 index 000000000..63fe4513c --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_mimi.yaml @@ -0,0 +1,210 @@ +# ########################################################################################### +# Model: CRDNN with discrete audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +run_name: mimi_crdnn + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] + +# Data preparation +data_folder: !PLACEHOLDER +tokens_folder: !PLACEHOLDER +cached_data_folder: !ref # e.g., path/to/cache +train_csv: !ref /trainset_28spk_wav.csv +valid_csv: !ref /validset_wav.csv +test_csv: !ref /testset_wav.csv +splits: [trainset_28spk_wav, validset_wav, testset_wav] +num_valid_speakers: 2 +testing: True + +# Output folders +output_folder: !ref results/CRDNN// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: False +save_audios: False + +# Preprocessing parameters +train_remove_if_longer: 60.0 # Seconds +valid_remove_if_longer: 60.0 # Seconds +test_remove_if_longer: 60.0 # Seconds +sorting: random + +# Training parameters +num_epochs: 50 +grad_accumulation_factor: 1 +train_batch_size: 16 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 4 +nonfinite_patience: 10 +max_grad_norm: 0.001 +precision: fp16 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 + +# Optimizer parameters +lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.0005)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# Codec parameters +sample_rate: 24000 # Should match the tokenizer's sample rate +vocab_size: 2048 +num_codebooks: 2 + +# Embedding parameters +embedding_dim: 1024 +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.LeakyReLU +rnn_class: !name:speechbrain.nnet.RNN.LSTM +rnn_layers: 4 # @orion_step1: --rnn_layers~"uniform(1, 4,discrete=True)" +time_pooling_size: 1 +rnn_bidirectional: True +rnn_neurons: 256 +dnn_blocks: 2 +dnn_neurons: 256 +cnn_blocks: 2 +cnn_channels: (12, 12) +inter_layer_pooling_size: (2, 2) +cnn_kernelsize: (3, 3) + +# Modules +embedding: !new:custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.CRDNN.CRDNN + input_shape: [null, null, !ref ] + activation: !ref + dropout: !ref + cnn_blocks: !ref + cnn_channels: !ref + cnn_kernelsize: !ref + inter_layer_pooling_size: !ref + time_pooling: True + using_2d_pooling: False + time_pooling_size: !ref + rnn_class: !ref + rnn_layers: !ref + rnn_neurons: !ref + rnn_bidirectional: !ref + dnn_blocks: !ref + dnn_neurons: !ref + rnn_re_init: False + use_rnnp: False + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Token loaders +tokens_loader_in: !new:utils.tokens.TokensLoader + data_path: !ref /input + +tokens_loader_out: !new:utils.tokens.TokensLoader + data_path: !ref /output + +# Codec +codec: !name:utils.tokenizer_interface.MimiTokenizer + source: kyutai/mimi + save_path: !ref + num_codebooks: !ref + sample_rate: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +dnsmos_computer: !name:metrics.dnsmos.DNSMOS + sample_rate: !ref + +dwer_computer: !name:metrics.dwer.DWER + model_hub: openai/whisper-small + save_path: !ref + sample_rate: !ref + +wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM + model_hub: microsoft/wavlm-base-sv + save_path: !ref + sample_rate: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/crdnn/hparams/train_dac.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_speech_tokenizer.yaml similarity index 68% rename from benchmarks/DASB/VoiceBank/enhancement/crdnn/hparams/train_dac.yaml rename to benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_speech_tokenizer.yaml index 2bbf54de6..8a95811d8 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/crdnn/hparams/train_dac.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_speech_tokenizer.yaml @@ -1,9 +1,9 @@ # ########################################################################################### -# Model: CRDNN with DAC audio representations +# Model: CRDNN with discrete audio representations # Authors: Luca Della Libera 2024 # ########################################################################################### -experiment_name: dac +run_name: speech_tokenizer_crdnn # Seed needs to be set at top of YAML seed: 0 @@ -11,19 +11,22 @@ __set_seed: !apply:torch.manual_seed [!ref ] # Data preparation data_folder: !PLACEHOLDER +tokens_folder: !PLACEHOLDER +cached_data_folder: !ref # e.g., path/to/cache train_csv: !ref /trainset_28spk_wav.csv valid_csv: !ref /validset_wav.csv test_csv: !ref /testset_wav.csv splits: [trainset_28spk_wav, validset_wav, testset_wav] num_valid_speakers: 2 +testing: True # Output folders -output_folder: !ref results// +output_folder: !ref results/CRDNN// save_folder: !ref /save cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE # Save options -compute_metrics: True +compute_metrics: False save_audios: False # Preprocessing parameters @@ -31,93 +34,52 @@ train_remove_if_longer: 60.0 # Seconds valid_remove_if_longer: 60.0 # Seconds test_remove_if_longer: 60.0 # Seconds sorting: random -use_cache: True # Training parameters num_epochs: 50 -grad_accumulation_factor: 16 -train_batch_size: 1 +grad_accumulation_factor: 1 +train_batch_size: 16 valid_batch_size: 1 test_batch_size: 1 dataloader_workers: 4 nonfinite_patience: 10 -max_grad_norm: 5.0 -precision: fp32 +max_grad_norm: 0.001 +precision: fp16 ckpt_interval_minutes: 6000 keep_checkpoints: 1 -augment: False -augment_prob: 0.75 # Optimizer parameters -lr: 0.0005 +lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.0005)" weight_decay: 0.01 improvement_threshold: 0.0025 annealing_factor: 0.9 patient: 1 -# DAC parameters -# sample_rate: [16000, 24000, 44000, 44000] -# vocab_size: [1024, 1024, 1024, 1024] -# max_num_codebooks: [12, 32, 9, 18] -# model_type: [16khz, 24khz, 44khz, 44khz] -# model_bitrate: [8kbps, 8kbps, 8kbps, 16kbps] -sample_rate: 24000 # NOTE: must match DAC's model type +# Codec parameters +sample_rate: 16000 # Should match the tokenizer's sample rate vocab_size: 1024 -num_codebooks: 2 # NOTE: must be smaller or equal to the maximum number of codebooks for the given model type -model_type: 24khz -model_bitrate: 8kbps +num_codebooks: 2 # Embedding parameters embedding_dim: 1024 -pretrain_embedding: False # If True, must match the codec's embedding size (1024) freeze_embedding: False # Encoder parameters dropout: 0.1 activation: !name:torch.nn.LeakyReLU rnn_class: !name:speechbrain.nnet.RNN.LSTM -rnn_layers: 4 +rnn_layers: 4 # @orion_step1: --rnn_layers~"uniform(1, 4,discrete=True)" time_pooling_size: 1 rnn_bidirectional: True rnn_neurons: 256 dnn_blocks: 2 dnn_neurons: 256 cnn_blocks: 2 -cnn_channels: (16, 16) +cnn_channels: (12, 12) inter_layer_pooling_size: (2, 2) cnn_kernelsize: (3, 3) -# Augmentation -drop_freq: !new:speechbrain.augment.time_domain.DropFreq - drop_freq_low: 0 # Min frequency band dropout probability - drop_freq_high: 1 # Max frequency band dropout probability - drop_freq_count_low: 1 # Min number of frequency bands to drop - drop_freq_count_high: 3 # Max number of frequency bands to drop - drop_freq_width: 0.05 # Width of frequency bands to drop - -drop_chunk: !new:speechbrain.augment.time_domain.DropChunk - drop_length_low: 1 # Min number of audio chunks to drop - drop_length_high: 5 # Max number of audio chunks to drop - drop_count_low: 1000 # Min length of audio chunks to drop - drop_count_high: 2000 # Max length of audio chunks to drop - -augmentation: !new:speechbrain.augment.augmenter.Augmenter - parallel_augment: False - concat_original: False - repeat_augment: 1 - shuffle_augmentations: False - min_augmentations: 2 - max_augmentations: 2 - augment_prob: !ref - augmentations: [!ref , !ref ] - # Modules -codec: !new:speechbrain.lobes.models.discrete.dac.DAC - model_type: !ref - model_bitrate: !ref - load_pretrained: True - tag: latest - embedding: !new:custom_model.Discrete_EmbeddingLayer num_codebooks: !ref vocab_size: !ref @@ -145,7 +107,7 @@ encoder: !new:speechbrain.lobes.models.CRDNN.CRDNN rnn_bidirectional: !ref dnn_blocks: !ref dnn_neurons: !ref - rnn_re_init: True + rnn_re_init: False use_rnnp: False head: !new:torch.nn.Linear @@ -184,6 +146,19 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler annealing_factor: !ref patient: !ref +# Token loaders +tokens_loader_in: !new:utils.tokens.TokensLoader + data_path: !ref /input + +tokens_loader_out: !new:utils.tokens.TokensLoader + data_path: !ref /output + +# Codec +codec: !name:utils.tokenizer_interface.SpeechTokenizerWrapper + source: fnlp/SpeechTokenizer + save_path: !ref + sample_rate: !ref + # Dataloaders train_dataloader_kwargs: batch_size: !ref @@ -219,11 +194,6 @@ wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM save_path: !ref sample_rate: !ref -ecapatdnn_sim_computer: !name:metrics.spk_sim.SpkSimECAPATDNN - model_hub: speechbrain/spkrec-ecapa-voxceleb - save_path: !apply:os.path.join [!ref , models--speechbrain--spkrec-ecapa-voxceleb] - sample_rate: !ref - # Counters, checkpointers, loggers, etc. epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter limit: !ref @@ -236,4 +206,4 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer counter: !ref train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref /train_log.txt + save_file: !ref /train_log.txt \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_sqcodec.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_sqcodec.yaml new file mode 100644 index 000000000..58c4adb17 --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_sqcodec.yaml @@ -0,0 +1,211 @@ +# ########################################################################################### +# Model: CRDNN with discrete audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +run_name: sqcodec_crdnn + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] + +# Data preparation +data_folder: !PLACEHOLDER +tokens_folder: !PLACEHOLDER +cached_data_folder: !ref # e.g., path/to/cache +train_csv: !ref /trainset_28spk_wav.csv +valid_csv: !ref /validset_wav.csv +test_csv: !ref /testset_wav.csv +splits: [trainset_28spk_wav, validset_wav, testset_wav] +num_valid_speakers: 2 +testing: True + +# Output folders +output_folder: !ref results/CRDNN// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: False +save_audios: False + +# Preprocessing parameters +train_remove_if_longer: 60.0 # Seconds +valid_remove_if_longer: 60.0 # Seconds +test_remove_if_longer: 60.0 # Seconds +sorting: random + +# Training parameters +num_epochs: 50 +grad_accumulation_factor: 1 +train_batch_size: 16 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 4 +nonfinite_patience: 10 +max_grad_norm: 0.001 +precision: fp16 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 + +# Optimizer parameters +lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.0005)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# Codec parameters +sample_rate: 16000 # Should match the tokenizer's sample rate +vocab_size: 19683 +num_codebooks: 4 +tokenizer_save_path: !PLACEHOLDER + +# Embedding parameters +embedding_dim: 1024 +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.LeakyReLU +rnn_class: !name:speechbrain.nnet.RNN.LSTM +rnn_layers: 4 # @orion_step1: --rnn_layers~"uniform(1, 4,discrete=True)" +time_pooling_size: 1 +rnn_bidirectional: True +rnn_neurons: 256 +dnn_blocks: 2 +dnn_neurons: 256 +cnn_blocks: 2 +cnn_channels: (12, 12) +inter_layer_pooling_size: (2, 2) +cnn_kernelsize: (3, 3) + +# Modules +embedding: !new:custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.CRDNN.CRDNN + input_shape: [null, null, !ref ] + activation: !ref + dropout: !ref + cnn_blocks: !ref + cnn_channels: !ref + cnn_kernelsize: !ref + inter_layer_pooling_size: !ref + time_pooling: True + using_2d_pooling: False + time_pooling_size: !ref + rnn_class: !ref + rnn_layers: !ref + rnn_neurons: !ref + rnn_bidirectional: !ref + dnn_blocks: !ref + dnn_neurons: !ref + rnn_re_init: False + use_rnnp: False + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Token loaders +tokens_loader_in: !new:utils.tokens.TokensLoader + data_path: !ref /input + +tokens_loader_out: !new:utils.tokens.TokensLoader + data_path: !ref /output + +# Codec +codec: !name:utils.tokenizer_interface.SQCodecTokenizer + save_path: !ref + checkpoint: ckpt_00190000.pth + config: config.yaml + sample_rate: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +dnsmos_computer: !name:metrics.dnsmos.DNSMOS + sample_rate: !ref + +dwer_computer: !name:metrics.dwer.DWER + model_hub: openai/whisper-small + save_path: !ref + sample_rate: !ref + +wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM + model_hub: microsoft/wavlm-base-sv + save_path: !ref + sample_rate: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wav2vec2.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wav2vec2.yaml new file mode 100644 index 000000000..977951f2e --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wav2vec2.yaml @@ -0,0 +1,219 @@ +# ########################################################################################### +# Model: CRDNN with discrete audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +run_name: wav2vec2_crdnn + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] + +# Data preparation +data_folder: !PLACEHOLDER +tokens_folder: !PLACEHOLDER +cached_data_folder: !ref # e.g., path/to/cache +train_csv: !ref /trainset_28spk_wav.csv +valid_csv: !ref /validset_wav.csv +test_csv: !ref /testset_wav.csv +splits: [trainset_28spk_wav, validset_wav, testset_wav] +num_valid_speakers: 2 +testing: True + +# Output folders +output_folder: !ref results/CRDNN// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: False +save_audios: False + +# Preprocessing parameters +train_remove_if_longer: 60.0 # Seconds +valid_remove_if_longer: 60.0 # Seconds +test_remove_if_longer: 60.0 # Seconds +sorting: random + +# Training parameters +num_epochs: 50 +grad_accumulation_factor: 1 +train_batch_size: 16 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 4 +nonfinite_patience: 10 +max_grad_norm: 0.001 +precision: fp16 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 + +# Optimizer parameters +lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.0005)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# Codec parameters +sample_rate: 16000 # Should match the tokenizer's sample rate +vocab_size: 1000 +num_codebooks: 2 +kmeans_dataset: LibriSpeech +SSL_layers: [1, 3] #[1, 3, 7, 12, 18, 23] + +# Embedding parameters +embedding_dim: 1024 +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.LeakyReLU +rnn_class: !name:speechbrain.nnet.RNN.LSTM +rnn_layers: 4 # @orion_step1: --rnn_layers~"uniform(1, 4,discrete=True)" +time_pooling_size: 1 +rnn_bidirectional: True +rnn_neurons: 256 +dnn_blocks: 2 +dnn_neurons: 256 +cnn_blocks: 2 +cnn_channels: (12, 12) +inter_layer_pooling_size: (2, 2) +cnn_kernelsize: (3, 3) + +# Modules +embedding: !new:custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.CRDNN.CRDNN + input_shape: [null, null, !ref ] + activation: !ref + dropout: !ref + cnn_blocks: !ref + cnn_channels: !ref + cnn_kernelsize: !ref + inter_layer_pooling_size: !ref + time_pooling: True + using_2d_pooling: False + time_pooling_size: !ref + rnn_class: !ref + rnn_layers: !ref + rnn_neurons: !ref + rnn_bidirectional: !ref + dnn_blocks: !ref + dnn_neurons: !ref + rnn_re_init: False + use_rnnp: False + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Token loaders +tokens_loader_in: !new:utils.tokens.TokensLoader + data_path: !ref /input + +tokens_loader_out: !new:utils.tokens.TokensLoader + data_path: !ref /output + +# Codec +codec: !name:utils.tokenizer_interface.DiscreteSSLTokenizer + save_path: !ref + ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 + source: facebook/wav2vec2-large + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref + vocoder_repo_id: speechbrain/hifigan-wav2vec2-k1000-LibriTTS + kmeans_dataset: !ref + num_clusters: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +dnsmos_computer: !name:metrics.dnsmos.DNSMOS + sample_rate: !ref + +dwer_computer: !name:metrics.dwer.DWER + model_hub: openai/whisper-small + save_path: !ref + sample_rate: !ref + +wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM + model_hub: microsoft/wavlm-base-sv + save_path: !ref + sample_rate: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wavlm.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wavlm.yaml new file mode 100644 index 000000000..59a19ae53 --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wavlm.yaml @@ -0,0 +1,219 @@ +# ########################################################################################### +# Model: CRDNN with discrete audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +run_name: wavlm_crdnn + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] + +# Data preparation +data_folder: !PLACEHOLDER +tokens_folder: !PLACEHOLDER +cached_data_folder: !ref # e.g., path/to/cache +train_csv: !ref /trainset_28spk_wav.csv +valid_csv: !ref /validset_wav.csv +test_csv: !ref /testset_wav.csv +splits: [trainset_28spk_wav, validset_wav, testset_wav] +num_valid_speakers: 2 +testing: True + +# Output folders +output_folder: !ref results/CRDNN// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: False +save_audios: False + +# Preprocessing parameters +train_remove_if_longer: 60.0 # Seconds +valid_remove_if_longer: 60.0 # Seconds +test_remove_if_longer: 60.0 # Seconds +sorting: random + +# Training parameters +num_epochs: 50 +grad_accumulation_factor: 1 +train_batch_size: 16 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 4 +nonfinite_patience: 10 +max_grad_norm: 0.001 +precision: fp16 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 + +# Optimizer parameters +lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.0005)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# Codec parameters +sample_rate: 16000 # Should match the tokenizer's sample rate +vocab_size: 1000 +num_codebooks: 2 +kmeans_dataset: LibriSpeech +SSL_layers: [1, 3] #[1, 3, 7, 12, 18, 23] + +# Embedding parameters +embedding_dim: 1024 +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.LeakyReLU +rnn_class: !name:speechbrain.nnet.RNN.LSTM +rnn_layers: 4 # @orion_step1: --rnn_layers~"uniform(1, 4,discrete=True)" +time_pooling_size: 1 +rnn_bidirectional: True +rnn_neurons: 256 +dnn_blocks: 2 +dnn_neurons: 256 +cnn_blocks: 2 +cnn_channels: (12, 12) +inter_layer_pooling_size: (2, 2) +cnn_kernelsize: (3, 3) + +# Modules +embedding: !new:custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.CRDNN.CRDNN + input_shape: [null, null, !ref ] + activation: !ref + dropout: !ref + cnn_blocks: !ref + cnn_channels: !ref + cnn_kernelsize: !ref + inter_layer_pooling_size: !ref + time_pooling: True + using_2d_pooling: False + time_pooling_size: !ref + rnn_class: !ref + rnn_layers: !ref + rnn_neurons: !ref + rnn_bidirectional: !ref + dnn_blocks: !ref + dnn_neurons: !ref + rnn_re_init: False + use_rnnp: False + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Token loaders +tokens_loader_in: !new:utils.tokens.TokensLoader + data_path: !ref /input + +tokens_loader_out: !new:utils.tokens.TokensLoader + data_path: !ref /output + +# Codec +codec: !name:utils.tokenizer_interface.DiscreteSSLTokenizer + save_path: !ref + ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM + source: microsoft/wavlm-large + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref + vocoder_repo_id: speechbrain/hifigan-wavlm-k1000-LibriTTS + kmeans_dataset: !ref + num_clusters: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +dnsmos_computer: !name:metrics.dnsmos.DNSMOS + sample_rate: !ref + +dwer_computer: !name:metrics.dwer.DWER + model_hub: openai/whisper-small + save_path: !ref + sample_rate: !ref + +wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM + model_hub: microsoft/wavlm-base-sv + save_path: !ref + sample_rate: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/VoiceBank/enhancement/crdnn/hparams/train_encodec.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wavtokenizer.yaml similarity index 68% rename from benchmarks/DASB/VoiceBank/enhancement/crdnn/hparams/train_encodec.yaml rename to benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wavtokenizer.yaml index 310f31132..3e5b50d00 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/crdnn/hparams/train_encodec.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wavtokenizer.yaml @@ -1,9 +1,9 @@ # ########################################################################################### -# Model: CRDNN with EnCodec audio representations +# Model: CRDNN with discrete audio representations # Authors: Luca Della Libera 2024 # ########################################################################################### -experiment_name: encodec +run_name: wavtokenizer_crdnn # Seed needs to be set at top of YAML seed: 0 @@ -11,19 +11,22 @@ __set_seed: !apply:torch.manual_seed [!ref ] # Data preparation data_folder: !PLACEHOLDER +tokens_folder: !PLACEHOLDER +cached_data_folder: !ref # e.g., path/to/cache train_csv: !ref /trainset_28spk_wav.csv valid_csv: !ref /validset_wav.csv test_csv: !ref /testset_wav.csv splits: [trainset_28spk_wav, validset_wav, testset_wav] num_valid_speakers: 2 +testing: True # Output folders -output_folder: !ref results// +output_folder: !ref results/CRDNN// save_folder: !ref /save cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE # Save options -compute_metrics: True +compute_metrics: False save_audios: False # Preprocessing parameters @@ -31,94 +34,55 @@ train_remove_if_longer: 60.0 # Seconds valid_remove_if_longer: 60.0 # Seconds test_remove_if_longer: 60.0 # Seconds sorting: random -use_cache: True # Training parameters num_epochs: 50 -grad_accumulation_factor: 16 -train_batch_size: 1 +grad_accumulation_factor: 1 +train_batch_size: 16 valid_batch_size: 1 test_batch_size: 1 dataloader_workers: 4 nonfinite_patience: 10 -max_grad_norm: 5.0 -precision: fp32 +max_grad_norm: 0.001 +precision: fp16 ckpt_interval_minutes: 6000 keep_checkpoints: 1 -augment: False -augment_prob: 0.75 # Optimizer parameters -lr: 0.0005 +lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.0005)" weight_decay: 0.01 improvement_threshold: 0.0025 annealing_factor: 0.9 patient: 1 -# EnCodec parameters -# sample_rate: [24000, 24000, 24000, 24000] -# vocab_size: [1024, 1024, 1024, 1024] -# num_codebooks: [2, 4, 8, 16, 32] -# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0] +# Codec parameters +model_hub: novateur/WavTokenizer-medium-music-audio-75token +config: wavtokenizer_mediumdata_music_audio_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml +checkpoint: wavtokenizer_medium_music_audio_320_24k_v2.ckpt sample_rate: 24000 -vocab_size: 1024 -num_codebooks: 2 -bandwidth: !ref * 75 / 100 +num_codebooks: 1 +vocab_size: 4096 # Embedding parameters embedding_dim: 1024 -pretrain_embedding: False # If True, must match the codec's embedding size (128) freeze_embedding: False # Encoder parameters dropout: 0.1 activation: !name:torch.nn.LeakyReLU rnn_class: !name:speechbrain.nnet.RNN.LSTM -rnn_layers: 4 +rnn_layers: 4 # @orion_step1: --rnn_layers~"uniform(1, 4,discrete=True)" time_pooling_size: 1 rnn_bidirectional: True rnn_neurons: 256 dnn_blocks: 2 dnn_neurons: 256 cnn_blocks: 2 -cnn_channels: (16, 16) +cnn_channels: (12, 12) inter_layer_pooling_size: (2, 2) cnn_kernelsize: (3, 3) -# Augmentation -drop_freq: !new:speechbrain.augment.time_domain.DropFreq - drop_freq_low: 0 # Min frequency band dropout probability - drop_freq_high: 1 # Max frequency band dropout probability - drop_freq_count_low: 1 # Min number of frequency bands to drop - drop_freq_count_high: 3 # Max number of frequency bands to drop - drop_freq_width: 0.05 # Width of frequency bands to drop - -drop_chunk: !new:speechbrain.augment.time_domain.DropChunk - drop_length_low: 1 # Min number of audio chunks to drop - drop_length_high: 5 # Max number of audio chunks to drop - drop_count_low: 1000 # Min length of audio chunks to drop - drop_count_high: 2000 # Max length of audio chunks to drop - -augmentation: !new:speechbrain.augment.augmenter.Augmenter - parallel_augment: False - concat_original: False - repeat_augment: 1 - shuffle_augmentations: False - min_augmentations: 2 - max_augmentations: 2 - augment_prob: !ref - augmentations: [!ref , !ref ] - # Modules -codec: !new:speechbrain.lobes.models.huggingface_transformers.encodec.Encodec - source: facebook/encodec_24khz # Only the 24kHz version supports mono audio - save_path: !ref - sample_rate: !ref - bandwidth: !ref - flat_embeddings: False - freeze: True - renorm_embeddings: False - embedding: !new:custom_model.Discrete_EmbeddingLayer num_codebooks: !ref vocab_size: !ref @@ -146,7 +110,7 @@ encoder: !new:speechbrain.lobes.models.CRDNN.CRDNN rnn_bidirectional: !ref dnn_blocks: !ref dnn_neurons: !ref - rnn_re_init: True + rnn_re_init: False use_rnnp: False head: !new:torch.nn.Linear @@ -185,6 +149,22 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler annealing_factor: !ref patient: !ref +# Token loaders +tokens_loader_in: !new:utils.tokens.TokensLoader + data_path: !ref /input + +tokens_loader_out: !new:utils.tokens.TokensLoader + data_path: !ref /output + +# Codec +codec: !name:utils.tokenizer_interface.WavTokenizerWrapper + source: !ref + save_path: !ref + checkpoint: !ref + config: !ref + sample_rate: !ref + freeze: True + # Dataloaders train_dataloader_kwargs: batch_size: !ref @@ -220,11 +200,6 @@ wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM save_path: !ref sample_rate: !ref -ecapatdnn_sim_computer: !name:metrics.spk_sim.SpkSimECAPATDNN - model_hub: speechbrain/spkrec-ecapa-voxceleb - save_path: !apply:os.path.join [!ref , models--speechbrain--spkrec-ecapa-voxceleb] - sample_rate: !ref - # Counters, checkpointers, loggers, etc. epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter limit: !ref @@ -237,4 +212,4 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer counter: !ref train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref /train_log.txt + save_file: !ref /train_log.txt \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/conformer/hparams/train_discrete_ssl.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train.yaml similarity index 51% rename from benchmarks/DASB/VoiceBank/enhancement/conformer/hparams/train_discrete_ssl.yaml rename to benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train.yaml index 62f83fdf4..3e37dddaa 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/conformer/hparams/train_discrete_ssl.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train.yaml @@ -1,9 +1,9 @@ # ########################################################################################### -# Model: Conformer with discrete SSL audio representations +# Model: Conformer with discrete audio representations # Authors: Luca Della Libera 2024 # ########################################################################################### -experiment_name: discrete_hubert +run_name: encodec # Seed needs to be set at top of YAML seed: 0 @@ -11,19 +11,22 @@ __set_seed: !apply:torch.manual_seed [!ref ] # Data preparation data_folder: !PLACEHOLDER +tokens_folder: !PLACEHOLDER +cached_data_folder: !ref # e.g., path/to/cache train_csv: !ref /trainset_28spk_wav.csv -valid_csv: !ref /trainset_28spk_wav.csv -test_csv: !ref /trainset_28spk_wav.csv +valid_csv: !ref /validset_wav.csv +test_csv: !ref /testset_wav.csv splits: [trainset_28spk_wav, validset_wav, testset_wav] num_valid_speakers: 2 +testing: True # Output folders -output_folder: !ref results// +output_folder: !ref results/Conformer// save_folder: !ref /save cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE # Save options -compute_metrics: True +compute_metrics: False save_audios: False # Preprocessing parameters @@ -31,12 +34,11 @@ train_remove_if_longer: 60.0 # Seconds valid_remove_if_longer: 60.0 # Seconds test_remove_if_longer: 60.0 # Seconds sorting: random -use_cache: True # Training parameters num_epochs: 50 -grad_accumulation_factor: 16 -train_batch_size: 1 +grad_accumulation_factor: 1 +train_batch_size: 16 valid_batch_size: 1 test_batch_size: 1 dataloader_workers: 4 @@ -45,31 +47,26 @@ max_grad_norm: 5.0 precision: fp32 ckpt_interval_minutes: 6000 keep_checkpoints: 1 -augment: False -augment_prob: 0.75 # Optimizer parameters -lr: 0.0005 +lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.5)" weight_decay: 0.01 improvement_threshold: 0.0025 annealing_factor: 0.9 patient: 1 -# Discrete SSL parameters -sample_rate: 16000 -num_clusters: 1000 -vocab_size: !ref -SSL_layers: [1, 3, 7, 12, 18, 23] -num_codebooks: !apply:len [!ref ] -ssl_hub: facebook/hubert-large-ll60k -vocoder_hub: speechbrain/hifigan-hubert-l1-3-7-12-18-23-k1000-LibriTTS # Must be consistent with ssl_hub/SSL_layers/num_clusters -kmeans_repo_id: speechbrain/SSL_Quantization -kmeans_dataset: LibriSpeech-100-360-500 -ssl_model_type: hubert +# Codec parameters +codec_type: encodec +sample_rate: 24000 # Should match the tokenizer's sample rate +vocab_size: 1024 +num_codebooks: 2 +bandwidth: 24.0 +kmeans_dataset: LibriSpeech +SSL_layers: [1, 3] #[1, 3, 7, 12, 18, 23] +tokenizer_save_path: /home/luca/Downloads/SQ-Codec/sqcodec # Embedding parameters embedding_dim: 1024 -pretrain_embedding: False # If True, must match the codec's embedding size (1024) freeze_embedding: False # Encoder parameters @@ -77,74 +74,12 @@ dropout: 0.1 activation: !name:torch.nn.GELU d_model: 256 nhead: 4 -num_layers: 6 +num_layers: 6 # @orion_step1: --num_layers~"uniform(1, 8,discrete=True)" d_ffn: 2048 max_length: 2000 causal: False -# Augmentation -drop_freq: !new:speechbrain.augment.time_domain.DropFreq - drop_freq_low: 0 # Min frequency band dropout probability - drop_freq_high: 1 # Max frequency band dropout probability - drop_freq_count_low: 1 # Min number of frequency bands to drop - drop_freq_count_high: 3 # Max number of frequency bands to drop - drop_freq_width: 0.05 # Width of frequency bands to drop - -drop_chunk: !new:speechbrain.augment.time_domain.DropChunk - drop_length_low: 1 # Min number of audio chunks to drop - drop_length_high: 5 # Max number of audio chunks to drop - drop_count_low: 1000 # Min length of audio chunks to drop - drop_count_high: 2000 # Max length of audio chunks to drop - -augmentation: !new:speechbrain.augment.augmenter.Augmenter - parallel_augment: False - concat_original: False - repeat_augment: 1 - shuffle_augmentations: False - min_augmentations: 2 - max_augmentations: 2 - augment_prob: !ref - augmentations: [!ref , !ref ] - # Modules -ssl_model: !new:utils.SBWav2Vec2ForwardWrapper - wav2vec2: !apply:speechbrain.utils.hparams.choice - value: !ref - choices: - wavlm: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM - source: !ref - output_norm: False - freeze: True - freeze_feature_extractor: True - output_all_hiddens: True - save_path: !ref - hubert: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT - source: !ref - output_norm: False - freeze: True - freeze_feature_extractor: True - output_all_hiddens: True - save_path: !ref - wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 - source: !ref - output_norm: False - freeze: True - freeze_feature_extractor: True - output_all_hiddens: True - save_path: !ref - layer_ids: !ref - -codec_quantizer: !new:speechbrain.lobes.models.huggingface_transformers.discrete_ssl.DiscreteSSL - save_path: !ref - ssl_model: !ref - kmeans_dataset: !ref - kmeans_repo_id: !ref - num_clusters: !ref - -codec_vocoder: !apply:speechbrain.inference.vocoders.UnitHIFIGAN.from_hparams - source: !ref - savedir: !apply:os.path.join [!ref , !ref ] - embedding: !new:custom_model.Discrete_EmbeddingLayer num_codebooks: !ref vocab_size: !ref @@ -172,8 +107,7 @@ encoder: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR head: !new:torch.nn.Linear in_features: !ref - # Workaround to bypass HyperPyYAML lack of flexibility - out_features: !apply:train_discrete_ssl.len_ [!ref , !ref ] + out_features: !ref * modules: embedding: !ref @@ -207,6 +141,88 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler annealing_factor: !ref patient: !ref +# Token loaders +tokens_loader_in: !new:utils.tokens.TokensLoader + data_path: !ref /input + +tokens_loader_out: !new:utils.tokens.TokensLoader + data_path: !ref /output + +# Codec +codec: !apply:speechbrain.utils.hparams.choice + value: !ref + choices: + dac: !new:utils.tokenizer_interface.DACTokenizer + model_type: 24khz + model_bitrate: 8kbps + load_pretrained: True + tag: latest + encodec: !new:utils.tokenizer_interface.EncodecTokenizer + source: facebook/encodec_24khz + save_path: !ref + sample_rate: !ref + bandwidth: !ref + flat_embeddings: False + freeze: True + renorm_embeddings: False + mimi: !new:utils.tokenizer_interface.MimiTokenizer + source: kyutai/mimi + save_path: !ref + num_codebooks: !ref + sample_rate: !ref + speech_tokenizer: !new:utils.tokenizer_interface.SpeechTokenizerWrapper + source: fnlp/SpeechTokenizer + save_path: !ref + sample_rate: !ref + sqcodec: !new:utils.tokenizer_interface.SQCodecTokenizer + save_path: !ref + checkpoint: ckpt_00190000.pth + config: config.yaml + sample_rate: !ref + wavtokenizer: !new:utils.tokenizer_interface.WavTokenizerWrapper + source: novateur/WavTokenizer-medium-music-audio-75token + save_path: !ref + checkpoint: wavtokenizer_medium_music_audio_320_24k_v2.ckpt + config: wavtokenizer_mediumdata_music_audio_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml + sample_rate: !ref + freeze: True + wavlm: !new:utils.tokenizer_interface.DiscreteSSLTokenizer + save_path: !ref + ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM + source: microsoft/wavlm-large + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref + vocoder_repo_id: speechbrain/hifigan-wavlm-k1000-LibriTTS + kmeans_dataset: !ref + num_clusters: !ref + hubert: !new:utils.tokenizer_interface.DiscreteSSLTokenizer + save_path: !ref + ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT + source: facebook/hubert-large-ll60k + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref + vocoder_repo_id: speechbrain/hifigan-hubert-k1000-LibriTTS + kmeans_dataset: !ref + num_clusters: !ref + wav2vec2: !new:utils.tokenizer_interface.DiscreteSSLTokenizer + save_path: !ref + ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 + source: facebook/wav2vec2-large + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref + vocoder_repo_id: speechbrain/hifigan-wav2vec2-k1000-LibriTTS + kmeans_dataset: !ref + num_clusters: !ref + # Dataloaders train_dataloader_kwargs: batch_size: !ref @@ -242,11 +258,6 @@ wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM save_path: !ref sample_rate: !ref -ecapatdnn_sim_computer: !name:metrics.spk_sim.SpkSimECAPATDNN - model_hub: speechbrain/spkrec-ecapa-voxceleb - save_path: !apply:os.path.join [!ref , models--speechbrain--spkrec-ecapa-voxceleb] - sample_rate: !ref - # Counters, checkpointers, loggers, etc. epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter limit: !ref @@ -259,4 +270,4 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer counter: !ref train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref /train_log.txt + save_file: !ref /train_log.txt \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/conformer/hparams/train_continuous_wavlm.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_continuous_wavlm.yaml similarity index 77% rename from benchmarks/DASB/VoiceBank/enhancement/conformer/hparams/train_continuous_wavlm.yaml rename to benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_continuous_wavlm.yaml index 30d85d5ef..cdb4e35ab 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/conformer/hparams/train_continuous_wavlm.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_continuous_wavlm.yaml @@ -3,7 +3,7 @@ # Authors: Luca Della Libera 2024 # ########################################################################################### -experiment_name: continuous_wavlm +run_name: continuous_wavlm_conformer # Seed needs to be set at top of YAML seed: 0 @@ -18,12 +18,12 @@ splits: [trainset_28spk_wav, validset_wav, testset_wav] num_valid_speakers: 2 # Output folders -output_folder: !ref results// +output_folder: !ref results/Conformer// save_folder: !ref /save cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE # Save options -compute_metrics: True +compute_metrics: False save_audios: False # Preprocessing parameters @@ -35,8 +35,8 @@ use_cache: False # Training parameters num_epochs: 50 -grad_accumulation_factor: 16 -train_batch_size: 1 +grad_accumulation_factor: 1 +train_batch_size: 16 valid_batch_size: 1 test_batch_size: 1 dataloader_workers: 4 @@ -45,8 +45,6 @@ max_grad_norm: 5.0 precision: fp32 ckpt_interval_minutes: 6000 keep_checkpoints: 1 -augment: False -augment_prob: 0.75 # Optimizer parameters lr: 0.0005 @@ -73,32 +71,8 @@ d_ffn: 2048 max_length: 2000 causal: False -# Augmentation -drop_freq: !new:speechbrain.augment.time_domain.DropFreq - drop_freq_low: 0 # Min frequency band dropout probability - drop_freq_high: 1 # Max frequency band dropout probability - drop_freq_count_low: 1 # Min number of frequency bands to drop - drop_freq_count_high: 3 # Max number of frequency bands to drop - drop_freq_width: 0.05 # Width of frequency bands to drop - -drop_chunk: !new:speechbrain.augment.time_domain.DropChunk - drop_length_low: 1 # Min number of audio chunks to drop - drop_length_high: 5 # Max number of audio chunks to drop - drop_count_low: 1000 # Min length of audio chunks to drop - drop_count_high: 2000 # Max length of audio chunks to drop - -augmentation: !new:speechbrain.augment.augmenter.Augmenter - parallel_augment: False - concat_original: False - repeat_augment: 1 - shuffle_augmentations: False - min_augmentations: 2 - max_augmentations: 2 - augment_prob: !ref - augmentations: [!ref , !ref ] - # Modules -ssl_model: !new:utils.SBWav2Vec2ForwardWrapper +ssl_model: !new:common.SBWav2Vec2ForwardWrapper wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM source: !ref output_norm: False @@ -196,11 +170,6 @@ wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM save_path: !ref sample_rate: !ref -ecapatdnn_sim_computer: !name:metrics.spk_sim.SpkSimECAPATDNN - model_hub: speechbrain/spkrec-ecapa-voxceleb - save_path: !apply:os.path.join [!ref , models--speechbrain--spkrec-ecapa-voxceleb] - sample_rate: !ref - # Counters, checkpointers, loggers, etc. epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter limit: !ref diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_dac.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_dac.yaml new file mode 100644 index 000000000..902e16029 --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_dac.yaml @@ -0,0 +1,200 @@ +# ########################################################################################### +# Model: Conformer with discrete audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +run_name: dac_conformer + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] + +# Data preparation +data_folder: !PLACEHOLDER +tokens_folder: !PLACEHOLDER +cached_data_folder: !ref # e.g., path/to/cache +train_csv: !ref /trainset_28spk_wav.csv +valid_csv: !ref /validset_wav.csv +test_csv: !ref /testset_wav.csv +splits: [trainset_28spk_wav, validset_wav, testset_wav] +num_valid_speakers: 2 +testing: True + +# Output folders +output_folder: !ref results/Conformer// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: False +save_audios: False + +# Preprocessing parameters +train_remove_if_longer: 60.0 # Seconds +valid_remove_if_longer: 60.0 # Seconds +test_remove_if_longer: 60.0 # Seconds +sorting: random + +# Training parameters +num_epochs: 50 +grad_accumulation_factor: 1 +train_batch_size: 16 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 4 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 + +# Optimizer parameters +lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.5)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# Codec parameters +sample_rate: 24000 # Should match the tokenizer's sample rate +vocab_size: 1024 +num_codebooks: 2 + +# Embedding parameters +embedding_dim: 1024 +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.GELU +d_model: 256 +nhead: 4 +num_layers: 6 # @orion_step1: --num_layers~"uniform(1, 8,discrete=True)" +d_ffn: 2048 +max_length: 2000 +causal: False + +# Modules +embedding: !new:custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR + input_size: !ref + tgt_vocab: -1 + d_model: !ref + nhead: !ref + num_encoder_layers: !ref + num_decoder_layers: 0 + d_ffn: !ref + dropout: !ref + activation: !ref + max_length: !ref + encoder_module: conformer + normalize_before: True + causal: !ref + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Token loaders +tokens_loader_in: !new:utils.tokens.TokensLoader + data_path: !ref /input + +tokens_loader_out: !new:utils.tokens.TokensLoader + data_path: !ref /output + +# Codec +codec: !name:utils.tokenizer_interface.DACTokenizer + model_type: 24khz + model_bitrate: 8kbps + load_pretrained: True + tag: latest + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +dnsmos_computer: !name:metrics.dnsmos.DNSMOS + sample_rate: !ref + +dwer_computer: !name:metrics.dwer.DWER + model_hub: openai/whisper-small + save_path: !ref + sample_rate: !ref + +wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM + model_hub: microsoft/wavlm-base-sv + save_path: !ref + sample_rate: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/conformer/hparams/train_speech_tokenizer.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_encodec.yaml similarity index 71% rename from benchmarks/DASB/VoiceBank/enhancement/conformer/hparams/train_speech_tokenizer.yaml rename to benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_encodec.yaml index 1e0ad02dc..9244de0e7 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/conformer/hparams/train_speech_tokenizer.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_encodec.yaml @@ -1,9 +1,9 @@ # ########################################################################################### -# Model: Conformer with SpeechTokenizer audio representations +# Model: Conformer with discrete audio representations # Authors: Luca Della Libera 2024 # ########################################################################################### -experiment_name: speech_tokenizer +run_name: encodec_conformer # Seed needs to be set at top of YAML seed: 0 @@ -11,19 +11,22 @@ __set_seed: !apply:torch.manual_seed [!ref ] # Data preparation data_folder: !PLACEHOLDER +tokens_folder: !PLACEHOLDER +cached_data_folder: !ref # e.g., path/to/cache train_csv: !ref /trainset_28spk_wav.csv valid_csv: !ref /validset_wav.csv test_csv: !ref /testset_wav.csv splits: [trainset_28spk_wav, validset_wav, testset_wav] num_valid_speakers: 2 +testing: True # Output folders -output_folder: !ref results// +output_folder: !ref results/Conformer// save_folder: !ref /save cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE # Save options -compute_metrics: True +compute_metrics: False save_audios: False # Preprocessing parameters @@ -31,12 +34,11 @@ train_remove_if_longer: 60.0 # Seconds valid_remove_if_longer: 60.0 # Seconds test_remove_if_longer: 60.0 # Seconds sorting: random -use_cache: True # Training parameters num_epochs: 50 -grad_accumulation_factor: 16 -train_batch_size: 1 +grad_accumulation_factor: 1 +train_batch_size: 16 valid_batch_size: 1 test_batch_size: 1 dataloader_workers: 4 @@ -45,24 +47,22 @@ max_grad_norm: 5.0 precision: fp32 ckpt_interval_minutes: 6000 keep_checkpoints: 1 -augment: False -augment_prob: 0.75 # Optimizer parameters -lr: 0.0005 +lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.5)" weight_decay: 0.01 improvement_threshold: 0.0025 annealing_factor: 0.9 patient: 1 -# SpeechTokenizer parameters -sample_rate: 16000 +# Codec parameters +sample_rate: 24000 # Should match the tokenizer's sample rate vocab_size: 1024 -num_codebooks: 2 # Must be <= 8 +num_codebooks: 2 +bandwidth: 1.5 # Embedding parameters embedding_dim: 1024 -pretrain_embedding: False # If True, must match the codec's embedding size (1024) freeze_embedding: False # Encoder parameters @@ -70,40 +70,12 @@ dropout: 0.1 activation: !name:torch.nn.GELU d_model: 256 nhead: 4 -num_layers: 6 +num_layers: 6 # @orion_step1: --num_layers~"uniform(1, 8,discrete=True)" d_ffn: 2048 max_length: 2000 causal: False -# Augmentation -drop_freq: !new:speechbrain.augment.time_domain.DropFreq - drop_freq_low: 0 # Min frequency band dropout probability - drop_freq_high: 1 # Max frequency band dropout probability - drop_freq_count_low: 1 # Min number of frequency bands to drop - drop_freq_count_high: 3 # Max number of frequency bands to drop - drop_freq_width: 0.05 # Width of frequency bands to drop - -drop_chunk: !new:speechbrain.augment.time_domain.DropChunk - drop_length_low: 1 # Min number of audio chunks to drop - drop_length_high: 5 # Max number of audio chunks to drop - drop_count_low: 1000 # Min length of audio chunks to drop - drop_count_high: 2000 # Max length of audio chunks to drop - -augmentation: !new:speechbrain.augment.augmenter.Augmenter - parallel_augment: False - concat_original: False - repeat_augment: 1 - shuffle_augmentations: False - min_augmentations: 2 - max_augmentations: 2 - augment_prob: !ref - augmentations: [!ref , !ref ] - # Modules -codec: !new:speechbrain.lobes.models.discrete.speechtokenizer_interface.SpeechTokenizer_interface - source: fnlp/SpeechTokenizer - save_path: !ref - embedding: !new:custom_model.Discrete_EmbeddingLayer num_codebooks: !ref vocab_size: !ref @@ -165,6 +137,23 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler annealing_factor: !ref patient: !ref +# Token loaders +tokens_loader_in: !new:utils.tokens.TokensLoader + data_path: !ref /input + +tokens_loader_out: !new:utils.tokens.TokensLoader + data_path: !ref /output + +# Codec +codec: !name:utils.tokenizer_interface.EncodecTokenizer + source: facebook/encodec_24khz + save_path: !ref + sample_rate: !ref + bandwidth: !ref + flat_embeddings: False + freeze: True + renorm_embeddings: False + # Dataloaders train_dataloader_kwargs: batch_size: !ref @@ -200,11 +189,6 @@ wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM save_path: !ref sample_rate: !ref -ecapatdnn_sim_computer: !name:metrics.spk_sim.SpkSimECAPATDNN - model_hub: speechbrain/spkrec-ecapa-voxceleb - save_path: !apply:os.path.join [!ref , models--speechbrain--spkrec-ecapa-voxceleb] - sample_rate: !ref - # Counters, checkpointers, loggers, etc. epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter limit: !ref diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_hubert.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_hubert.yaml new file mode 100644 index 000000000..96cdda32a --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_hubert.yaml @@ -0,0 +1,209 @@ +# ########################################################################################### +# Model: Conformer with discrete audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +run_name: hubert_conformer + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] + +# Data preparation +data_folder: !PLACEHOLDER +tokens_folder: !PLACEHOLDER +cached_data_folder: !ref # e.g., path/to/cache +train_csv: !ref /trainset_28spk_wav.csv +valid_csv: !ref /validset_wav.csv +test_csv: !ref /testset_wav.csv +splits: [trainset_28spk_wav, validset_wav, testset_wav] +num_valid_speakers: 2 +testing: True + +# Output folders +output_folder: !ref results/Conformer// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: False +save_audios: False + +# Preprocessing parameters +train_remove_if_longer: 60.0 # Seconds +valid_remove_if_longer: 60.0 # Seconds +test_remove_if_longer: 60.0 # Seconds +sorting: random + +# Training parameters +num_epochs: 50 +grad_accumulation_factor: 1 +train_batch_size: 16 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 4 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 + +# Optimizer parameters +lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.5)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# Codec parameters +sample_rate: 16000 # Should match the tokenizer's sample rate +vocab_size: 1000 +num_codebooks: 2 +kmeans_dataset: LibriSpeech +SSL_layers: [1, 3] #[1, 3, 7, 12, 18, 23] + +# Embedding parameters +embedding_dim: 1024 +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.GELU +d_model: 256 +nhead: 4 +num_layers: 6 # @orion_step1: --num_layers~"uniform(1, 8,discrete=True)" +d_ffn: 2048 +max_length: 2000 +causal: False + +# Modules +embedding: !new:custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR + input_size: !ref + tgt_vocab: -1 + d_model: !ref + nhead: !ref + num_encoder_layers: !ref + num_decoder_layers: 0 + d_ffn: !ref + dropout: !ref + activation: !ref + max_length: !ref + encoder_module: conformer + normalize_before: True + causal: !ref + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Token loaders +tokens_loader_in: !new:utils.tokens.TokensLoader + data_path: !ref /input + +tokens_loader_out: !new:utils.tokens.TokensLoader + data_path: !ref /output + +# Codec +codec: !name:utils.tokenizer_interface.DiscreteSSLTokenizer + save_path: !ref + ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT + source: facebook/hubert-large-ll60k + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref + vocoder_repo_id: speechbrain/hifigan-hubert-k1000-LibriTTS + kmeans_dataset: !ref + num_clusters: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +dnsmos_computer: !name:metrics.dnsmos.DNSMOS + sample_rate: !ref + +dwer_computer: !name:metrics.dwer.DWER + model_hub: openai/whisper-small + save_path: !ref + sample_rate: !ref + +wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM + model_hub: microsoft/wavlm-base-sv + save_path: !ref + sample_rate: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_mimi.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_mimi.yaml new file mode 100644 index 000000000..f53e7f768 --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_mimi.yaml @@ -0,0 +1,200 @@ +# ########################################################################################### +# Model: Conformer with discrete audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +run_name: mimi_conformer + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] + +# Data preparation +data_folder: !PLACEHOLDER +tokens_folder: !PLACEHOLDER +cached_data_folder: !ref # e.g., path/to/cache +train_csv: !ref /trainset_28spk_wav.csv +valid_csv: !ref /validset_wav.csv +test_csv: !ref /testset_wav.csv +splits: [trainset_28spk_wav, validset_wav, testset_wav] +num_valid_speakers: 2 +testing: True + +# Output folders +output_folder: !ref results/Conformer// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: False +save_audios: False + +# Preprocessing parameters +train_remove_if_longer: 60.0 # Seconds +valid_remove_if_longer: 60.0 # Seconds +test_remove_if_longer: 60.0 # Seconds +sorting: random + +# Training parameters +num_epochs: 50 +grad_accumulation_factor: 1 +train_batch_size: 16 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 4 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 + +# Optimizer parameters +lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.5)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# Codec parameters +sample_rate: 24000 # Should match the tokenizer's sample rate +vocab_size: 2048 +num_codebooks: 2 + +# Embedding parameters +embedding_dim: 1024 +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.GELU +d_model: 256 +nhead: 4 +num_layers: 6 # @orion_step1: --num_layers~"uniform(1, 8,discrete=True)" +d_ffn: 2048 +max_length: 2000 +causal: False + +# Modules +embedding: !new:custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR + input_size: !ref + tgt_vocab: -1 + d_model: !ref + nhead: !ref + num_encoder_layers: !ref + num_decoder_layers: 0 + d_ffn: !ref + dropout: !ref + activation: !ref + max_length: !ref + encoder_module: conformer + normalize_before: True + causal: !ref + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Token loaders +tokens_loader_in: !new:utils.tokens.TokensLoader + data_path: !ref /input + +tokens_loader_out: !new:utils.tokens.TokensLoader + data_path: !ref /output + +# Codec +codec: !name:utils.tokenizer_interface.MimiTokenizer + source: kyutai/mimi + save_path: !ref + num_codebooks: !ref + sample_rate: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +dnsmos_computer: !name:metrics.dnsmos.DNSMOS + sample_rate: !ref + +dwer_computer: !name:metrics.dwer.DWER + model_hub: openai/whisper-small + save_path: !ref + sample_rate: !ref + +wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM + model_hub: microsoft/wavlm-base-sv + save_path: !ref + sample_rate: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_speech_tokenizer.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_speech_tokenizer.yaml new file mode 100644 index 000000000..71d907b0d --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_speech_tokenizer.yaml @@ -0,0 +1,199 @@ +# ########################################################################################### +# Model: Conformer with discrete audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +run_name: speech_tokenizer_conformer + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] + +# Data preparation +data_folder: !PLACEHOLDER +tokens_folder: !PLACEHOLDER +cached_data_folder: !ref # e.g., path/to/cache +train_csv: !ref /trainset_28spk_wav.csv +valid_csv: !ref /validset_wav.csv +test_csv: !ref /testset_wav.csv +splits: [trainset_28spk_wav, validset_wav, testset_wav] +num_valid_speakers: 2 +testing: True + +# Output folders +output_folder: !ref results/Conformer// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: False +save_audios: False + +# Preprocessing parameters +train_remove_if_longer: 60.0 # Seconds +valid_remove_if_longer: 60.0 # Seconds +test_remove_if_longer: 60.0 # Seconds +sorting: random + +# Training parameters +num_epochs: 50 +grad_accumulation_factor: 1 +train_batch_size: 16 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 4 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 + +# Optimizer parameters +lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.5)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# Codec parameters +sample_rate: 16000 # Should match the tokenizer's sample rate +vocab_size: 1024 +num_codebooks: 2 + +# Embedding parameters +embedding_dim: 1024 +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.GELU +d_model: 256 +nhead: 4 +num_layers: 6 # @orion_step1: --num_layers~"uniform(1, 8,discrete=True)" +d_ffn: 2048 +max_length: 2000 +causal: False + +# Modules +embedding: !new:custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR + input_size: !ref + tgt_vocab: -1 + d_model: !ref + nhead: !ref + num_encoder_layers: !ref + num_decoder_layers: 0 + d_ffn: !ref + dropout: !ref + activation: !ref + max_length: !ref + encoder_module: conformer + normalize_before: True + causal: !ref + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Token loaders +tokens_loader_in: !new:utils.tokens.TokensLoader + data_path: !ref /input + +tokens_loader_out: !new:utils.tokens.TokensLoader + data_path: !ref /output + +# Codec +codec: !name:utils.tokenizer_interface.SpeechTokenizerWrapper + source: fnlp/SpeechTokenizer + save_path: !ref + sample_rate: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +dnsmos_computer: !name:metrics.dnsmos.DNSMOS + sample_rate: !ref + +dwer_computer: !name:metrics.dwer.DWER + model_hub: openai/whisper-small + save_path: !ref + sample_rate: !ref + +wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM + model_hub: microsoft/wavlm-base-sv + save_path: !ref + sample_rate: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/conformer/hparams/train_dac.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_sqcodec.yaml similarity index 67% rename from benchmarks/DASB/VoiceBank/enhancement/conformer/hparams/train_dac.yaml rename to benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_sqcodec.yaml index 93ecbc6fd..c5d056c21 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/conformer/hparams/train_dac.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_sqcodec.yaml @@ -1,9 +1,9 @@ # ########################################################################################### -# Model: Conformer with DAC audio representations +# Model: Conformer with discrete audio representations # Authors: Luca Della Libera 2024 # ########################################################################################### -experiment_name: dac +run_name: sqcodec_conformer # Seed needs to be set at top of YAML seed: 0 @@ -11,19 +11,22 @@ __set_seed: !apply:torch.manual_seed [!ref ] # Data preparation data_folder: !PLACEHOLDER +tokens_folder: !PLACEHOLDER +cached_data_folder: !ref # e.g., path/to/cache train_csv: !ref /trainset_28spk_wav.csv valid_csv: !ref /validset_wav.csv test_csv: !ref /testset_wav.csv splits: [trainset_28spk_wav, validset_wav, testset_wav] num_valid_speakers: 2 +testing: True # Output folders -output_folder: !ref results// +output_folder: !ref results/Conformer// save_folder: !ref /save cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE # Save options -compute_metrics: True +compute_metrics: False save_audios: False # Preprocessing parameters @@ -31,12 +34,11 @@ train_remove_if_longer: 60.0 # Seconds valid_remove_if_longer: 60.0 # Seconds test_remove_if_longer: 60.0 # Seconds sorting: random -use_cache: True # Training parameters num_epochs: 50 -grad_accumulation_factor: 16 -train_batch_size: 1 +grad_accumulation_factor: 1 +train_batch_size: 16 valid_batch_size: 1 test_batch_size: 1 dataloader_workers: 4 @@ -45,31 +47,22 @@ max_grad_norm: 5.0 precision: fp32 ckpt_interval_minutes: 6000 keep_checkpoints: 1 -augment: False -augment_prob: 0.75 # Optimizer parameters -lr: 0.0005 +lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.5)" weight_decay: 0.01 improvement_threshold: 0.0025 annealing_factor: 0.9 patient: 1 -# DAC parameters -# sample_rate: [16000, 24000, 44000, 44000] -# vocab_size: [1024, 1024, 1024, 1024] -# max_num_codebooks: [12, 32, 9, 18] -# model_type: [16khz, 24khz, 44khz, 44khz] -# model_bitrate: [8kbps, 8kbps, 8kbps, 16kbps] -sample_rate: 24000 # NOTE: must match DAC's model type -vocab_size: 1024 -num_codebooks: 2 # NOTE: must be smaller or equal to the maximum number of codebooks for the given model type -model_type: 24khz -model_bitrate: 8kbps +# Codec parameters +sample_rate: 16000 # Should match the tokenizer's sample rate +vocab_size: 19683 +num_codebooks: 2 +tokenizer_save_path: !PLACEHOLDER # Embedding parameters embedding_dim: 1024 -pretrain_embedding: False # If True, must match the codec's embedding size (1024) freeze_embedding: False # Encoder parameters @@ -77,42 +70,12 @@ dropout: 0.1 activation: !name:torch.nn.GELU d_model: 256 nhead: 4 -num_layers: 6 +num_layers: 6 # @orion_step1: --num_layers~"uniform(1, 8,discrete=True)" d_ffn: 2048 max_length: 2000 causal: False -# Augmentation -drop_freq: !new:speechbrain.augment.time_domain.DropFreq - drop_freq_low: 0 # Min frequency band dropout probability - drop_freq_high: 1 # Max frequency band dropout probability - drop_freq_count_low: 1 # Min number of frequency bands to drop - drop_freq_count_high: 3 # Max number of frequency bands to drop - drop_freq_width: 0.05 # Width of frequency bands to drop - -drop_chunk: !new:speechbrain.augment.time_domain.DropChunk - drop_length_low: 1 # Min number of audio chunks to drop - drop_length_high: 5 # Max number of audio chunks to drop - drop_count_low: 1000 # Min length of audio chunks to drop - drop_count_high: 2000 # Max length of audio chunks to drop - -augmentation: !new:speechbrain.augment.augmenter.Augmenter - parallel_augment: False - concat_original: False - repeat_augment: 1 - shuffle_augmentations: False - min_augmentations: 2 - max_augmentations: 2 - augment_prob: !ref - augmentations: [!ref , !ref ] - # Modules -codec: !new:speechbrain.lobes.models.discrete.dac.DAC - model_type: !ref - model_bitrate: !ref - load_pretrained: True - tag: latest - embedding: !new:custom_model.Discrete_EmbeddingLayer num_codebooks: !ref vocab_size: !ref @@ -174,6 +137,20 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler annealing_factor: !ref patient: !ref +# Token loaders +tokens_loader_in: !new:utils.tokens.TokensLoader + data_path: !ref /input + +tokens_loader_out: !new:utils.tokens.TokensLoader + data_path: !ref /output + +# Codec +codec: !name:utils.tokenizer_interface.SQCodecTokenizer + save_path: !ref + checkpoint: ckpt_00190000.pth + config: config.yaml + sample_rate: !ref + # Dataloaders train_dataloader_kwargs: batch_size: !ref @@ -209,11 +186,6 @@ wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM save_path: !ref sample_rate: !ref -ecapatdnn_sim_computer: !name:metrics.spk_sim.SpkSimECAPATDNN - model_hub: speechbrain/spkrec-ecapa-voxceleb - save_path: !apply:os.path.join [!ref , models--speechbrain--spkrec-ecapa-voxceleb] - sample_rate: !ref - # Counters, checkpointers, loggers, etc. epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter limit: !ref diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wav2vec2.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wav2vec2.yaml new file mode 100644 index 000000000..dcca7e7ea --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wav2vec2.yaml @@ -0,0 +1,209 @@ +# ########################################################################################### +# Model: Conformer with discrete audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +run_name: wav2vec2_conformer + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] + +# Data preparation +data_folder: !PLACEHOLDER +tokens_folder: !PLACEHOLDER +cached_data_folder: !ref # e.g., path/to/cache +train_csv: !ref /trainset_28spk_wav.csv +valid_csv: !ref /validset_wav.csv +test_csv: !ref /testset_wav.csv +splits: [trainset_28spk_wav, validset_wav, testset_wav] +num_valid_speakers: 2 +testing: True + +# Output folders +output_folder: !ref results/Conformer// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: False +save_audios: False + +# Preprocessing parameters +train_remove_if_longer: 60.0 # Seconds +valid_remove_if_longer: 60.0 # Seconds +test_remove_if_longer: 60.0 # Seconds +sorting: random + +# Training parameters +num_epochs: 50 +grad_accumulation_factor: 1 +train_batch_size: 16 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 4 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 + +# Optimizer parameters +lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.5)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# Codec parameters +sample_rate: 16000 # Should match the tokenizer's sample rate +vocab_size: 1000 +num_codebooks: 2 +kmeans_dataset: LibriSpeech +SSL_layers: [1, 3] #[1, 3, 7, 12, 18, 23] + +# Embedding parameters +embedding_dim: 1024 +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.GELU +d_model: 256 +nhead: 4 +num_layers: 6 # @orion_step1: --num_layers~"uniform(1, 8,discrete=True)" +d_ffn: 2048 +max_length: 2000 +causal: False + +# Modules +embedding: !new:custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR + input_size: !ref + tgt_vocab: -1 + d_model: !ref + nhead: !ref + num_encoder_layers: !ref + num_decoder_layers: 0 + d_ffn: !ref + dropout: !ref + activation: !ref + max_length: !ref + encoder_module: conformer + normalize_before: True + causal: !ref + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Token loaders +tokens_loader_in: !new:utils.tokens.TokensLoader + data_path: !ref /input + +tokens_loader_out: !new:utils.tokens.TokensLoader + data_path: !ref /output + +# Codec +codec: !name:utils.tokenizer_interface.DiscreteSSLTokenizer + save_path: !ref + ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 + source: facebook/wav2vec2-large + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref + vocoder_repo_id: speechbrain/hifigan-wav2vec2-k1000-LibriTTS + kmeans_dataset: !ref + num_clusters: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +dnsmos_computer: !name:metrics.dnsmos.DNSMOS + sample_rate: !ref + +dwer_computer: !name:metrics.dwer.DWER + model_hub: openai/whisper-small + save_path: !ref + sample_rate: !ref + +wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM + model_hub: microsoft/wavlm-base-sv + save_path: !ref + sample_rate: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wavlm.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wavlm.yaml new file mode 100644 index 000000000..02a598a80 --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wavlm.yaml @@ -0,0 +1,209 @@ +# ########################################################################################### +# Model: Conformer with discrete audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +run_name: wavlm_conformer + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] + +# Data preparation +data_folder: !PLACEHOLDER +tokens_folder: !PLACEHOLDER +cached_data_folder: !ref # e.g., path/to/cache +train_csv: !ref /trainset_28spk_wav.csv +valid_csv: !ref /validset_wav.csv +test_csv: !ref /testset_wav.csv +splits: [trainset_28spk_wav, validset_wav, testset_wav] +num_valid_speakers: 2 +testing: True + +# Output folders +output_folder: !ref results/Conformer// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: False +save_audios: False + +# Preprocessing parameters +train_remove_if_longer: 60.0 # Seconds +valid_remove_if_longer: 60.0 # Seconds +test_remove_if_longer: 60.0 # Seconds +sorting: random + +# Training parameters +num_epochs: 50 +grad_accumulation_factor: 1 +train_batch_size: 16 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 4 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 + +# Optimizer parameters +lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.5)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# Codec parameters +sample_rate: 16000 # Should match the tokenizer's sample rate +vocab_size: 1000 +num_codebooks: 2 +kmeans_dataset: LibriSpeech +SSL_layers: [1, 3] #[1, 3, 7, 12, 18, 23] + +# Embedding parameters +embedding_dim: 1024 +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.GELU +d_model: 256 +nhead: 4 +num_layers: 6 # @orion_step1: --num_layers~"uniform(1, 8,discrete=True)" +d_ffn: 2048 +max_length: 2000 +causal: False + +# Modules +embedding: !new:custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR + input_size: !ref + tgt_vocab: -1 + d_model: !ref + nhead: !ref + num_encoder_layers: !ref + num_decoder_layers: 0 + d_ffn: !ref + dropout: !ref + activation: !ref + max_length: !ref + encoder_module: conformer + normalize_before: True + causal: !ref + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Token loaders +tokens_loader_in: !new:utils.tokens.TokensLoader + data_path: !ref /input + +tokens_loader_out: !new:utils.tokens.TokensLoader + data_path: !ref /output + +# Codec +codec: !name:utils.tokenizer_interface.DiscreteSSLTokenizer + save_path: !ref + ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM + source: microsoft/wavlm-large + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref + vocoder_repo_id: speechbrain/hifigan-wavlm-k1000-LibriTTS + kmeans_dataset: !ref + num_clusters: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +dnsmos_computer: !name:metrics.dnsmos.DNSMOS + sample_rate: !ref + +dwer_computer: !name:metrics.dwer.DWER + model_hub: openai/whisper-small + save_path: !ref + sample_rate: !ref + +wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM + model_hub: microsoft/wavlm-base-sv + save_path: !ref + sample_rate: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/VoiceBank/enhancement/conformer/hparams/train_encodec.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wavtokenizer.yaml similarity index 68% rename from benchmarks/DASB/VoiceBank/enhancement/conformer/hparams/train_encodec.yaml rename to benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wavtokenizer.yaml index 4fe47d0e7..cc4572519 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/conformer/hparams/train_encodec.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wavtokenizer.yaml @@ -1,9 +1,9 @@ # ########################################################################################### -# Model: Conformer with EnCodec audio representations +# Model: Conformer with discrete audio representations # Authors: Luca Della Libera 2024 # ########################################################################################### -experiment_name: encodec +run_name: wavtokenizer_conformer # Seed needs to be set at top of YAML seed: 0 @@ -11,19 +11,22 @@ __set_seed: !apply:torch.manual_seed [!ref ] # Data preparation data_folder: !PLACEHOLDER +tokens_folder: !PLACEHOLDER +cached_data_folder: !ref # e.g., path/to/cache train_csv: !ref /trainset_28spk_wav.csv valid_csv: !ref /validset_wav.csv test_csv: !ref /testset_wav.csv splits: [trainset_28spk_wav, validset_wav, testset_wav] num_valid_speakers: 2 +testing: True # Output folders -output_folder: !ref results// +output_folder: !ref results/Conformer// save_folder: !ref /save cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE # Save options -compute_metrics: True +compute_metrics: False save_audios: False # Preprocessing parameters @@ -31,12 +34,11 @@ train_remove_if_longer: 60.0 # Seconds valid_remove_if_longer: 60.0 # Seconds test_remove_if_longer: 60.0 # Seconds sorting: random -use_cache: True # Training parameters num_epochs: 50 -grad_accumulation_factor: 16 -train_batch_size: 1 +grad_accumulation_factor: 1 +train_batch_size: 16 valid_batch_size: 1 test_batch_size: 1 dataloader_workers: 4 @@ -45,29 +47,24 @@ max_grad_norm: 5.0 precision: fp32 ckpt_interval_minutes: 6000 keep_checkpoints: 1 -augment: False -augment_prob: 0.75 # Optimizer parameters -lr: 0.0005 +lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.5)" weight_decay: 0.01 improvement_threshold: 0.0025 annealing_factor: 0.9 patient: 1 -# EnCodec parameters -# sample_rate: [24000, 24000, 24000, 24000] -# vocab_size: [1024, 1024, 1024, 1024] -# num_codebooks: [2, 4, 8, 16, 32] -# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0] +# Codec parameters +model_hub: novateur/WavTokenizer-medium-music-audio-75token +config: wavtokenizer_mediumdata_music_audio_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml +checkpoint: wavtokenizer_medium_music_audio_320_24k_v2.ckpt sample_rate: 24000 -vocab_size: 1024 -num_codebooks: 2 -bandwidth: !ref * 75 / 100 +num_codebooks: 1 +vocab_size: 4096 # Embedding parameters embedding_dim: 1024 -pretrain_embedding: False # If True, must match the codec's embedding size (128) freeze_embedding: False # Encoder parameters @@ -75,45 +72,12 @@ dropout: 0.1 activation: !name:torch.nn.GELU d_model: 256 nhead: 4 -num_layers: 6 +num_layers: 6 # @orion_step1: --num_layers~"uniform(1, 8,discrete=True)" d_ffn: 2048 max_length: 2000 causal: False -# Augmentation -drop_freq: !new:speechbrain.augment.time_domain.DropFreq - drop_freq_low: 0 # Min frequency band dropout probability - drop_freq_high: 1 # Max frequency band dropout probability - drop_freq_count_low: 1 # Min number of frequency bands to drop - drop_freq_count_high: 3 # Max number of frequency bands to drop - drop_freq_width: 0.05 # Width of frequency bands to drop - -drop_chunk: !new:speechbrain.augment.time_domain.DropChunk - drop_length_low: 1 # Min number of audio chunks to drop - drop_length_high: 5 # Max number of audio chunks to drop - drop_count_low: 1000 # Min length of audio chunks to drop - drop_count_high: 2000 # Max length of audio chunks to drop - -augmentation: !new:speechbrain.augment.augmenter.Augmenter - parallel_augment: False - concat_original: False - repeat_augment: 1 - shuffle_augmentations: False - min_augmentations: 2 - max_augmentations: 2 - augment_prob: !ref - augmentations: [!ref , !ref ] - # Modules -codec: !new:speechbrain.lobes.models.huggingface_transformers.encodec.Encodec - source: facebook/encodec_24khz # Only the 24kHz version supports mono audio - save_path: !ref - sample_rate: !ref - bandwidth: !ref - flat_embeddings: False - freeze: True - renorm_embeddings: False - embedding: !new:custom_model.Discrete_EmbeddingLayer num_codebooks: !ref vocab_size: !ref @@ -175,6 +139,22 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler annealing_factor: !ref patient: !ref +# Token loaders +tokens_loader_in: !new:utils.tokens.TokensLoader + data_path: !ref /input + +tokens_loader_out: !new:utils.tokens.TokensLoader + data_path: !ref /output + +# Codec +codec: !name:utils.tokenizer_interface.WavTokenizerWrapper + source: !ref + save_path: !ref + checkpoint: !ref + config: !ref + sample_rate: !ref + freeze: True + # Dataloaders train_dataloader_kwargs: batch_size: !ref @@ -210,11 +190,6 @@ wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM save_path: !ref sample_rate: !ref -ecapatdnn_sim_computer: !name:metrics.spk_sim.SpkSimECAPATDNN - model_hub: speechbrain/spkrec-ecapa-voxceleb - save_path: !apply:os.path.join [!ref , models--speechbrain--spkrec-ecapa-voxceleb] - sample_rate: !ref - # Counters, checkpointers, loggers, etc. epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter limit: !ref diff --git a/benchmarks/DASB/VoiceBank/enhancement/metrics/spk_sim.py b/benchmarks/DASB/VoiceBank/enhancement/metrics/spk_sim.py index 9bf543371..106732f26 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/metrics/spk_sim.py +++ b/benchmarks/DASB/VoiceBank/enhancement/metrics/spk_sim.py @@ -7,53 +7,15 @@ import torch import torchaudio from speechbrain.dataio.dataio import length_to_mask -from speechbrain.inference.speaker import SpeakerRecognition from speechbrain.utils.metric_stats import MetricStats from transformers import AutoModelForAudioXVector -__all__ = ["SpkSimECAPATDNN", "SpkSimWavLM"] +__all__ = ["SpkSimWavLM"] SAMPLE_RATE = 16000 - -class SpkSimECAPATDNN(MetricStats): - def __init__(self, model_hub, save_path, sample_rate): - self.sample_rate = sample_rate - self.model = SpeakerRecognition.from_hparams( - model_hub, savedir=save_path - ).cpu() - self.clear() - - @torch.no_grad() - def append(self, ids, hyp_audio, ref_audio, lens=None): - assert hyp_audio.shape == ref_audio.shape - assert hyp_audio.ndim == 2 - - # Concatenate - audio = torch.cat([hyp_audio, ref_audio]) - if lens is not None: - lens = torch.cat([lens, lens]) - - # Resample - audio = torchaudio.functional.resample( - audio, self.sample_rate, SAMPLE_RATE - ) - - self.model.device = hyp_audio.device - self.model.to(hyp_audio.device) - self.model.eval() - - # Forward - embs = self.model.encode_batch(audio, lens, normalize=False) - hyp_embs, ref_embs = embs.split([len(hyp_audio), len(ref_audio)]) - scores = self.model.similarity(hyp_embs, ref_embs)[:, 0] - - self.ids += ids - self.scores += scores.cpu().tolist() - - class SpkSimWavLM(MetricStats): def __init__(self, model_hub, save_path, sample_rate): self.sample_rate = sample_rate diff --git a/benchmarks/DASB/VoiceBank/enhancement/model/ __init__.py b/benchmarks/DASB/VoiceBank/enhancement/model/ __init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/benchmarks/DASB/VoiceBank/enhancement/model/custom_model.py b/benchmarks/DASB/VoiceBank/enhancement/model/custom_model.py new file mode 100644 index 000000000..972d35c66 --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/model/custom_model.py @@ -0,0 +1,111 @@ +import torch + + +class AttentionMLP(torch.nn.Module): + def __init__(self, input_dim, hidden_dim): + super(AttentionMLP, self).__init__() + self.layers = torch.nn.Sequential( + torch.nn.Linear(input_dim, hidden_dim), + torch.nn.ReLU(), + torch.nn.Linear(hidden_dim, 1, bias=False), + ) + + def forward(self, x): + x = self.layers(x) + att_w = torch.nn.functional.softmax(x, dim=2) + return att_w + + +class Discrete_EmbeddingLayer(torch.nn.Module): + """This class handles embedding layers for discrete tokens. + + Arguments + --------- + num_codebooks: int , + number of codebooks of the tokenizer. + vocab_size : int, + size of the dictionary of embeddings + emb_dim: int , + the size of each embedding vector + pad_index: int (default: 0), + If specified, the entries at padding_idx do not contribute to the gradient. + init: boolean (default: False): + If set to True, init the embedding with the tokenizer embedding otherwise init randomly. + freeze: boolean (default: False) + If True, the embedding is frozen. If False, the model will be trained + alongside with the rest of the pipeline. + + Example + ------- + >>> from speechbrain.lobes.models.huggingface_transformers.encodec import Encodec + >>> model_hub = "facebook/encodec_24khz" + >>> save_path = "savedir" + >>> model = Encodec(model_hub, save_path) + >>> audio = torch.randn(4, 1000) + >>> length = torch.tensor([1.0, .5, .75, 1.0]) + >>> tokens, emb = model.encode(audio, length) + >>> print(tokens.shape) + torch.Size([4, 4, 2]) + >>> emb= Discrete_EmbeddingLayer(2, 1024, 1024) + >>> in_emb = emb(tokens) + >>> print(in_emb.shape) + torch.Size([4, 4, 2, 1024]) + """ + + def __init__( + self, + num_codebooks, + vocab_size, + emb_dim, + init=False, + freeze=False, + hidden_dim=None, + ): + super(Discrete_EmbeddingLayer, self).__init__() + self.vocab_size = vocab_size + self.num_codebooks = ( + len(num_codebooks) + if isinstance(num_codebooks, list) + else num_codebooks + ) + self.freeze = freeze + self.embedding = torch.nn.Embedding( + self.num_codebooks * vocab_size, emb_dim + ).requires_grad_(not self.freeze) + self.init = init + + # Add a linear layer to match dimensions if necessary + if hidden_dim is not None and hidden_dim != emb_dim: + self.proj_layer = torch.nn.Linear(emb_dim, hidden_dim) + else: + self.proj_layer = None + + def init_embedding(self, weights): + self.embedding.weight.data.copy_(weights) + + def forward(self, in_tokens): + """Computes the embedding for discrete tokens. + a sample. + + Arguments + --------- + in_tokens : torch.Tensor + A (Batch x Time x num_codebooks) + audio sample + Returns + ------- + in_embs : torch.Tensor + """ + with torch.set_grad_enabled(not self.freeze): + # Add unique token IDs across diffrent codebooks by adding num_codebooks * vocab_size + in_tokens += torch.arange( + 0, + self.num_codebooks * self.vocab_size, + self.vocab_size, + device=in_tokens.device, + ) + # Forward Pass to embedding and + in_embs = self.embedding(in_tokens) + if self.proj_layer is not None: + in_embs = self.proj_layer(in_embs) + return in_embs diff --git a/benchmarks/DASB/VoiceBank/enhancement/model/sq_codec.py b/benchmarks/DASB/VoiceBank/enhancement/model/sq_codec.py new file mode 100644 index 000000000..916820b96 --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/model/sq_codec.py @@ -0,0 +1,1361 @@ +"""This lobe enables the integration of speech codec model (SQ-Codec) with scalar quantization,. + +SQ-Codec effectively maps the complex speech signal into a finite and compact latent space, named scalar latent space. + +Repository: https://github.com/yangdongchao/SimpleSpeech +Paper: https://arxiv.org/abs/2406.02328, https://arxiv.org/abs/2408.13893 + +Authors + * Pooneh Mousavi 2024 +""" + +import logging +import os + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from omegaconf import OmegaConf +from torch.autograd import Function +from torch.nn.utils import remove_weight_norm, weight_norm + + +class SQCodec(nn.Module): + """ + Speech codec model (SQ-Codec) with scalar quantization. It maps the complex speech signal into a finite and compact latent space. + The model consists of an encoder-decoder architecture with optional causal convolutions, downsampling, and upsampling layers. + It uses vector quantization and various convolutional blocks for processing. + + Make sure that you download and extract the SQ-codec.zip in save_path from following Huggingface repo: + - HF repo: https://huggingface.co/Dongchao/UniAudio/blob/main/SQ-Codec.zip + + Repository: https://github.com/yangdongchao/SimpleSpeech + Paper: https://arxiv.org/abs/2406.02328, https://arxiv.org/abs/2408.13893 + + Arguments + --------- + save_path : str, optional + Directory where the model and configuration files are saved (default is None). + config : str, optional + Configuration filename for the model. It is extracted form zip file(default is 'config.yaml'). + checkpoint : str, optional + Model checkpoint filename. It is extracted form zip file( (default is 'ckpt_00190000.pth'). + sample_rate : int, optional + Sample rate for input audio (default is 16000). + dim_codebook : int, optional + Dimension of each codebook (default is 19683). + n_codebook : int, optional + Number of codebooks used (default is 4). + bw : float, optional + Bandwidth parameter (default is 2). + clip_length : int, optional + Maximum clip length for processing (default is 450). + + Example + ------- + >>> save_path = "savedir" + >>> config = "config.yaml" + >>> checkpoint = "ckpt_00190000.pth" + >>> model = SQCodec(save_path, config, checkpoint) + >>> audio = torch.randn(3, 16000) + >>> tokens, emb = model.encode(audio) + >>> tokens.shape + torch.Size([3, 200]) + >>> emb.shape + torch.Size([3, 36, 50]) + >>> rec = model.decode(tokens) + >>> rec.shape + torch.Size([3, 1, 16000]) + """ + + def __init__( + self, + save_path, + config, + checkpoint, + sample_rate=16000, + dim_codebook=19683, + n_codebook=4, + bw=2, + clip_length=450, + ): + super(SQCodec, self).__init__() + self.config_path = os.path.join(save_path, config) + self.ckpt_path = os.path.join(save_path, checkpoint) + if not os.path.exists(self.config_path) and not os.path.exists( + self.ckpt_path + ): + err_msg = ( + "the files %s or %s does not exist." + "(make sure that you download and extract the SQ-codec.zip in save_path from following Huggingface repo:" + " https://huggingface.co/Dongchao/UniAudio/blob/main/SQ-Codec.zip)" + % (self.ckpt_path, self.config_path) + ) + raise FileNotFoundError(err_msg) + self.clip_length = clip_length + + logging.info( + f"Using config {self.config_path} and model {self.ckpt_path}" + ) + + self.scalar_codec = self.build_codec_model(self.config_path) + self.sample_rate = sample_rate + self.dim_codebook = dim_codebook + self.n_codebook = n_codebook + self.bw = bw + self.mask_id = self.dim_codebook * self.n_codebook + + def build_codec_model(self, config): + """ + Loads and builds the scalar codec model from the given configuration. + + Parameters + ---------- + config : str + Path to the configuration file. + + Returns + ------- + ScalarModel + The built scalar codec model loaded with weights from the checkpoint. + """ + exp_model_config = OmegaConf.load(config) + scalar_codec = ScalarModel(**exp_model_config.generator.config) + device = next(iter(scalar_codec.parameters())).device + parameter_dict = torch.load(self.ckpt_path, map_location=device, weights_only=False) + scalar_codec.load_state_dict(parameter_dict["codec_model"]) + return scalar_codec + + def _flatten_codebooks(self, arr, offset_size=None): + """ + Flattens a 3D array (B, N, D) to a 1D array while applying an offset to each codebook if specified. + + Parameters + ---------- + arr : numpy.ndarray + A 3D array of shape (B, N, D). + offset_size : int or None, optional + The offset size to be applied to each codebook slice (default is None). + + Returns + ------- + numpy.ndarray + A 1D array representing the flattened codebooks. + """ + assert ( + len(arr.shape) == 3 + ), "Input array must have 3 dimensions [B, N, D]" + N, B, D = arr.shape + arr = arr.copy() + # if offset_size is not None: + # for n in range(N): + # arr[n, :, :] += offset_size * n + flattened_arr = arr.transpose(1, 2, 0).reshape(B, N * D) + return flattened_arr + + def encode(self, inputs): + """ + Encodes the input audio tensor using the scalar codec and quantizes the output. + + Parameters + ---------- + inputs : torch.Tensor + Input audio tensor of shape (B, T) or (B, 1, T), where B is the batch size + and T is the length of the audio sequence. + + Returns + ------- + tuple + A tuple containing: + - torch.Tensor: The flattened and quantized encoded representation of the input. + - torch.Tensor: Quantized embedding. + """ + if inputs.dim() == 2: + inputs = inputs.unsqueeze(1) + compressed = self.scalar_codec.encode(inputs) + chunks = compressed.chunk(self.n_codebook, dim=1) + codec_ls = [] + for i, chunk in enumerate(chunks): + chunk = chunk.detach().cpu().numpy().astype(np.int32) + 1 + tmp_codec = ternary_matrix_to_decimal(chunk) + codec_ls.append(tmp_codec) + codec_ls = np.array(codec_ls) + flat_codec = self._flatten_codebooks(codec_ls, self.dim_codebook) + flat_codec = torch.from_numpy(flat_codec).to(torch.int32) + return flat_codec.to(inputs.device), compressed.to(inputs.device) + + def decode(self, codes): + """ + Decodes the quantized codes back into an audio tensor. + + Parameters + ---------- + codes : torch.Tensor + Quantized codes with shape (B, T). + + Returns + ------- + torch.Tensor + Reconstructed audio signal. + """ + assert codes.dim() == 2 + B, T = codes.shape + assert ( + T % self.n_codebook == 0 + ), "Length T must be divisible by n_codebook" + codes = codes.view(B, -1, self.n_codebook).permute(2, 0, 1) + # for i in range(self.n_codebook): + # codes[i, :, :] -= i * self.dim_codebook + emb_quant = [] + for i in range(self.n_codebook): + tmp_list = decimal_to_ternary_matrix(codes[i, :, :], D=9) - 1 + emb_quant.append(tmp_list) + emb_quant = torch.cat(emb_quant, dim=1) + out = self.scalar_codec.decode(emb_quant.float().to(codes.device)) + return out.detach().cpu().squeeze(0) + + def reconstruct(self, wav_root): + """ + Processes a given waveform file by encoding and decoding it through the scalar codec. + + Parameters + ---------- + wav_root : str + Path to the waveform file. + + Returns + ------- + torch.Tensor or None + Processed waveform tensor or None if the file is empty. + """ + wav, sr = torchaudio.load(wav_root) + if wav.numel() == 0: + return None + if sr != self.sample_rate: + wav = torchaudio.transforms.Resample(sr, self.sample_rate)(wav) + wav = wav.unsqueeze(1) + emb, emb_quant, x = self.scalar_codec.inference(wav) + return x.detach().cpu().squeeze(0) + + @property + def is_discrete(self): + """Indicates whether the codec works with discrete values.""" + return True + + @property + def codebook_length(self): + """Returns the total length of the codebook.""" + return self.dim_codebook * self.n_codebook + 1 + + def find_length(self, x): + """ + Finds the length of the tokenized version of the input tensor. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + int + The length of the tokenized input. + """ + return self.tokenize(x).shape[0] // self.n_codebook + + +class ScalarModel(nn.Module): + """ + A custom neural network model for encoding and decoding audio signals. + + The model consists of an encoder-decoder architecture with optional + causal convolutions, downsampling, and upsampling layers. It uses + vector quantization and various convolutional blocks for processing. + + + Arguments + --------- + num_bands : int + Number of input bands (or channels). + sample_rate : int + Sample rate of the input signal. + causal : bool + If True, uses causal convolutions for processing. + num_samples : int + Number of samples to process for downsampling or upsampling. + downsample_factors : list of int + List of factors to downsample the input. + downsample_kernel_sizes : list of int + List of kernel sizes for downsampling layers. + upsample_factors : list of int + List of factors to upsample the input. + upsample_kernel_sizes : list of int + List of kernel sizes for upsampling layers. + latent_hidden_dim : int + Dimension of the latent representation. + default_kernel_size : int + Default kernel size for convolutional layers. + delay_kernel_size : int + Kernel size used for the delay convolutional layer. + init_channel : int + Number of initial channels for the encoder and decoder. + res_kernel_size : int + Kernel size used for the residual convolutional blocks. + + Example + ------- + >>> model = ScalarModel(num_bands=1, sample_rate=16000,causal=True,num_samples=2,downsample_factors=[2,4,4,5],downsample_kernel_sizes=[4,8,8,10],upsample_factors=[5,4,4,2],upsample_kernel_sizes=[10,8,8,4],latent_hidden_dim=36,default_kernel_size=7,delay_kernel_size=5,init_channel=48,res_kernel_size=7) # doctest: +SKIP + >>> audio = torch.randn(3, 1, 16000) + >>> quant_emb = model.encode(audio) # doctest: +SKIP + >>> quant_emb.shape + torch.Size([3, 36, 50]) + >>> rec = model.decode(quant_emb) # doctest: +SKIP + >>> rec.shap) # doctest: +SKIP + torch.Size([3, 1, 16000]) + """ + + def __init__( + self, + num_bands, + sample_rate, + causal, + num_samples, + downsample_factors, + downsample_kernel_sizes, + upsample_factors, + upsample_kernel_sizes, + latent_hidden_dim, + default_kernel_size, + delay_kernel_size, + init_channel, + res_kernel_size, + ): + super(ScalarModel, self).__init__() + self.sample_rate = sample_rate + self.encoder = [] + self.decoder = [] + self.vq = lambda x: CustomRoundingFunction.apply(x, "binary") + + # Encoder layers + self.encoder.append( + weight_norm( + Conv1d( + num_bands, + init_channel, + kernel_size=default_kernel_size, + causal=causal, + ) + ) + ) + if num_samples > 1: + # Downsampling layer + self.encoder.append( + PreProcessor( + init_channel, + init_channel, + num_samples, + kernel_size=default_kernel_size, + causal=causal, + ) + ) + for i, down_factor in enumerate(downsample_factors): + self.encoder.append( + ResEncoderBlock( + init_channel * np.power(2, i), + init_channel * np.power(2, i + 1), + down_factor, + downsample_kernel_sizes[i], + res_kernel_size, + causal=causal, + ) + ) + self.encoder.append( + weight_norm( + Conv1d( + init_channel * np.power(2, len(downsample_factors)), + latent_hidden_dim, + kernel_size=default_kernel_size, + causal=causal, + ) + ) + ) + + # Decoder layers + self.decoder.append( + weight_norm( + Conv1d( + latent_hidden_dim, + init_channel * np.power(2, len(upsample_factors)), + kernel_size=delay_kernel_size, + ) + ) + ) + for i, upsample_factor in enumerate(upsample_factors): + self.decoder.append( + ResDecoderBlock( + init_channel * np.power(2, len(upsample_factors) - i), + init_channel * np.power(2, len(upsample_factors) - i - 1), + upsample_factor, + upsample_kernel_sizes[i], + res_kernel_size, + causal=causal, + ) + ) + if num_samples > 1: + self.decoder.append( + PostProcessor( + init_channel, + init_channel, + num_samples, + kernel_size=default_kernel_size, + causal=causal, + ) + ) + self.decoder.append( + weight_norm( + Conv1d( + init_channel, + num_bands, + kernel_size=default_kernel_size, + causal=causal, + ) + ) + ) + + self.encoder = nn.ModuleList(self.encoder) + self.decoder = nn.ModuleList(self.decoder) + + def forward(self, x): + """ + Performs a forward pass through the encoder and decoder. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, channels, length). + + Returns + ------- + torch.Tensor + Reconstructed output tensor. + """ + for i, layer in enumerate(self.encoder): + if i != len(self.encoder) - 1: + x = layer(x) + else: + x = F.tanh(layer(x)) + x = self.vq(x) # Quantization step + for i, layer in enumerate(self.decoder): + x = layer(x) + return x + + def inference(self, x): + """ + Encodes input tensor `x` and decodes the quantized embeddings. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, channels, length). + + Returns + ------- + tuple + A tuple (emb, emb_quant, x), where `emb` is the latent embedding, + `emb_quant` is the quantized embedding, and `x` is the decoded output. + """ + for i, layer in enumerate(self.encoder): + if i != len(self.encoder) - 1: + x = layer(x) + else: + x = F.tanh(layer(x)) + emb = x + emb_quant = self.vq(emb) + x = emb_quant + for i, layer in enumerate(self.decoder): + x = layer(x) + return emb, emb_quant, x + + def encode(self, x): + """ + Encodes the input tensor `x` into a quantized embedding. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, channels, length). + + Returns + ------- + torch.Tensor + Quantized embedding. + """ + for i, layer in enumerate(self.encoder): + if i != len(self.encoder) - 1: + x = layer(x) + else: + x = F.tanh(layer(x)) + emb = x + emb_quant = self.vq(emb) + return emb_quant + + def decode(self, emb_quant): + """ + Decodes the quantized embeddings back into a tensor. + + Parameters + ---------- + emb_quant : torch.Tensor + Quantized embedding tensor. + + Returns + ------- + torch.Tensor + Reconstructed output tensor. + """ + x = emb_quant + for i, layer in enumerate(self.decoder): + x = layer(x) + return x + + +class CustomRoundingFunction(Function): + """ + A customizable rounding function for various rounding operations, including: + - Rounding to the nearest multiple of a specified divisor. + - Rounding to the nearest integer. + - Applying the Heaviside step function. + + Arguments + --------- + mode : str + The mode of the operation. Can be 'round', 'binary', or 'heaviside'. + divisor : float, optional + The divisor for rounding. Only used in 'round' mode. + """ + + @staticmethod + def forward(ctx, input, mode="round", divisor=1.0): + """ + Forward pass for the custom rounding function. + + Arguments + --------- + ctx : context object + Context object used to store information for the backward computation. + input : torch.Tensor + The input tensor to be processed. + mode : str + The mode of the operation ('round', 'binary', 'heaviside'). + divisor : float + The divisor for rounding. Only used in 'round' mode. + + Returns + ------- + torch.Tensor + The processed tensor after applying the operation. + """ + ctx.mode = mode + ctx.divisor = divisor + + if mode == "round": + return torch.round(divisor * input) / divisor + elif mode == "binary": + return torch.round(input) + elif mode == "heaviside": + values = torch.tensor([0.0]).type_as(input) + return torch.heaviside(input, values) + else: + raise ValueError( + f"Invalid mode '{mode}'. Supported modes: 'round', 'binary', 'heaviside'." + ) + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass for the custom rounding function. + + Arguments + --------- + ctx : context object + Context object containing information saved during the forward pass. + grad_output : torch.Tensor + The gradient of the output with respect to the loss. + + Returns + ------- + torch.Tensor + The gradient of the input with respect to the loss. + """ + # For all modes, the gradient is propagated unchanged. + return grad_output.clone(), None, None + + +class PreProcessor(nn.Module): + """ + A module for preprocessing input data through convolution and pooling operations. + It is used as an initial step before the encoder blocks in the ScalarModel, particularly when the kernel_size for average pooling operation exceeds 1. + + Arguments + --------- + n_in : int + Number of input channels. + n_out : int + Number of output channels. + num_samples : int + Number of samples for pooling. + kernel_size : int, optional + Size of the convolutional kernel (default is 7). + causal : bool, optional + If True, applies causal convolution (default is False). + """ + + def __init__(self, n_in, n_out, num_samples, kernel_size=7, causal=False): + super(PreProcessor, self).__init__() + self.pooling = torch.nn.AvgPool1d(kernel_size=num_samples) + self.conv = Conv1d(n_in, n_out, kernel_size=kernel_size, causal=causal) + self.activation = nn.PReLU() + + def forward(self, x): + """ + Applies convolution, activation, and pooling to the input data. + + Arguments + --------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Processed output tensor. + """ + output = self.activation(self.conv(x)) + output = self.pooling(output) + return output + + +class PostProcessor(nn.Module): + """ + A module for postprocessing data through convolution and reshaping. + It is used as an initial step after the decoder blocks in the ScalarModel, particularly when the kernel_size for average pooling operation exceeds 1. + + Arguments + --------- + n_in : int + Number of input channels. + n_out : int + Number of output channels. + num_samples : int + Number of samples for repetition. + kernel_size : int, optional + Size of the convolutional kernel (default is 7). + causal : bool, optional + If True, applies causal convolution (default is False). + """ + + def __init__(self, n_in, n_out, num_samples, kernel_size=7, causal=False): + super(PostProcessor, self).__init__() + self.num_samples = num_samples + self.conv = Conv1d(n_in, n_out, kernel_size=kernel_size, causal=causal) + self.activation = nn.PReLU() + + def forward(self, x): + """ + Applies reshaping, repetition, and convolution to the input data. + + Arguments + --------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Processed output tensor. + """ + x = torch.transpose(x, 1, 2) + B, T, C = x.size() + x = x.repeat(1, 1, self.num_samples).view(B, -1, C) + x = torch.transpose(x, 1, 2) + output = self.activation(self.conv(x)) + return output + + +class DownsampleLayer(nn.Module): + """ + A downsampling layer that applies convolution, optional pooling, and activation. + + Arguments + --------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_size : int + Size of the convolutional kernel. + stride : int, optional + Stride of the convolution (default is 1). + causal : bool, optional + If True, applies causal convolution (default is False). + activation : nn.Module, optional + Activation function (default is PReLU). + use_weight_norm : bool, optional + If True, applies weight normalization to the convolution (default is True). + pooling : bool, optional + If True, applies an average pooling operation (default is False). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + causal: bool = False, + activation=nn.PReLU(), + use_weight_norm: bool = True, + pooling: bool = False, + ): + super(DownsampleLayer, self).__init__() + self.pooling = pooling + self.stride = stride + self.activation = activation + self.use_weight_norm = use_weight_norm + if pooling: + self.layer = Conv1d( + in_channels, out_channels, kernel_size, causal=causal + ) + self.pooling = nn.AvgPool1d(kernel_size=stride) + else: + self.layer = Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + causal=causal, + ) + if use_weight_norm: + self.layer = weight_norm(self.layer) + + def forward(self, x): + """ + Applies convolution, optional pooling, and activation to the input data. + + Arguments + --------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Processed output tensor. + """ + x = self.layer(x) + x = self.activation(x) if self.activation is not None else x + if self.pooling: + x = self.pooling(x) + return x + + def remove_weight_norm(self): + """ + Removes weight normalization from the convolutional layer. + """ + if self.use_weight_norm: + remove_weight_norm(self.layer) + + +class UpsampleLayer(nn.Module): + """ + An upsampling layer that applies transposed convolution or repetition, with activation. + + Arguments + --------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_size : int + Size of the convolutional kernel. + stride : int, optional + Stride of the transposed convolution (default is 1). + causal : bool, optional + If True, applies causal convolution (default is False). + activation : nn.Module, optional + Activation function (default is PReLU). + use_weight_norm : bool, optional + If True, applies weight normalization to the convolution (default is True). + repeat : bool, optional + If True, applies repetition instead of transposed convolution (default is False). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + causal: bool = False, + activation=nn.PReLU(), + use_weight_norm: bool = True, + repeat: bool = False, + ): + super(UpsampleLayer, self).__init__() + self.repeat = repeat + self.stride = stride + self.activation = activation + self.use_weight_norm = use_weight_norm + if repeat: + self.layer = Conv1d( + in_channels, out_channels, kernel_size, causal=causal + ) + else: + self.layer = ConvTranspose1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + causal=causal, + ) + if use_weight_norm: + self.layer = weight_norm(self.layer) + + def forward(self, x): + """ + Applies upsampling through transposed convolution or repetition, followed by activation. + + Arguments + --------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Processed output tensor. + """ + x = self.layer(x) + x = self.activation(x) if self.activation is not None else x + if self.repeat: + x = torch.transpose(x, 1, 2) + B, T, C = x.size() + x = x.repeat(1, 1, self.stride).view(B, -1, C) + x = torch.transpose(x, 1, 2) + return x + + def remove_weight_norm(self): + """ + Removes weight normalization from the convolutional layer. + """ + if self.use_weight_norm: + remove_weight_norm(self.layer) + + +class ResidualUnit(nn.Module): + """ + A residual unit with two convolutional layers and activation functions. + This module is commonly used in the encoder and decoder blocks of the ScalarModel + + Arguments + --------- + n_in : int + Number of input channels. + n_out : int + Number of output channels. + dilation : int + Dilation factor for the first convolutional layer. + res_kernel_size : int, optional + Size of the convolutional kernel for residual connections (default is 7). + causal : bool, optional + If True, applies causal convolution (default is False). + """ + + def __init__(self, n_in, n_out, dilation, res_kernel_size=7, causal=False): + super(ResidualUnit, self).__init__() + self.conv1 = weight_norm( + Conv1d( + n_in, + n_out, + kernel_size=res_kernel_size, + dilation=dilation, + causal=causal, + ) + ) + self.conv2 = weight_norm( + Conv1d(n_in, n_out, kernel_size=1, causal=causal) + ) + self.activation1 = nn.PReLU() + self.activation2 = nn.PReLU() + + def forward(self, x): + """ + Applies two convolutional layers with activations and adds the input for a residual connection. + + Arguments + --------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Output tensor with residual connection applied. + """ + output = self.activation1(self.conv1(x)) + output = self.activation2(self.conv2(output)) + return output + x + + +class ResEncoderBlock(nn.Module): + """ + A residual encoder block with multiple residual units and a downsampling layer. + + Arguments + --------- + n_in : int + Number of input channels. + n_out : int + Number of output channels. + stride : int + Stride for the downsampling layer. + down_kernel_size : int + Kernel size for the downsampling layer. + res_kernel_size : int, optional + Size of the convolutional kernel for residual connections (default is 7). + causal : bool, optional + If True, applies causal convolution (default is False). + """ + + def __init__( + self, + n_in, + n_out, + stride, + down_kernel_size, + res_kernel_size=7, + causal=False, + ): + super(ResEncoderBlock, self).__init__() + self.convs = nn.ModuleList( + [ + ResidualUnit( + n_in, + n_out // 2, + dilation=1, + res_kernel_size=res_kernel_size, + causal=causal, + ), + ResidualUnit( + n_out // 2, + n_out // 2, + dilation=3, + res_kernel_size=res_kernel_size, + causal=causal, + ), + ResidualUnit( + n_out // 2, + n_out // 2, + dilation=5, + res_kernel_size=res_kernel_size, + causal=causal, + ), + ResidualUnit( + n_out // 2, + n_out // 2, + dilation=7, + res_kernel_size=res_kernel_size, + causal=causal, + ), + ResidualUnit( + n_out // 2, + n_out // 2, + dilation=9, + res_kernel_size=res_kernel_size, + causal=causal, + ), + ] + ) + self.down_conv = DownsampleLayer( + n_in, n_out, down_kernel_size, stride=stride, causal=causal + ) + + def forward(self, x): + """ + Applies a series of residual units and a downsampling layer. + + Arguments + --------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Processed output tensor. + """ + for conv in self.convs: + x = conv(x) + x = self.down_conv(x) + return x + + +class ResDecoderBlock(nn.Module): + """ + A residual decoder block with upsampling and multiple residual units. + + Arguments + --------- + n_in : int + Number of input channels. + n_out : int + Number of output channels. + stride : int + Stride for the upsampling layer. + up_kernel_size : int + Kernel size for the upsampling layer. + res_kernel_size : int, optional + Size of the convolutional kernel for residual connections (default is 7). + causal : bool, optional + If True, applies causal convolution (default is False). + """ + + def __init__( + self, + n_in, + n_out, + stride, + up_kernel_size, + res_kernel_size=7, + causal=False, + ): + super(ResDecoderBlock, self).__init__() + self.up_conv = UpsampleLayer( + n_in, + n_out, + kernel_size=up_kernel_size, + stride=stride, + causal=causal, + activation=None, + ) + self.convs = nn.ModuleList( + [ + ResidualUnit( + n_out, + n_out, + dilation=1, + res_kernel_size=res_kernel_size, + causal=causal, + ), + ResidualUnit( + n_out, + n_out, + dilation=3, + res_kernel_size=res_kernel_size, + causal=causal, + ), + ResidualUnit( + n_out, + n_out, + dilation=5, + res_kernel_size=res_kernel_size, + causal=causal, + ), + ResidualUnit( + n_out, + n_out, + dilation=7, + res_kernel_size=res_kernel_size, + causal=causal, + ), + ResidualUnit( + n_out, + n_out, + dilation=9, + res_kernel_size=res_kernel_size, + causal=causal, + ), + ] + ) + + def forward(self, x): + """ + Applies upsampling followed by a series of residual units. + + Arguments + --------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Processed output tensor. + """ + x = self.up_conv(x) + for conv in self.convs: + x = conv(x) + return x + + +class Conv1d(nn.Conv1d): + """ + Custom 1D convolution layer with an optional causal mode. + + This class extends PyTorch's `nn.Conv1d` and allows for causal convolutions + by automatically applying the correct amount of padding to ensure that the output + does not depend on future inputs, which is useful for sequential data processing. + + Arguments + --------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_size : int + Size of the convolutional kernel. + stride : int, optional + Stride of the convolution (default is 1). + dilation : int, optional + Dilation factor for the convolution (default is 1). + groups : int, optional + Number of blocked connections from input channels to output channels (default is 1). + padding_mode : str, optional + Padding mode to use ('zeros', 'reflect', 'replicate', or 'circular') (default is 'zeros'). + bias : bool, optional + If True, adds a learnable bias to the output (default is True). + padding : int, optional + Explicit padding value. If not provided, it will be computed automatically. + causal : bool, optional + If True, applies causal convolution where the output depends only on the past and current inputs (default is False). + w_init_gain : str, optional + Gain value used for Xavier initialization (e.g., 'relu', 'tanh', etc.). If provided, applies Xavier uniform initialization to the convolutional weights. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + padding_mode: str = "zeros", + bias: bool = True, + padding=None, + causal: bool = False, + w_init_gain=None, + ): + self.causal = causal + if padding is None: + if causal: + padding = 0 + self.left_padding = dilation * (kernel_size - 1) + else: + padding = get_padding(kernel_size, dilation) + super(Conv1d, self).__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + padding_mode=padding_mode, + bias=bias, + ) + if w_init_gain is not None: + torch.nn.init.xavier_uniform_( + self.weight, gain=torch.nn.init.calculate_gain(w_init_gain) + ) + + def forward(self, x): + """ + Applies the forward pass of the convolutional layer. + + Arguments + --------- + x : torch.Tensor + Input tensor of shape (batch_size, channels, sequence_length). + + Returns + ------- + torch.Tensor + The output tensor after applying the convolution operation. + If `causal` is True, the input tensor is padded to ensure that + the output at each timestep only depends on the current and previous inputs. + """ + if self.causal: + x = F.pad(x.unsqueeze(2), (self.left_padding, 0, 0, 0)).squeeze(2) + + return super(Conv1d, self).forward(x) + + +class ConvTranspose1d(nn.ConvTranspose1d): + """ + Custom transposed 1D convolution layer with causal option. + + Arguments + --------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_size : int + Size of the convolutional kernel. + stride : int, optional + Stride of the convolution (default is 1). + output_padding : int, optional + Additional size added to one side of the output (default is 0). + groups : int, optional + Number of blocked connections (default is 1). + bias : bool, optional + If True, adds a learnable bias (default is True). + dilation : int, optional + Dilation factor (default is 1). + padding : int, optional + Explicit padding value (default is None). + padding_mode : str, optional + Padding mode (default is 'zeros'). + causal : bool, optional + If True, applies causal convolution. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + output_padding: int = 0, + groups: int = 1, + bias: bool = True, + dilation: int = 1, + padding=None, + padding_mode: str = "zeros", + causal: bool = False, + ): + if padding is None: + padding = 0 if causal else (kernel_size - stride) // 2 + if causal: + assert ( + padding == 0 + ), "padding is not allowed in causal ConvTranspose1d." + assert ( + kernel_size == 2 * stride + ), "kernel_size must be equal to 2*stride is not allowed in causal ConvTranspose1d." + super(ConvTranspose1d, self).__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + padding_mode=padding_mode, + ) + self.causal = causal + self.stride = stride + + def forward(self, x): + """ + Applies the transposed convolution operation. + + Arguments + --------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Transposed convolved output tensor. + """ + x = super(ConvTranspose1d, self).forward(x) + if self.causal: + x = x[:, :, : -self.stride] + return x + + +def decimal_to_ternary_matrix(decimals, D): + """ + Convert a tensor of decimal numbers to a D*T ternary matrix for each batch. + + Arguments + --------- + decimals : torch.Tensor + A 2D tensor of decimal numbers with shape (B, T), where B is the batch size + and T is the number of elements in each batch. + D : int + Number of ternary digits to represent each number (depth). + + Returns + ------- + torch.Tensor + A 3D tensor of shape (B, D, T) where each slice along the first dimension + corresponds to a batch, and each column is represented as a ternary number. + """ + B, T = decimals.shape + ternary_matrix = torch.zeros((B, D, T), dtype=torch.long) + for pos in range(D): + ternary_matrix[:, pos, :] = decimals % 3 # Modulo operation + decimals //= 3 # Floor division for next ternary digit + + return ternary_matrix + + +def ternary_matrix_to_decimal(matrix): + """ + Convert a B*D*N ternary matrix to a 2D array of decimal numbers for each batch. + + Arguments + --------- + matrix : numpy.ndarray + A 3D numpy array of shape (B, D, N), where B is the batch size, D is the number + of ternary digits, and N is the number of ternary numbers in each batch. + + Returns + ------- + numpy.ndarray + A 2D numpy array of shape (B, N), where each value represents the decimal + equivalent of the corresponding ternary number in the input matrix. + """ + ( + B, + D, + N, + ) = ( + matrix.shape + ) # B is the batch size, D is the number of digits, N is the number of ternary numbers + powers_of_three = 3 ** np.arange(D) # [3^0, 3^1, ..., 3^(D-1)] + + # Reshape powers_of_three for broadcasting: [D] -> [1, D, 1] + powers_of_three = powers_of_three[:, np.newaxis] # Shape [D, 1] + + # Compute dot product using broadcasting: matrix * powers_of_three along D axis + decimals = np.sum(matrix * powers_of_three, axis=1) # Sum along the D axis + + return decimals + + +def get_padding(kernel_size, dilation=1): + """ + Computes the padding size for a given kernel size and dilation. + + Arguments + --------- + kernel_size : int + Size of the convolutional kernel. + dilation : int, optional + Dilation factor for convolution (default is 1). + + Returns + ------- + int + Calculated padding size. + """ + return int((kernel_size * dilation - dilation) / 2) diff --git a/benchmarks/DASB/VoiceBank/enhancement/crdnn/train_encodec.py b/benchmarks/DASB/VoiceBank/enhancement/train.py similarity index 73% rename from benchmarks/DASB/VoiceBank/enhancement/crdnn/train_encodec.py rename to benchmarks/DASB/VoiceBank/enhancement/train.py index 82a5470f9..0f7186843 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/crdnn/train_encodec.py +++ b/benchmarks/DASB/VoiceBank/enhancement/train.py @@ -1,9 +1,9 @@ #!/usr/bin/env/python -"""Recipe for training a transformer-based speech enhancement system using EnCodec audio representations. +"""Recipe for training a speech enhancement system using discrete audio representations. To run this recipe: -> python train_encodec.py hparams/.yaml +> python train.py hparams/.yaml Authors * Luca Della Libera 2024 @@ -20,56 +20,25 @@ from speechbrain.utils.distributed import if_main_process, run_on_main -_CACHE = {} - - class Enhancement(sb.Brain): - @torch.no_grad() - def sig_to_toks(self, sig, lens): - # sig: [B, T] - self.hparams.codec.to(self.device).eval() - toks, _ = self.hparams.codec.encode(sig, lens) # [B, N, K] - return toks - @torch.no_grad() def toks_to_sig(self, toks): # toks: [B, N, K] self.hparams.codec.to(self.device).eval() - sig = self.hparams.codec.decode(toks)[:, 0] # [B, T] + self.hparams.codec.device = self.device + if hasattr(self.hparams.codec, "codec_vocoder"): + self.hparams.codec.codec_vocoder.device = self.device + kwargs = {} + if hasattr(self.hparams, "SSL_layers"): + kwargs = {"SSL_layers": self.hparams.SSL_layers} + sig = self.hparams.codec.tokens_to_sig(toks, **kwargs) # [B, T] return sig def compute_forward(self, batch, stage): """Forward pass.""" batch = batch.to(self.device) - in_sig, in_lens = batch.in_sig # [B, T] - out_sig, out_lens = batch.out_sig # [B, T] - - # Augment if specified - if stage == sb.Stage.TRAIN and self.hparams.augment: - in_sig, in_lens = self.hparams.augmentation(in_sig, in_lens) - - # Extract tokens (cache them at first epoch if augmentation is disabled) - key = tuple(sorted(batch.id)) - try: - in_toks, out_toks = _CACHE[key] - in_toks = in_toks.to(self.device) - out_toks = out_toks.to(self.device) - except KeyError: - assert (in_lens == out_lens).all() - sig = torch.cat([in_sig, out_sig]) # [B2, T] - lens = torch.cat([in_lens, out_lens]) # [B2, T] - toks = self.sig_to_toks(sig, lens) # [B2, N, K] - in_toks, out_toks = toks.split( - [len(in_sig), len(out_sig)] - ) # [B, N, K], [B, N, K] - out_toks = out_toks.reshape( - len(in_sig), -1, self.hparams.num_codebooks, - ) # [B, N, K] - if self.hparams.use_cache and (not self.hparams.augment): - _CACHE[key] = in_toks.cpu(), out_toks.cpu() - - # Avoid in-place modification from embedding layer - in_toks = in_toks.clone() + in_toks, in_lens = batch.in_toks # [B, N, K] + out_toks, out_lens = batch.out_toks # [B, N, K] # Forward embedding + attention in_embs = self.modules.embedding(in_toks) # [B, N, K, H] @@ -79,7 +48,14 @@ def compute_forward(self, batch, stage): ) # [B, N, H] # Forward encoder - hyp_embs = self.modules.encoder(in_embs) + if hasattr(self.modules.encoder, "encode"): + hyp_embs = self.modules.encoder.encode(in_embs, in_lens) # [B, N, H] + else: + abs_length = (in_embs.shape[1] * in_lens).ceil().long() + for i in range(len(abs_length)): + if abs_length[i] < in_embs.shape[1]: + in_embs[i, abs_length[i]:] = 0 + hyp_embs = self.modules.encoder(in_embs) # [B, N, H] # Forward head log_probs = ( @@ -128,8 +104,8 @@ def compute_objectives(self, predictions, batch, stage): @torch.no_grad() def vocode(self, IDs, in_sig, out_sig, hyp_toks, out_toks, lens): - hyp_sig = self.toks_to_sig(hyp_toks) # [B, T] - rec_sig = self.toks_to_sig(out_toks) # [B, T] + hyp_sig = self.toks_to_sig(hyp_toks).to(self.device) # [B, T] + rec_sig = self.toks_to_sig(out_toks).to(self.device) # [B, T] # Adjust length if out_sig.shape[-1] > hyp_sig.shape[-1]: @@ -145,11 +121,8 @@ def vocode(self, IDs, in_sig, out_sig, hyp_toks, out_toks, lens): rec_sig = rec_sig.narrow(-1, 0, out_sig.shape[-1]) # [B, T_out] self.dnsmos_metric.append(IDs, hyp_sig, lens) - self.rec_dnsmos_metric.append(IDs, rec_sig, lens) - self.ref_dnsmos_metric.append(IDs, out_sig, lens) self.dwer_metric.append(IDs, hyp_sig, out_sig, lens) self.wavlm_sim_metric.append(IDs, hyp_sig, out_sig, lens) - self.ecapatdnn_sim_metric.append(IDs, hyp_sig, out_sig, lens) if self.hparams.save_audios: save_folder = os.path.join(self.hparams.output_folder, "audios") @@ -183,11 +156,8 @@ def on_stage_start(self, stage, epoch=None): self.ter_metric = self.hparams.ter_computer() if stage == sb.Stage.TEST and self.hparams.compute_metrics: self.dnsmos_metric = self.hparams.dnsmos_computer() - self.rec_dnsmos_metric = self.hparams.dnsmos_computer() - self.ref_dnsmos_metric = self.hparams.dnsmos_computer() self.dwer_metric = self.hparams.dwer_computer() self.wavlm_sim_metric = self.hparams.wavlm_sim_computer() - self.ecapatdnn_sim_metric = self.hparams.ecapatdnn_sim_computer() def on_stage_end(self, stage, stage_loss, epoch=None): """Gets called at the end of each epoch.""" @@ -218,19 +188,10 @@ def on_stage_end(self, stage, stage_loss, epoch=None): elif stage == sb.Stage.TEST: if self.hparams.compute_metrics: stage_stats["DNSMOS"] = self.dnsmos_metric.summarize("average") - stage_stats["RecDNSMOS"] = self.rec_dnsmos_metric.summarize( - "average" - ) - stage_stats["RefDNSMOS"] = self.ref_dnsmos_metric.summarize( - "average" - ) stage_stats["dWER"] = self.dwer_metric.summarize("error_rate") stage_stats["WavLMSim"] = self.wavlm_sim_metric.summarize( "average" ) - stage_stats[ - "ECAPATDNNSim" - ] = self.ecapatdnn_sim_metric.summarize("average") self.hparams.train_logger.log_stats( stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, test_stats=stage_stats, @@ -276,26 +237,16 @@ def on_stage_end(self, stage, stage_loss, epoch=None): run_on_main(prepare_data, kwargs=prepare_data_kwargs) # Create the datasets objects - from utils import dataio_prepare + from common import dataio_prepare train_data, valid_data, test_data = dataio_prepare( debug=run_opts.get("debug", False), **hparams ) - # Pretrain the specified modules - if "pretrainer" in hparams: - run_on_main(hparams["pretrainer"].collect_files) - run_on_main(hparams["pretrainer"].load_collected) - - # Use pretrained embeddings - if hparams["pretrain_embedding"]: - embs = hparams["codec"].vocabulary.reshape(-1, hparams["embedding_dim"]) - hparams["embedding"].embedding.weight.data.copy_(embs) - # Log number of parameters/buffers - codec_params = sum( - [x.numel() for x in hparams["codec"].state_dict().values()] - ) + #codec_params = sum( + # [x.numel() for x in hparams["codec"].state_dict().values()] + #) model_params = sum( [ x.numel() @@ -305,10 +256,12 @@ def on_stage_end(self, stage, stage_loss, epoch=None): ) hparams["train_logger"].log_stats( stats_meta={ - f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", + #f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", }, ) + hparams["compute_metrics"] = True + hparams["codec"] = hparams["codec"]() # Trainer initialization brain = Enhancement( diff --git a/benchmarks/DASB/VoiceBank/enhancement/conformer/train_continuous_ssl.py b/benchmarks/DASB/VoiceBank/enhancement/train_continuous_ssl.py similarity index 89% rename from benchmarks/DASB/VoiceBank/enhancement/conformer/train_continuous_ssl.py rename to benchmarks/DASB/VoiceBank/enhancement/train_continuous_ssl.py index 72ba7f6ce..d95d24666 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/conformer/train_continuous_ssl.py +++ b/benchmarks/DASB/VoiceBank/enhancement/train_continuous_ssl.py @@ -1,6 +1,6 @@ #!/usr/bin/env/python -"""Recipe for training a transformer-based speech enhancement system using continuous SSL audio representations. +"""Recipe for training a speech enhancement system using continuous SSL audio representations. To run this recipe: > python train_continuous_ssl.py hparams/.yaml @@ -22,7 +22,6 @@ _CACHE = {} - # To use in configuration files def len_(SSL_layers, embedding_dim): return len(SSL_layers) * embedding_dim @@ -72,10 +71,6 @@ def compute_forward(self, batch, stage): in_sig, in_lens = batch.in_sig # [B, T] out_sig, out_lens = batch.out_sig # [B, T] - # Augment if specified - if stage == sb.Stage.TRAIN and self.hparams.augment: - in_sig, in_lens = self.hparams.augmentation(in_sig, in_lens) - # Extract features (cache them at first epoch if augmentation is disabled) key = tuple(sorted(batch.id)) try: @@ -161,11 +156,8 @@ def vocode(self, IDs, in_sig, out_sig, hyp_embs, out_embs, lens): rec_sig = rec_sig.narrow(-1, 0, out_sig.shape[-1]) # [B, T_out] self.dnsmos_metric.append(IDs, hyp_sig, lens) - self.rec_dnsmos_metric.append(IDs, rec_sig, lens) - self.ref_dnsmos_metric.append(IDs, out_sig, lens) self.dwer_metric.append(IDs, hyp_sig, out_sig, lens) self.wavlm_sim_metric.append(IDs, hyp_sig, out_sig, lens) - self.ecapatdnn_sim_metric.append(IDs, hyp_sig, out_sig, lens) if self.hparams.save_audios: save_folder = os.path.join(self.hparams.output_folder, "audios") @@ -197,11 +189,8 @@ def on_stage_start(self, stage, epoch=None): super().on_stage_start(stage, epoch) if stage == sb.Stage.TEST and self.hparams.compute_metrics: self.dnsmos_metric = self.hparams.dnsmos_computer() - self.rec_dnsmos_metric = self.hparams.dnsmos_computer() - self.ref_dnsmos_metric = self.hparams.dnsmos_computer() self.dwer_metric = self.hparams.dwer_computer() self.wavlm_sim_metric = self.hparams.wavlm_sim_computer() - self.ecapatdnn_sim_metric = self.hparams.ecapatdnn_sim_computer() def on_stage_end(self, stage, stage_loss, epoch=None): """Gets called at the end of each epoch.""" @@ -230,19 +219,10 @@ def on_stage_end(self, stage, stage_loss, epoch=None): elif stage == sb.Stage.TEST: if self.hparams.compute_metrics: stage_stats["DNSMOS"] = self.dnsmos_metric.summarize("average") - stage_stats["RecDNSMOS"] = self.rec_dnsmos_metric.summarize( - "average" - ) - stage_stats["RefDNSMOS"] = self.ref_dnsmos_metric.summarize( - "average" - ) stage_stats["dWER"] = self.dwer_metric.summarize("error_rate") stage_stats["WavLMSim"] = self.wavlm_sim_metric.summarize( "average" ) - stage_stats[ - "ECAPATDNNSim" - ] = self.ecapatdnn_sim_metric.summarize("average") self.hparams.train_logger.log_stats( stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, test_stats=stage_stats, @@ -288,17 +268,12 @@ def on_stage_end(self, stage, stage_loss, epoch=None): run_on_main(prepare_data, kwargs=prepare_data_kwargs) # Create the datasets objects - from utils import dataio_prepare + from common import dataio_prepare train_data, valid_data, test_data = dataio_prepare( debug=run_opts.get("debug", False), **hparams ) - # Pretrain the specified modules - if "pretrainer" in hparams: - run_on_main(hparams["pretrainer"].collect_files) - run_on_main(hparams["pretrainer"].load_collected) - # Log number of parameters/buffers ssl_params = sum( [x.numel() for x in hparams["ssl_model"].state_dict().values()] diff --git a/benchmarks/DASB/VoiceBank/enhancement/utils/__init__.py b/benchmarks/DASB/VoiceBank/enhancement/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/benchmarks/DASB/VoiceBank/enhancement/utils/aggregate_results.py b/benchmarks/DASB/VoiceBank/enhancement/utils/aggregate_results.py new file mode 100644 index 000000000..0df315b7e --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/utils/aggregate_results.py @@ -0,0 +1,149 @@ +#!/usr/bin/python +""" +Snippet to aggregate the results over multiple runs of the same experiment. +This is useful when we run multiple experiments with different seeds and we +want to compute the average performance. The script also reports the final +metric to Orion (when needed for hyperparameter tuning). + +The script searches for the result files (_results.txt) and computes the mean +and the standard deviation of the given evaluation metrics (e.g., acc or f1). +The results must have an identical format (with only different performance +numbers). + +To run this script: + + > python aggregate_results.py your_result_folder acc + +Author +------ +Pooneh Mousavi 2024 +""" + +import sys +import re +import numpy as np +from orion.client import report_objective +from speechbrain.utils.data_utils import get_all_files + + +def get_prototype(res_file, eval_metric): + """Parses a result file and adds a placeholder where the aggregated metrics + should be printed. It also returns the number of detected metrics. + + Arguments + --------- + res_file: path + Path of the result file to parse. + eval_metric: path + Metric of interest (e.g, acc or f1). + + Returns + --------- + prototype: list + List of the lines of the result file (with as placeholder). + n_metrics: int + Number of metrics to replace in the result files. + """ + prototype = [] + n_metrics = 0 + + # Open the first res file and figure out where the metrics are + with open(res_file) as file_in: + for line in file_in: + if eval_metric in line: + line = line.split(eval_metric)[0] + # The placeholder for the metric is + line = line + eval_metric + " " + n_metrics = n_metrics + 1 + prototype.append(line) + return prototype, n_metrics + + +def get_metrics(res_files, eval_metric): + """Summarizes the metrics of interest in a matrix. + + Arguments + --------- + res_files: list + List of all the result files. + eval_metric: path + Metric of interest (e.g, acc or f1). + + Returns + --------- + metrics: np.array + Matrix (n_metrics, n_files) containing the metrics of interest. + """ + + # Metric initialization + metrics = np.zeros([n_metrics, len(res_files)]) + + # Loop over files + for i in range(len(res_files)): + cnt = 0 + # Metric extraction + with open(res_files[i]) as file_in: + for line in file_in: + if eval_metric in line: + # Use regex to find the test WER value + match = re.search( + rf"{eval_metric}: (\d+\.\d+(?:e[+-]?\d+)?)", line + ) + if match: + value = match.group(1) + value = float(value) + metrics[cnt, i] = value + cnt = cnt + 1 + return metrics + + +def aggregate_metrics(prototype, metrics): + """Prints the aggregated metrics.It replaces the placeholders with + the corresponding metrics. + + Arguments + --------- + prototype: list + List of the lines of the result file (with as placeholder). + metrics: np.array + Matrix (n_metrics, n_files) containing the metrics of interest. + """ + cnt = 0 + for line in prototype: + if eval_metric in line: + values_line = "[" + for i in range(len(res_files)): + values_line = values_line + "%f " % float(metrics[cnt, i]) + values_line = values_line[:-1] + values_line = values_line + "] avg: %f ± %f " % ( + float(metrics[cnt, :].mean()), + float(metrics[cnt, :].std()), + ) + line = line.replace("", values_line) + cnt = cnt + 1 + print(line) + + +if __name__ == "__main__": + output_folder = sys.argv[1] + eval_metric = sys.argv[2] + + # Getting the list of the result files in the output folder + res_files = get_all_files(output_folder, match_and=["train_log.txt"]) + + # Gettin a prototype file + prototype, n_metrics = get_prototype(res_files[0], eval_metric) + + # Extracting the metrics of interest + metrics = get_metrics(res_files, eval_metric) + + # print aggregated metrics + aggregate_metrics(prototype, metrics) + + final_metric = metrics[-1, :].mean() + + # Report final metric to Orion + # Remember: orion expects metrics to be minimized! + if eval_metric == "acc" or eval_metric == "f1": + final_metric = 1 - final_metric + report_objective(final_metric) diff --git a/benchmarks/DASB/VoiceBank/enhancement/utils/audio_tokens.py b/benchmarks/DASB/VoiceBank/enhancement/utils/audio_tokens.py new file mode 100644 index 000000000..9dc4014c4 --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/utils/audio_tokens.py @@ -0,0 +1,193 @@ +"""Utilities for discrete audio token models + + +Authors + * Artem Ploujnikov 2023 +""" +import torch +from speechbrain.dataio.batch import PaddedBatch +from speechbrain.utils.data_utils import batch_pad_right +from functools import partial + + +def get_silence_token( + model, + sample_length=100000, + extract_emb=True, + device=None, + model_kwargs=None, +): + """Attempts to find out the silence tokens for a given model, + if applicable + + Arguments + --------- + model : nn.Module + A discrete token model, taking (wav, lengths) as arguments + sample_length : int + The length of the sample + extract_emb : bool + Whether to extract embeddings + device : str | torch.Device + The device to use + model_kwargs : dict + Additional arguments to pass to the model + + Returns + ------- + silence_tokens : torch.Tensor + The token(s) corresponding to silence + + silece_emb : torch.Tensor + The embedding(s) corresponding to silence + + """ + if device is None: + device = next(model.parameters()).device + if model_kwargs is None: + model_kwargs = {} + + audio = torch.zeros(1, sample_length, device=device) + length = torch.ones(1, device=device) + result = model(audio, length, **model_kwargs) + tokens = result[0] + silence_tokens = tokens.squeeze(0).mode(0).values + silence_emb = None + if extract_emb: + if hasattr(model, "embeddings"): + silence_emb = model.embeddings( + silence_tokens[None, None, :] + ).squeeze() + else: + heads = tokens.shape[-1] + embs = result[1] + mode_idx = [ + (tokens[0, :, head] == silence_tokens[head]).nonzero()[0].item() + for head in range(heads) + ] + silence_emb = torch.stack( + [embs[0, idx, head] for head, idx in enumerate(mode_idx)] + ) + return silence_tokens, silence_emb + + +def feature_pad_to(tensor, length, padding=None): + """Pads feature dimensions to the specified length with the specified padding, + assuming a (Batch x Length x Features..) tensor + + Arguments + --------- + tensor : torch.Tensor + The tensor to be padded + + length : int + The length to which the tensor will be padded + + padding : torch.Tensor, optional + The padding tensor - if omitted, zero padding + will be used + + Returns + ------- + result : torch.Tensor + The padded tensor + """ + if padding is None: + padding = torch.zeros(tensor.shape[1:]) + padding = padding[None, ...].expand( + (length - tensor.size(0),) + tensor.shape[1:] + ) + return torch.cat([tensor, padding], dim=0) + + +def batch_feature_pad(tensors, padding=None): + """Similar to batch_pad_right but pads with the specified padding, whcih + can be a vector or a tensor + + Arguments + --------- + tensors : list + The list of tensors to be padded + padding : torch.Tensor + The padding tensor + + Returns + ------- + result : torch.Tensor + the padded tensor + """ + lengths_abs = torch.tensor( + [len(item) for item in tensors], device=tensors[0].device + ) + max_length = lengths_abs.max() + data = torch.stack( + [feature_pad_to(item, max_length, padding) for item in tensors] + ) + lengths = lengths_abs / max_length + return data, lengths + + +def token_collate_fn(examples, silence_token, token_keys): + """A customized collation function for audio tokens where + the specified silence token will be used as padding - instead of + zeros + + Arguments + --------- + examples : list + A list of examples + + silence_token : torch.Tensor + The token(s) representing silence + + token_keys : list + The list of keys to which special padding will be applied + + Returns + ------- + result : speechbrain.dataio.batch.PaddedBatch + A padded batch + """ + token_tensor_ids = {id(examples[0][key]) for key in token_keys} + return PaddedBatch( + examples, + padding_func=_silence_padding, + padding_kwargs={ + "silence_token": silence_token, + "token_tensor_ids": token_tensor_ids, + }, + ) + + +def _silence_padding(values, silence_token, token_tensor_ids): + return ( + batch_feature_pad(values, silence_token) + if id(values[0]) in token_tensor_ids + else batch_pad_right(values) + ) + + +def use_silence_padding(dataloader_opts, silence_token, token_keys): + """Overrides the collation function to add silence padding to + audio token features + + Arguments + --------- + dataloder_opts : dict + Dataloader options + silence_token : torch.Tensor + The tensor to be used as silence padding + token_keys : torch.Tensor + The keys to apply silence padding to + + Returns + ------- + dataloader_opts : dict + Updated data loader options + """ + return { + **dataloader_opts, + "collate_fn": partial( + token_collate_fn, silence_token=silence_token, token_keys=token_keys + ), + } diff --git a/benchmarks/DASB/VoiceBank/enhancement/utils/data.py b/benchmarks/DASB/VoiceBank/enhancement/utils/data.py new file mode 100644 index 000000000..6c68358f5 --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/utils/data.py @@ -0,0 +1,91 @@ +"""Data utilities + +Authors + * Artem Ploujnikov 2024 +""" + +import torch +from speechbrain.dataio.batch import PaddedData + + +def undo_batch(batch): + """Converts a padded batch or a dicitionary to a list of + dictionaries. Any instances of PaddedData encountered will + be converted to plain tensors + + Arguments + --------- + batch: dict|speechbrain.dataio.batch.PaddedBatch + the batch + + Returns + ------- + result: dict + a list of dictionaries with each dictionary as a batch + element + """ + if hasattr(batch, "as_dict"): + batch = batch.as_dict() + keys = batch.keys() + return [ + dict(zip(keys, item)) + for item in zip( + *[_unpack_feature(feature) for feature in batch.values()] + ) + ] + + +def _unpack_feature(feature): + """Un-batches a single feature. If a PaddedBatch is provided, it will be converted + to a list of unpadded tensors. Otherwise, it will be returned unmodified + + Arguments + --------- + feature : any + The feature to un-batch + """ + if isinstance(feature, PaddedData): + device = feature.data.device + feature = _undo_padding(feature.data, feature.lengths) + feature = [torch.tensor(item, device=device) for item in feature] + return feature + + +# NOTE: Similar to the function in speechbrain.utils.data_utils +# but it keeps values in tensor form +def _undo_padding(batch, lengths): + """Produces Python lists given a batch of sentences with + their corresponding relative lengths. + + Arguments + --------- + batch : torch.Tensor + Batch of sentences gathered in a batch. + lengths : torch.Tensor + Relative length of each sentence in the batch. + + Returns + ------- + as_list : list + A python list of the corresponding input tensor. + + Example + ------- + >>> batch=torch.rand([4,100]) + >>> lengths=torch.tensor([0.5,0.6,0.7,1.0]) + >>> snt_list=undo_padding(batch, lengths) + >>> len(snt_list) + 4 + """ + batch_max_len = batch.shape[1] + as_list = [] + for seq, seq_length in zip(batch, lengths): + actual_size = int(torch.round(seq_length * batch_max_len)) + seq_true = seq[:actual_size] + as_list.append(seq_true) + return as_list + + +def as_dict(batch): + """Converts a batch to a dictionary""" + return {key: getattr(batch, key) for key in batch._PaddedBatch__keys} diff --git a/benchmarks/DASB/VoiceBank/enhancement/utils/eval.py b/benchmarks/DASB/VoiceBank/enhancement/utils/eval.py new file mode 100644 index 000000000..c0e14f867 --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/utils/eval.py @@ -0,0 +1,1028 @@ +""" Specifies the inference interfaces for speech quality +evaluation, used to assess the quality/intelligibility of +text-to-speech systems + +Authors: +* Artem Ploujnikov 2024 +""" + +from speechbrain.inference.interfaces import Pretrained +from speechbrain.inference.ASR import EncoderDecoderASR +from speechbrain.lobes.models.huggingface_transformers import Whisper +from speechbrain.dataio.dataset import FilteredSortedDynamicItemDataset +from speechbrain.decoders.seq2seq import S2SWhisperGreedySearcher +from speechbrain.dataio.batch import PaddedBatch +from speechbrain.utils.metric_stats import ErrorRateStats +from speechbrain.utils.superpowers import run_shell +from collections import namedtuple +from pathlib import Path +import os +import torch +import torchaudio +import re +import string +import logging +import shutil +import shlex +import subprocess + +logger = logging.getLogger(__name__) + +RE_PUNCTUATION = re.compile( + "|".join(re.escape(char) for char in string.punctuation) +) + + +SpeechEvaluationResult = namedtuple( + "SpeechEvaluationResult", ["score", "details"] +) + + +class SpeechEvaluator: + """A base class for speech evaluators + + Arguments + --------- + sample_rate : int + The audio sample rate this evaluator expects + """ + + def __init__(self, sample_rate=16000): + self.sample_rate = sample_rate + + def evaluate_file(self, file_name, text=None): + """Evaluates a single file + + Arguments + --------- + file_name : str|pathlib.Path + The file name to evaluate + text : str + The ground truth text, if applicable + + Returns + ------- + result: SpeechEvaluationResult + the evaluation result + """ + wav = self.read_audio(str(file_name)).to(self.device) + result = self.evaluate( + wavs=wav.unsqueeze(0), + length=torch.ones(1).to(self.device), + text=[text], + ) + return SpeechEvaluationResult( + score=result.score.item(), + details={ + key: _unbatchify(value) for key, value in result.details.items() + }, + ) + + def evaluate_files(self, file_names, text=None): + """Evaluates multiple files + + Arguments + --------- + file_names : list + A list of files + + text : list + File transcripts (not required for all evaluators) + + Returns + ------- + result : list + a list of SpeechEvaluationResult instances + """ + if text is None: + text = [None] * len(file_names) + items = [ + {"wav": self.read_audio(str(file_name)), "text": item_text} + for file_name, item_text in zip(file_names, text) + ] + batch = PaddedBatch(items) + return self.evaluate( + wavs=batch.wav.data.to(self.device), + length=batch.wav.lengths.to(self.device), + text=batch.text, + ) + + def read_audio(self, file_name): + """Reads an audio file, resampling if necessary + + Arguments + --------- + file_name : str | path-like + The file path + + Returns + ------- + audio : torch.Tensor + the audio + """ + audio, audio_sample_rate = torchaudio.load(str(file_name)) + return self.resample(audio, audio_sample_rate) + + def evaluate( + self, + wavs, + length, + text=None, + wavs_ref=None, + wavs_length_ref=None, + sample_rate=None, + ): + """Evaluates samples + + Arguments + --------- + wavs : torch.Tensor + the waveforms to evaluate + + length : torch.Tensor + relative lengths (a 1-D tensor) + + text : list + Evaluator-specific metadata + + wavs_ref : torch.Tensor + the reference waveforms + + wavs_length_ref + the reference waveform lengths + + sample_rate: int, optional + The sample rate of the audio. If not provided, + the audio is assumed to be at the same sample + rate as the model + + Returns + ------- + result : list + A list of SpeechEvaluationResult objects, + one for each sample""" + raise NotImplementedError() + + def resample(self, audio, sample_rate=None): + """Resamples the audio, if necessary + + Arguments + --------- + audio : torch.Tensor + the audio to be resampled + sample_rate : int + the sample rate of the audio + + Returns + ------- + audio : torch.Tensor + the target audio, resampled if necessary + """ + if sample_rate is not None and sample_rate != self.sample_rate: + audio = torchaudio.functional.resample( + audio, orig_freq=sample_rate, new_freq=self.sample_rate + ) + return audio + + +def _unbatchify(value): + """Removes the batch dimension from the tensor. If a single + number is returned in any shape, the function converts + the result to a numeric value. Values that are not tensors + are returned unmodified + + Arguments + --------- + value : object + the value + + Returns + ------- + value : object + the value with the batch dimension removed, if applicable + """ + if torch.is_tensor(value): + if value.dim() == 0 or not any(dim > 1 for dim in value.shape): + value = value.item() + else: + value = value.squeeze(0) + return value + + +class SpeechEvaluationRegressionModel(Pretrained): + """A pretrained wrapper for regression-based evaluaton + models""" + + def __call__(self, wavs, length): + return self.mods.model(wavs, length) + + +class RegressionModelSpeechEvaluator(SpeechEvaluator): + """A speech evaluator that uses a regression model + that produces a quality score (e.g. SSL fine-tuning) + for a sample of speech + + Arguments + --------- + source : str + The source model path or HuggingFace hub name + sample_rate : int + The audio sample rate this evaluator expects + """ + + def __init__(self, source, sample_rate=None, *args, **kwargs): + super().__init__(sample_rate=sample_rate) + self.model = SpeechEvaluationRegressionModel.from_hparams( + source, *args, **kwargs + ) + + def evaluate( + self, + wavs, + length, + text=None, + wavs_ref=None, + length_ref=None, + sample_rate=None, + sample_rate_ref=None, + ): + """Evaluates a batch of waveforms + + Arguments + --------- + Arguments + --------- + wavs: torch.Tensor + the waveforms to evaluate + + length: torch.Tensor + relative lengths (a 1-D tensor) + + text : list, optional + Ground truth text + + wavs_ref : torch.Tensor + the reference waveforms + + length_ref : torch.Tensor + the reference waveform lengths + + sample_rate : int, optional + The sample rate of the audio. If not provided, + the audio is assumed to be at the same sample + rate as the model + + sample_rate_ref : int, optional + The sample rate of the reference samples + + Returns + ------- + result : SpeechEvaluationResult + an aggregated speech evaluation result with a score + for each item + """ + wavs = self.resample(wavs, sample_rate) + scores = self.model(wavs, length) + while scores.dim() > 1 and scores.size(-1) == 1: + scores = scores.squeeze(-1) + return SpeechEvaluationResult(score=scores, details={"score": scores}) + + +class ASRSpeechEvaluator(SpeechEvaluator): + """A superclass for ASR speech evaluators""" + + def evaluate( + self, + wavs, + length, + text=None, + wavs_ref=None, + length_ref=None, + sample_rate=None, + sample_rate_ref=None, + ): + """Evaluates samples + + Arguments + --------- + wavs: torch.Tensor + the waveforms to evaluate + + length: torch.Tensor + relative lengths (a 1-D tensor) + + text : list, optional + Ground truth text + + wavs_ref : torch.Tensor + the reference waveforms + + length_ref : torch.Tensor + the reference waveform lengths + + + sample_rate : int, optional + The sample rate of the audio. If not provided, + the audio is assumed to be at the same sample + rate as the model + + sample_rate_ref : int, optional + The sample rate of the reference samples + + Returns + ------- + result : SpeechEvaluationResult + an aggregated speech evaluation result with a score + for each item + """ + details = self.evaluate_samples( + wavs=wavs, length=length, text=text, sample_rate=sample_rate + ) + if wavs_ref is not None: + details_ref = self.evaluate_samples( + wavs=wavs_ref, + length=length_ref, + text=text, + sample_rate=sample_rate_ref, + ) + details.update( + {f"{key}_ref": value for key, value in details_ref.items()} + ) + # Redundant: it is the same + del details["target_ref"] + details.update(self.compute_diff_rate(details, device=wavs.device)) + + return SpeechEvaluationResult(score=details["wer"], details=details,) + + def compute_diff_rate(self, details, device): + """Computes the differential token rate + + Arguments + --------- + details : dict + The evaluation details + Keys: + "pred": ASR predictions for the TTS sample + "pred_ref": ASR predictions for the ground + truth + + Returns + ------- + result: dict + A dictionary with the following keys + + dwer : torch.Tensor + The differential Word Error Rate (dWER) + dcer : torch.Tensor + The differential Character Error Rate (dCER) + + """ + ids = range(1, len(details["pred"]) + 1) + wer_metric, cer_metric = init_asr_metrics() + pred = self._replace_blanks(details["pred"]) + pred_ref = self._replace_blanks(details["pred_ref"]) + wer_metric.append(ids, pred, pred_ref) + cer_metric.append(ids, pred, pred_ref) + dwer = torch.tensor( + [score["WER"] for score in wer_metric.scores], device=device + ) + dcer = torch.tensor( + [score["WER"] for score in cer_metric.scores], device=device + ) + return {"dwer": dwer, "dcer": dcer} + + def _replace_blanks(self, preds): + """Replaces blanks with single spaces, preventing an exception + in the case of an unintelligible sample + + Arguments + --------- + """ + return [" " if item == "" else item for item in preds] + + +class EncoderDecoderASRSpeechEvaluator(ASRSpeechEvaluator): + """A speech evaluator implementation based on ASR. + Computes the Word Error Rate (WER), Character Error Rate (CER) + and a few other metrics + + Arguments + --------- + sample_rate : int + The audio sample rate this evaluator expects + """ + + def __init__(self, source, sample_rate=None, *args, **kwargs): + super().__init__(sample_rate=sample_rate) + self.asr = EncoderDecoderASR.from_hparams(source, *args, **kwargs) + self.device = next(self.asr.mods.parameters()).device + + def evaluate_samples(self, wavs, length, text, sample_rate): + wavs = self.resample(wavs, sample_rate) + if text is None: + raise ValueError("This evaluator requires ground-truth text") + predicted_words, scores, log_probs = self.transcribe_batch_with_details( + wavs, length + ) + ids = range(1, len(wavs) + 1) + wer_metric, cer_metric = init_asr_metrics() + wer_metric.append(ids, predicted_words, text) + cer_metric.append(ids, predicted_words, text) + wer = torch.tensor( + [score["WER"] for score in wer_metric.scores], device=wavs.device + ) + cer = torch.tensor( + [score["WER"] for score in cer_metric.scores], device=wavs.device + ) + prob_mean = log_probs.exp().mean(dim=-1) + return { + "wer": wer, + "cer": cer, + "beam_score": scores, + "prob_mean": prob_mean, + "pred": predicted_words, + "target": text, + } + + def transcribe_batch_with_details(self, wavs, wav_lens): + """Transcribes the input audio into a sequence of words + + The waveforms should already be in the model's desired format. + You can call: + ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)`` + to get a correctly converted signal in most cases. + + Arguments + --------- + predicted_words : list + The raw ASR predictions, fully decoded + best_scores : list + The best scores (from beam search) + best_log_probs : list + The best predicted log-probabilities (from beam search) + + + Returns + ------- + predicted_words : list + The predictions + + best_scores : torch.Tensor + The best scores (from beam search) + + best_log_probs : torch.Tensor + The best log-probabilities + + """ + with torch.no_grad(): + wav_lens = wav_lens.to(self.device) + encoder_out = self.asr.encode_batch(wavs, wav_lens) + ( + hyps, + best_lens, + best_scores, + best_log_probs, + ) = self.asr.mods.decoder(encoder_out, wav_lens) + predicted_words = [ + self.asr.tokenizer.decode_ids(token_seq) for token_seq in hyps + ] + return predicted_words, best_scores, best_log_probs + + def to(self, device): + """Transfers this module to the spcieifed device + + Arguments + --------- + device : str | torch.Device + the target device + """ + self.asr = self.asr.to(device) + return self + + +class WhisperASRSpeechEvaluator(ASRSpeechEvaluator): + """A speech evaluator implementation based on Whisper ASR + + Arguments + --------- + source : str + The source directory + savedir : str, optional + The path where Whisper will be saved + sample_rate: int, optional + The audio sample rate + min_decode_ratio : float, optional + The minimum decode ratio + max_decode_ratio : float, optional + The maximum decode ratio + run_opts : dict, optional + Run options for the Whisper model + unbatch : bool, optional + If enabled, which is the default, the implementation + will evaluate samples one by one with a batch size of + 1 and then "reassemble" the original batch. This is + sometimes needed because batched inference has been + found to result in decreased performance, primarily + due to masks not being applied to convolutional layers + """ + + def __init__( + self, + source, + savedir=None, + sample_rate=22050, + min_decode_ratio=0.0, + max_decode_ratio=1.0, + run_opts=None, + unbatch=True, + ): + super().__init__(sample_rate=sample_rate) + if run_opts is None: + run_opts = {} + if savedir is None: + savedir = "." + self.model = Whisper( + source, savedir, sample_rate, freeze=True, freeze_encoder=True, + ) + self.model.tokenizer.set_prefix_tokens("english", "transcribe", False) + self.searcher = S2SWhisperGreedySearcher( + self.model, + min_decode_ratio=min_decode_ratio, + max_decode_ratio=max_decode_ratio, + ) + device = run_opts.get("device", next(self.model.parameters()).device) + self.unbatch = unbatch + self.to(device) + + def evaluate_samples(self, wavs, length, text, sample_rate): + """Evaluates a batch of samples + + Arguments + --------- + wavs : torch.Tensor + A batch of waveforms + length : torch.Tensor + Relative lengths + text : list + Text labels corresponding to the waveforms + sample_rate : int + The sample rate of the waveforms + + Returns + ------- + results : dict + The evaluation results + """ + if self.unbatch: + batch_size = len(wavs) + length_abs = (length * wavs.size(1)).int() + results = [ + self._evaluate_samples( + wavs[idx : idx + 1, : length_abs[idx].item()], + torch.ones(1, device=wavs.device), + text[idx : idx + 1], + sample_rate, + ) + for idx in range(batch_size) + ] + result = { + "wer": torch.stack( + [result["wer"] for result in results] + ).squeeze(-1), + "cer": torch.stack( + [result["cer"] for result in results] + ).squeeze(-1), + "pred": [result["pred"][0] for result in results], + "target": text, + } + return result + else: + return self._evaluate_samples(wavs, length, text, sample_rate) + + def _evaluate_samples(self, wavs, length, text, sample_rate): + """Evaluates a batch of samples. This function is meant + to be used internally. evaluate_samples will call + it multiple times if unbatch is enabled. + + Arguments + --------- + wavs : torch.Tensor + A batch of waveforms + length : torch.Tensor + Relative lengths + text : list + Text labels corresponding to the waveforms + sample_rate : int + The sample rate of the waveforms + + Returns + ------- + results : dict + The evaluation results + """ + if text is None: + raise ValueError("This evaluator requires ground-truth text") + wavs = self.resample(wavs, sample_rate) + wavs = self.model.pad_or_trim(wavs) + mels = self.model.log_mel_spectrogram(wavs) + enc_out = self.model.forward_encoder(mels) + predicted_words, _, _, _ = self.searcher(enc_out.detach(), length) + predicted_words = self.model.tokenizer.batch_decode( + predicted_words, skip_special_tokens=True + ) + predicted_words = [self.normalize(text) for text in predicted_words] + ids = range(1, len(wavs) + 1) + wer_metric, cer_metric = init_asr_metrics() + wer_metric.append(ids, predicted_words, text) + cer_metric.append(ids, predicted_words, text) + wer = torch.tensor( + [score["WER"] for score in wer_metric.scores], device=wavs.device + ) + cer = torch.tensor( + [score["WER"] for score in cer_metric.scores], device=wavs.device + ) + return { + "wer": wer, + "cer": cer, + "pred": predicted_words, + "target": text, + } + + def normalize(self, text): + """Performs text normalization (uppercase, remove whitespace, + remove punctuation) + + Arguments + --------- + text : str + Unnormalized text + + Returns + ------- + text : str + Normalized text + """ + text = text.upper() + text = text.strip() + text = RE_PUNCTUATION.sub("", text) + return text + + def to(self, device): + """Transfers this module to the spcieifed device + + Arguments + --------- + device : str | torch.Device + the target device + """ + self.model = self.model.to(device) + return self + + +def itemize(result): + """Converts a single batch result into per-item results + + Arguments + --------- + result: SpeechEvaluationResult + a single batch result + + Returns + ------- + results: list + a list of individual SpeechEvaluationResult instances""" + + return [ + SpeechEvaluationResult( + score=result.score[idx], + details={key: value[idx] for key, value in result.items()}, + ) + for idx in range(len(result.score)) + ] + + +def init_asr_metrics(): + """Initializes the WER and CER metrics + + Returns + ------- + wer_metric : ErrorRateStats + the Word Error Rate (WER) metric + cer_metric : ErrorRateStats + the Character Error Rate (CER) metric""" + wer_metric = ErrorRateStats() + cer_metric = ErrorRateStats(split_tokens=True) + return wer_metric, cer_metric + + +class BulkSpeechEvaluator: + """A base class for a speech evaluator that is invoked for a series of filesystem files + rather than one batch at a time. This is useful for implementing wrappers around + command-line tools that would be impractical to run for each batch because of + long initialization time (to load models, etc)""" + + def evaluate_files(self, file_names, text=None, file_names_ref=None): + """Evaluates multiple files + + Arguments + --------- + file_names : list + A list of files + + text : list, optional + File transcripts (not required for all evaluators) + + file_names_ref : list, optional + A list of reference files / ground truths (if applicable) + + Returns + ------- + result : SpeechEvaluationResult + a consolidated evaluation result + """ + raise NotImplementedError() + + +UTMOS_REPO = "https://huggingface.co/spaces/sarulab-speech/UTMOS-demo" + + +class UTMOSSpeechEvaluator(BulkSpeechEvaluator): + """An evaluation wrapper for UTMOS + + Github: https://github.com/sarulab-speech/UTMOS22 + HuggingFace: https://huggingface.co/spaces/sarulab-speech/UTMOS-demo + + Arguments + --------- + model_path : str | path-like + The path where the HuggingFace repository was extracted + output_folder : str | path-like + The folder where results will be output + ckpt_path : str | path-like + The path to the checkpoint to be used + script : str | path-like + The path to the evaluation script, defaults to the bundled + predict.py + python : str | path-like, optional + The path to the Python interpreter to be used, defaults to + "python". Depending on the environment, it might need to be + changed (e.g. to "python3" or an absolute path to the interpreter) + use_python : bool + Whether to launch the script using python. This flag will need to be + set to False in environments where running UTMOS requires a wrapper shell + script (e.g. to initialize a different Python virtual environment from + the one in which SpeechBrain is running) + tmp_folder : str | path-like, optional + The temporary folder where files will be copied for evaluation. If + omitted, it will be set to output_folder. This can be useful on + compute environments that provide fast local storage (e.g. certain + compute clusters) + repo : str + The repor + """ + + def __init__( + self, + model_path, + output_folder, + ckpt_path, + script="predict.py", + python="python", + use_python=True, + batch_size=8, + tmp_folder=None, + repo=UTMOS_REPO, + ): + self.output_folder = Path(output_folder) + rand = torch.randint(1, 999999999, (1,)).item() + if tmp_folder is None: + tmp_folder = self.output_folder + else: + tmp_folder = Path(tmp_folder) + self.eval_path = (tmp_folder / f"eval_{rand}").absolute() + self.model_path = Path(model_path).absolute() + script = self.model_path / script + self.script = script + self.ckpt_path = Path(ckpt_path).absolute() + self.batch_size = batch_size + self.python = python + self.use_python = use_python + self.repo = repo + self.install() + + def install(self): + if self.model_path.exists(): + logger.info("UTMOS is already installed in %s", self.model_path) + return + logger.info( + "Attempting to install UTMOS from %s to %s", + self.repo, + self.model_path, + ) + cmd = shlex.join( + [ + "git", + "-C", + str(self.model_path.parent), + "clone", + self.repo, + str(self.model_path.name), + ] + ) + output, err, return_code = run_shell(cmd) + if return_code != 0: + raise CommandError(cmd, output, err, return_code) + logger.info("Repository clone successful, performing an LFS fetch") + cwd = Path.cwd() + try: + os.chdir(self.model_path) + cmd = shlex.join(["git", "lfs", "fetch"]) + output, err, return_code = run_shell(cmd) + if return_code != 0: + raise CommandError(cmd, output, err, return_code) + finally: + os.chdir(cwd) + if not self.ckpt_path.exists(): + raise ValueError("ckpt_path {ckpt_path} does not exist") + + def evaluate_files(self, file_names, text, file_names_ref=None): + """Evaluates multiple files + + Arguments + --------- + file_names : list + A list of files + + text : list + File transcripts (not required for all evaluators) + Not used in this evaluator + + file_names_ref : list, optional + A list of reference files / ground truths (if applicable) + Not used in this evaluator + + Returns + ------- + result : SpeechEvaluationResult + a consolidated evaluation result + """ + current_path = os.getcwd() + try: + self.eval_path.mkdir(parents=True, exist_ok=True) + logger.info("Copying the files to '%s'", self.eval_path) + for file_name in file_names: + target_file_name = self.eval_path / Path(file_name).name + shutil.copy(file_name, target_file_name) + + logger.info("Running evaluation") + result_path = self.eval_path / "result.txt" + os.chdir(self.model_path) + cmd = [ + str(self.script), + "--mode", + "predict_dir", + "--bs", + str(self.batch_size), + "--inp_dir", + str(self.eval_path), + "--out_path", + str(result_path), + "--ckpt_path", + str(self.ckpt_path), + ] + if self.use_python: + cmd = [self.python] + cmd + + output = subprocess.check_output(cmd) + logger.info("Evaluation finished, output: %s", output) + file_names = [path.name for path in self.eval_path.glob("*.wav")] + with open(result_path) as result_path: + scores = [float(line.strip()) for line in result_path] + score_map = dict(zip(file_names, scores)) + scores_ordered = [ + score_map[Path(file_name).name] for file_name in file_names + ] + return SpeechEvaluationResult( + scores_ordered, {"utmos": scores_ordered} + ) + finally: + os.chdir(current_path) + shutil.rmtree(self.eval_path) + + +def vocoder_to_device(vocoder, device): + """A fix for vocoders that do not properly handle + the .to() function and require the device to be set manually + + Arguments + --------- + vocoder : torch.nn.Module + a vocoder + device : str | torch.Device + the target device + """ + if hasattr(vocoder, "model") and hasattr(vocoder.model, "device"): + vocoder.model.device = device + elif hasattr(vocoder, "device"): + vocoder.device = device + + +class Tracker: + """A tracker that makes it possible to resume evaluation + + Arguments + --------- + file_name : str | path-like + The path to the tracker file""" + + def __init__(self, file_name): + self.file_name = Path(file_name) + + def mark_processed(self, item_id): + """Marks the specified file as processed + + Arguments + --------- + item_id : str|enumerable + The item ID or a list of IDS + """ + if isinstance(item_id, str): + item_id = [item_id] + with open(self.file_name, "a+") as tracker_file: + for item in item_id: + print(item, file=tracker_file) + + def filter(self, dataset): + """Filters a dataset using the tracker file + + Arguments + --------- + dataset : speechbrain.dataio.dataset.DynamicItemDataset + A dataset + + Returns + ------- + dataset : speechbrain.dataio.dataset.DynamicItemDataset + The dataset, possibly filtered + """ + if self.file_name.exists(): + with open(self.file_name) as tracker_file: + processed_ids = set(line.strip() for line in tracker_file) + remaining_ids = [ + data_id + for data_id in dataset.data_ids + if data_id not in processed_ids + ] + logger.info( + "Tracker %s already exists, %d items already processed, %d items remaining", + self.file_name, + len(processed_ids), + len(remaining_ids), + ) + dataset = FilteredSortedDynamicItemDataset( + dataset, remaining_ids + ) + else: + logger.info( + "Tracker %s does not exist, evaluating from the beginning" + ) + return dataset + + def get_processed(self): + """Retrieves the IDs of items that have been processed + + Returns + ------- + processed_ids : list + The list of file IDs + """ + if self.file_name.exists(): + with open(self.file_name, "r") as tracker_file: + processed_ids = [line.strip() for line in tracker_file] + else: + processed_ids = [] + return processed_ids + + +class CommandError(Exception): + """Thrown when an external command returns an error + + Arguments + --------- + cmd : str + The command that was run + output : str + The captured standard output stream + err : str + The captured standard error stream + return_code : int + The return code""" + + def __init__(self, cmd, output, err, return_code): + super().__init__( + f"Command {cmd} returned code {return_code}\n" + f"Output: {output}\n" + f"Errors: {err}" + ) + self.cmd = cmd + self.output = output diff --git a/benchmarks/DASB/VoiceBank/enhancement/utils/preparation.py b/benchmarks/DASB/VoiceBank/enhancement/utils/preparation.py new file mode 100644 index 000000000..88e29e22c --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/utils/preparation.py @@ -0,0 +1,470 @@ +"""This file contains utilities for preprocessing of features, particularly +using neural models + +Authors + * Artem Ploujnikov 2023 +""" +import torch +import numpy as np +import math +import speechbrain as sb +import concurrent.futures +import logging +import re +import tarfile +import shutil +import os +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from speechbrain.dataio.dataloader import make_dataloader +from speechbrain.dataio.dataset import DynamicItemDataset +from speechbrain.utils.data_pipeline import DataPipeline +from data import undo_batch, as_dict +from tqdm.auto import tqdm + +logger = logging.getLogger(__name__) + +variable_finder = re.compile(r"\$([\w.]+)") + + +class FeatureExtractor: + """A utility class for pipeline-based feature extraction + + Arguments + --------- + save_path: str|path-like + the path where the preprocessed features will be saved + + id_key: str + the key within the batch that will be used as an identifier + + save_format: str|callable + the format in which prepared features will be saved + + device: str|torch.Device + the device on which operations will be run + + dataloader_opts: dict + parameters to be passed to the data loader (batch size, etc) + + dynamic_items : list + Configuration for the dynamic items produced when fetching an example. + List of DynamicItems or dicts with the format:: + func: # To be called + takes: # key or list of keys of args this takes + provides: key # key or list of keys that this provides + + """ + + def __init__( + self, + save_path, + src_keys, + id_key="id", + save_format="npy", + device="cpu", + dataloader_opts=None, + dynamic_items=None, + description=None, + async_save=True, + async_save_batch_size=16, + async_save_concurrency=8, + ): + if not dataloader_opts: + dataloader_opts = {} + self.id_key = id_key + self.save_path = save_path + self.src_keys = src_keys + self.id_key = id_key + self.dataloader_opts = dataloader_opts + if callable(save_format): + self.save_fn = save_format + elif save_format in SAVE_FORMATS: + self.save_fn = SAVE_FORMATS[save_format] + else: + raise ValueError(f"Unsupported save_format: {save_format}") + self.device = device + self.pipeline = DataPipeline( + static_data_keys=src_keys, dynamic_items=dynamic_items or [] + ) + self.async_save = async_save + self._async_save_futures = {} + self.async_save_batch_size = async_save_batch_size + self.async_save_concurrency = async_save_concurrency + self.save_executor = None + self.description = description + + def extract(self, dataset, data=None): + """Runs the preprocessing operation + + Arguments + --------- + dataset : dict|speechbrain.dataio.dataset.DynamicItemDataset + the dataset to be saved + data : dict + the raw data dictionary (to update with extra features) + """ + if isinstance(dataset, dict): + dataset = DynamicItemDataset(dataset) + dataset.set_output_keys(self.src_keys + [self.id_key]) + if self.async_save: + self._init_async_save() + try: + dataloader = make_dataloader(dataset, **self.dataloader_opts) + batch_size = self.dataloader_opts.get("batch_size", 1) + batch_count = int(math.ceil(len(dataset) / batch_size)) + for batch in tqdm( + dataloader, total=batch_count, desc=self.description + ): + batch = batch.to(self.device) + self.process_batch(batch, data) + finally: + if self.async_save: + self._finish_async_save() + + def _init_async_save(self): + self.save_executor = ThreadPoolExecutor( + max_workers=self.async_save_concurrency + ) + + def _finish_async_save(self): + try: + self.flush() + finally: + self.save_executor.shutdown() + self.save_executor = None + + def process_batch(self, batch, data): + """Processes a batch of data + + Arguments + --------- + batch: speechbrain.dataio.batch.PaddedBatch + a batch + data : dict + the raw data dictionary (to update with extra features) + """ + batch_dict = as_dict(batch) + ids = batch_dict[self.id_key] + features = self.pipeline.compute_outputs(batch_dict) + + for idx, (item_id, item_features) in enumerate( + zip(ids, undo_batch(features)), start=1 + ): + self._add_inline_features(item_id, item_features, data) + if self.async_save: + future = self.save_executor.submit( + self.save_fn, + item_id, + item_features, + save_path=self.save_path, + ) + self._async_save_futures[item_id] = future + if idx % self.async_save_batch_size == 0: + self.flush() + else: + self.save_fn(item_id, item_features, save_path=self.save_path) + + def flush(self): + """Flushes all futures that have been accumulated""" + concurrent.futures.wait(self._async_save_futures.values()) + for item_id, future in self._async_save_futures.items(): + exc = future.exception() + if exc is not None: + exc_info = (type(exc), exc, exc.__traceback__) + logger.warn( + "Saving extracted features for %s could not be completed: %s", + item_id, + str(exc), + exc_info=exc_info, + ) + self._async_save_futures.clear() + + def _add_inline_features(self, item_id, item_features, data): + item_data = data.get(item_id) if data is not None else None + for key in self.inline_keys: + if item_data is not None: + item_data[key] = item_features[key] + del item_features[key] + return item_features + + def add_dynamic_item(self, func, takes=None, provides=None): + """Adds a dynamic item to be output + + Two calling conventions. For DynamicItem objects, just use: + add_dynamic_item(dynamic_item). + But otherwise, should use: + add_dynamic_item(func, takes, provides). + + See `speechbrain.utils.data_pipeline`. + + Arguments + --------- + func : callable, DynamicItem + If a DynamicItem is given, adds that directly. Otherwise a + DynamicItem is created, and this specifies the callable to use. If + a generator function is given, then create a GeneratorDynamicItem. + Otherwise creates a normal DynamicItem. + takes : list, str + List of keys. When func is called, each key is resolved to + either an entry in the data or the output of another dynamic_item. + The func is then called with these as positional arguments, + in the same order as specified here. + A single arg can be given directly. + provides : str + Unique key or keys that this provides. + """ + self.pipeline.add_dynamic_item(func, takes, provides) + + def set_output_features(self, keys, inline_keys=None): + """Sets the features to be output + + Arguments + --------- + keys : list + Keys to be output / saved + inline_keys : list, optional + The keys to be used inline (added to the data dictionary + rather than saved in flies)""" + self.inline_keys = inline_keys or [] + self.pipeline.set_output_keys(keys + self.inline_keys) + + +def save_pt(item_id, data, save_path): + """Saves the data in the PyTorch format (one file per sample) + + Arguments + --------- + item_id: str + the ID of the item to be saved + + data: dict + the data to be saved + + save_path: path-like + the destination path + """ + file_path = save_path / f"{item_id}.pt" + torch.save(data, file_path) + + +def save_npy(item_id, data, save_path): + """Saves the data in numpy format (one file per sample per feature) + + Arguments + --------- + item_id: str + the ID of the item to be saved + + data: dict + the data to be saved + + save_path: path-like + the destination path + """ + for key, value in data.items(): + file_path = save_path / f"{key}_{item_id}.npy" + np.save(file_path, value.detach().cpu().numpy()) + + +def load_pt(save_path, item_id, features): + """Loads a PyTorch pickled file + + Arguments + --------- + save_path : path-like + The storage path + item_id : object + The item identifier + features : enumerable + Not used + + Returns + ------- + result : object + the contents of the file + """ + file_path = save_path / f"{item_id}.pt" + return torch.load(file_path) + + +def load_npy(save_path, item_id, features): + """Loads a raw NumPy array + + Arguments + --------- + save_path : path-like + The storage path + item_id : object + The item identifier + features : enumerable + The features to be loaded + """ + return { + key: np.load(save_path / f"{key}_{item_id}.npy") for key in features + } + + +SAVE_FORMATS = { + "pt": save_pt, + "npy": save_npy, +} + +LOAD_FORMATS = { + "pt": load_pt, + "npy": load_npy, +} + + +def add_prepared_features( + dataset, save_path, features, id_key="id", save_format="npy" +): + """Adds prepared features to a pipeline + + Arguments + --------- + dataset : speechbrains.dataio.dataset.DynamicItemDataset + a dataset + save_path : str|path-like + the path where prepared features are saved + features : list + the list of features to be added + id_key : str + the ID of the pipeline elements used as the item ID + save_format : str | callable + One of the known formats (pt or npy) or a custom + function to load prepared features for a data sample""" + load_fn = LOAD_FORMATS.get(save_format, save_format) + save_path = Path(save_path) + + @sb.utils.data_pipeline.takes(id_key) + @sb.utils.data_pipeline.provides(*features) + def prepared_features_pipeline(item_id): + """A pipeline function that provides the features defined with + registered loaders + + Arguments + --------- + item_id : object + The item dentifier + + Returns + ------- + result : generator + The features + """ + data = load_fn(save_path, item_id, features) + for feature in features: + yield data[feature] + + dataset.add_dynamic_item(prepared_features_pipeline) + + +DEFAULT_PATTERNS = ["*.csv", "*.json", "features", "*_prepare.pkl"] + + +class Freezer: + """A utility class that helps archive and restore prepared + data. This is particularly useful on compute clusters where + preparation needs to be done on non-permanent storage + + Arguments + --------- + save_path : str|path-like + the path where prepared data is saved + archive_path : str|path-like + the path to the archive + patterns : enumerable + a list of glob patterns with prepared files + """ + + def __init__(self, save_path, archive_path, patterns=None): + self.save_path = Path(save_path) + self.archive_path = Path(archive_path) if archive_path else None + self.patterns = patterns or DEFAULT_PATTERNS + + def freeze(self): + """Archives pretrained files""" + if self.archive_path is None: + logger.info("Prepared data archiving is unavailable") + return + if self.archive_path.exists(): + logger.info( + "The prepared dataset has already been archived in %s", + self.archive_path, + ) + return + file_names = self.get_files() + logger.info( + "Archiving %d files from the prepared dataset in %s", + len(file_names), + self.archive_path, + ) + mode = self._get_archive_mode("w") + tmp_archive_path = self.save_path / self.archive_path.name + logger.info("Creating a temporary archive: %s", tmp_archive_path) + with tarfile.open(tmp_archive_path, mode) as tar_file: + for file_name in file_names: + tar_file.add( + name=file_name, + arcname=file_name.relative_to(self.save_path), + ) + logger.info("Copying %s to %s", tmp_archive_path, self.archive_path) + shutil.copy(tmp_archive_path, self.archive_path) + logger.info("Done copying, removing %s", tmp_archive_path) + os.remove(tmp_archive_path) + + def _get_archive_mode(self, mode): + """Adds a suffix to the archive mode""" + if self.archive_path.name.endswith(".gz"): + mode = f"{mode}:gz" + return mode + + def unfreeze(self): + """Unarchives pretrained files into save_path + + Returns + ------- + result: bool + True if the archive exists and has been unpacked, + False otherwise.""" + if self.archive_path is None: + logger.info("Prepared dataset freezing is disabled") + result = False + elif self.archive_path.exists(): + logger.info( + "Unpacking prepared dataset %s into %s", + self.archive_path, + self.save_path, + ) + mode = self._get_archive_mode("r") + with tarfile.open(self.archive_path, mode) as tar_file: + tar_file.extractall(self.save_path) + logger.info("Prepared dataset unpacked") + result = True + else: + logger.info( + "No frozen prepared dataset exists in %s", self.archive_path + ) + result = False + return result + + def get_files(self): + """Returns the list of prepared files available + to be archived + + Returns + ------- + result: list + A list of file names""" + return [ + file_name + for pattern in self.patterns + for file_name in self.save_path.glob(pattern) + ] + + def __enter__(self): + self.unfreeze() + + def __exit__(self, exc_type, exc_value, traceback): + self.freeze() diff --git a/benchmarks/DASB/VoiceBank/enhancement/utils/tokenizer_interface.py b/benchmarks/DASB/VoiceBank/enhancement/utils/tokenizer_interface.py new file mode 100644 index 000000000..be73fda74 --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/utils/tokenizer_interface.py @@ -0,0 +1,515 @@ +""" +Unified interface for tokenizers, standardizing the output shape of encode and decode functions. + +This class reshapes the outputs of various tokenizers to ensure consistency, simplifying integration with recipes and workflows. + +Authors +--------- +* Pooneh Mousavi, 2024 +""" +import sys +import os +import torch +from abc import ABC, abstractmethod +from speechbrain.lobes.models.huggingface_transformers.encodec import Encodec +from speechbrain.lobes.models.huggingface_transformers.discrete_ssl import ( + DiscreteSSL, +) +from speechbrain.lobes.models.discrete.dac import DAC +from speechbrain.lobes.models.discrete.speechtokenizer import SpeechTokenizer +from speechbrain.lobes.models.discrete.wavtokenizer import WavTokenizer +from speechbrain.lobes.models.huggingface_transformers.mimi import Mimi + +base_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..") +) # noqa: E402 +sys.path.append(base_dir) # noqa: E402 + +from model.sq_codec import SQCodec # noqa: E402 + + +class BaseTokenizer(ABC): + """ + Abstract base class for tokenizers that encode signals into discrete tokens + and decode tokens back into signals. + + This class defines the essential methods that any tokenizer must implement, + including encoding, decoding, and retrieving pretrained embeddings. + + Naming Convenstion + ------------------ + B : int + Batch size. + T : int + Sequence length in the time domain. + N : int + Sequence length in the token domain. + C : int + Vocabulary size, assuming each codebook has the same number of tokens. + K : int + Number of codebooks. + """ + + def __init__(self): + """ + Initialize the BaseTokenizer. + + This is a base constructor that other tokenizers can extend. + """ + super().__init__() + + @abstractmethod + @torch.no_grad() + def sig_to_tokens(self, signal, lengths=None, num_codebooks=None, **kwargs): + """ + Encode a signal into discrete tokens. + + Arguments + --------- + signal : torch.Tensor + Input signal with shape [B, T]. + lengths : torch.Tensor + Lengths of each sequence in the batch, with shape [B]. + num_codebooks : int, optional + Number of codebooks to use for encoding. If None, all codebooks are used (default: None). + If specified as an int, the tokens will be truncated to include only the first `num_codebooks` codebooks. If specified as a list, + the tokens will include only the codebooks at the specified indices. + **kwargs : dict + Additional arguments for the tokenizer. + + Returns + ------- + tokens : torch.Tensor + Discretized tokens with shape [B, N, K]. + """ + pass + + @abstractmethod + @torch.no_grad() + def tokens_to_sig(self, tokens, **kwargs): + """ + Decode discrete tokens back into a signal. + + Arguments + --------- + tokens : torch.Tensor + Input tokens with shape [B, N, K]. + **kwargs : dict + Additional arguments for the tokenizer. + + Returns + ------- + signal : torch.Tensor + Reconstructed signal with shape [B, T]. + """ + pass + + @abstractmethod + @torch.no_grad() + def get_pretrained_embeddings(self, vocab_size, num_codebooks, **kwargs): + """ + Retrieve pretrained embeddings for the tokenizer. + + Arguments + --------- + vocab_size : int + Number of tokens in each codebook. + num_codebooks : int + Number of codebooks. + **kwargs : dict + Additional arguments for embedding retrieval. + + Returns + ------- + embeddings : torch.Tensor + Pretrained embedding weights with shape [K * C, H], where H is the embedding dimension. + """ + pass + + +class EncodecTokenizer(Encodec, BaseTokenizer): + """This is a wrapper for the Encodec implemented in the SpeechBrain main repository. + + Source paper: + https://arxiv.org/abs/2210.13438 + Example + ------- + >>> model_hub = "facebook/encodec_24khz" + >>> save_path = "savedir" + >>> model = EncodecTokenizer(model_hub, save_path) + >>> emb=model.get_pretrained_embeddings() + >>> emb.shape + torch.Size([2048, 128]) + >>> audio = torch.randn(4, 1000) + >>> length = torch.tensor([1.0, .5, .75, 1.0]) + >>> tokens= model.sig_to_tokens(audio, length) + >>> tokens.shape + torch.Size([4, 4, 2]) + >>> rec = model.tokens_to_sig(tokens, lenght=length) + >>> rec.shape + torch.Size([4, 1280] + """ + + def __init__(self, *args, **kwargs): + Encodec.__init__(self, *args, **kwargs) + BaseTokenizer.__init__(self) + + @torch.no_grad() + def sig_to_tokens(self, signal, lengths=None, num_codebooks=None, **kwargs): + self.eval() + tokens, _ = self.encode(signal, lengths) + if num_codebooks: + if tokens.shape[-1] < num_codebooks: + raise ValueError( + f"Model only outputs {tokens.shape[-1]} codebooks, but {num_codebooks} requested" + ) + tokens = tokens[..., :num_codebooks] + return tokens + + @torch.no_grad() + def tokens_to_sig(self, tokens, **kwargs): + self.eval() + signal = self.decode(tokens)[:, 0] + return signal + + @torch.no_grad() + def get_pretrained_embeddings( + self, vocab_size=None, num_codebooks=None, **kwargs + ): + embeddings = self.vocabulary + return embeddings.reshape(-1, embeddings.shape[-1]) + + +class DACTokenizer(DAC, BaseTokenizer): + """This is a wrapper for the DAC implemented in the SpeechBrain main repository. + + Source paper: + http://arxiv.org/abs/2306.06546 + Example + ------- + >>> model = DACTokenizer(load_pretrained=True, model_type="24KHz", model_bitrate="8kbps", tag="latest") + >>> audio = torch.randn(4, 16000) + >>> emb=model.get_pretrained_embeddings(vocab_size=1024, num_codebooks=8) + >>> emb.shape + torch.Size([8192, 1024]) + >>> tokens= model.sig_to_tokens(audio) + >>> tokens.shape + torch.Size([4, 50, 32]) + >>> rec = model.tokens_to_sig(tokens) + >>> rec.shape + torch.Size([4, 15992]) + """ + + def __init__(self, *args, **kwargs): + DAC.__init__(self, *args, **kwargs) + BaseTokenizer.__init__(self) + + @torch.no_grad() + def sig_to_tokens(self, signal, lengths=None, num_codebooks=None, **kwargs): + self.eval() + tokens, _ = self(signal[:, None], n_quantizers=num_codebooks) + return tokens.movedim(-1, -2) + + @torch.no_grad() + def tokens_to_sig(self, tokens, **kwargs): + self.eval() + quantized_feats, _, _ = self.quantizer.from_codes( + tokens.movedim(-1, -2) + ) + return self.decode(quantized_feats)[:, 0] + + @torch.no_grad() + def get_pretrained_embeddings( + self, vocab_size=None, num_codebooks=None, **kwargs + ): + toks = torch.arange(vocab_size).to(next(self.parameters()).device) + toks = toks[:, None, None].expand(-1, num_codebooks, -1).clone() + self.eval() + z_q, z_p, _ = self.quantizer.from_codes(toks) + z_ps = z_p.split(z_p.shape[1] // toks.shape[1], dim=1) + z_qs = [ + self.quantizer.quantizers[i].out_proj(z_p_i) + for i, z_p_i in enumerate(z_ps) + ] + return torch.cat(z_qs)[:, :, 0] + + +class SpeechTokenizerWrapper(SpeechTokenizer, BaseTokenizer): + """This is a wrapper for the SpeechTokenizer implemented in the SpeechBrain main repository. + + Source paper: + https://arxiv.org/abs/2308.16692 + Example + ------- + >>> audio = torch.rand([10, 600]) + >>> model_hub = "fnlp/SpeechTokenizer" + >>> save_path = "savedir" + >>> model = SpeechTokenizerWrapper(model_hub, save_path) + >>> emb=model.get_pretrained_embeddings(vocab_size=1024, num_codebooks=8) + >>> emb.shape + torch.Size([8192, 1024]) + >>> tokens= model.sig_to_tokens(audio) + >>> tokens.shape + torch.Size([10, 2, 8]) + >>> rec = model.tokens_to_sig(tokens) + >>> rec.shape + torch.Size([10, 640]) + """ + + def __init__(self, *args, **kwargs): + SpeechTokenizer.__init__(self, *args, **kwargs) + BaseTokenizer.__init__(self) + + @torch.no_grad() + def sig_to_tokens(self, signal, lengths=None, num_codebooks=None, **kwargs): + self.eval() + tokens = self(signal) + if num_codebooks: + if len(tokens) < num_codebooks: + raise ValueError( + f"Model only outputs {len(tokens)} codebooks, but {num_codebooks} requested" + ) + tokens = tokens[:num_codebooks] + return tokens.movedim(-3, -1) + + @torch.no_grad() + def tokens_to_sig(self, tokens, **kwargs): + self.eval() + return self.decode(tokens.movedim(-1, -3)) + + @torch.no_grad() + def get_pretrained_embeddings( + self, vocab_size=None, num_codebooks=None, **kwargs + ): + toks = torch.arange(vocab_size).to(next(self.parameters()).device) + toks = toks[None, :, None].expand(num_codebooks, -1, -1).clone() + self.eval() + embs = [ + self.model.quantizer.vq.layers[i].decode(indices) + for i, indices in enumerate(toks) + ] + return torch.cat(embs)[:, :, 0] + + +class DiscreteSSLTokenizer(DiscreteSSL, BaseTokenizer): + """This is a wrapper for the Encodec implemented in the SpeechBrain main repository. + + Source paper: + https://arxiv.org/abs/2210.13438 + Example + ------- + >>> from speechbrain.lobes.models.huggingface_transformers.wavlm import (WavLM) + >>> inputs = torch.rand([3, 2000]) + >>> model_hub = "microsoft/wavlm-large" + >>> save_path = "savedir" + >>> ssl_layer_num = [7,23] + >>> deduplicate =[False, True] + >>> bpe_tokenizers=[None, None] + >>> vocoder_repo_id = "speechbrain/hifigan-wavlm-k1000-LibriTTS" + >>> kmeans_dataset = "LibriSpeech" + >>> num_clusters = 1000 + >>> ssl_model = WavLM(model_hub, save_path,output_all_hiddens=True) + >>> model = DiscreteSSLTokenizer(save_path, ssl_model, vocoder_repo_id=vocoder_repo_id, kmeans_dataset=kmeans_dataset,num_clusters=num_clusters) + >>> emb=model.get_pretrained_embeddings(num_codebooks=ssl_layer_num) + >>> emb.shape + torch.Size([2000, 1024]) + >>> tokens= model.sig_to_tokens(inputs,num_codebooks=ssl_layer_num, deduplicates=deduplicate, bpe_tokenizers=bpe_tokenizers) + >>> tokens.shape + torch.Size([3, 6, 2]) + >>> sig = model.tokens_to_sig(tokens, SSL_layers=ssl_layer_num) + >>> sig.shape + torch.Size([3, 1920]) + """ + + def __init__(self, *args, **kwargs): + DiscreteSSL.__init__(self, *args, **kwargs) + BaseTokenizer.__init__(self) + + @torch.no_grad() + def sig_to_tokens(self, signal, lengths=None, num_codebooks=None, **kwargs): + self.eval() + tokens, _, _ = self.encode( + signal, lengths, SSL_layers=num_codebooks, **kwargs + ) + return tokens + + @torch.no_grad() + def tokens_to_sig(self, tokens, **kwargs): + self.eval() + return self.decode(tokens, **kwargs).squeeze(1) + + @torch.no_grad() + def get_pretrained_embeddings( + self, vocab_size=None, num_codebooks=None, **kwargs + ): + embs = [] + for layer_num, vocabulary in zip( + self.ssl_layer_ids, self.vocabularies, + ): + if layer_num not in num_codebooks: + continue + embs.append(torch.as_tensor(vocabulary, dtype=torch.float32)) + embs = torch.cat(embs) + return embs + + +class MimiTokenizer(Mimi, BaseTokenizer): + """This is a wrapper for the Mimi implemented in the SpeechBrain main repository. + + Source paper: + https://kyutai.org/Moshi.pdf + Example + ------- + >>> model_hub = "kyutai/mimi" + >>> save_path = "savedir" + >>> model = MimiTokenizer(model_hub, save_path) + >>> emb=model.get_pretrained_embeddings() + >>> emb.shape + torch.Size([16384, 256]) + >>> audio = torch.randn(4, 48000) + >>> length = torch.tensor([1.0, .5, .75, 1.0]) + >>> tokens = model.sig_to_tokens(audio, length) + >>> tokens.shape + torch.Size([4, 25, 8]) + >>> rec = model.tokens_to_sig(tokens, length=length) + >>> rec.shape + torch.Size([4, 48000]) + """ + + def __init__(self, *args, **kwargs): + Mimi.__init__(self, *args, **kwargs) + BaseTokenizer.__init__(self) + + @torch.no_grad() + def sig_to_tokens(self, signal, lengths=None, num_codebooks=None, **kwargs): + self.eval() + tokens, _ = self.encode(signal, lengths) + if num_codebooks: + if tokens.shape[1] < num_codebooks: + raise ValueError( + f"Model only outputs {tokens.shape[1]} codebooks, but {num_codebooks} requested" + ) + tokens = tokens[:, :num_codebooks, :] + return tokens.movedim(-1, -2) + + @torch.no_grad() + def tokens_to_sig(self, tokens, **kwargs): + self.eval() + signal = self.decode(tokens.movedim(-1, -2), **kwargs)[:, 0] + return signal + + @torch.no_grad() + def get_pretrained_embeddings( + self, vocab_size=None, num_codebooks=None, **kwargs + ): + return self.embeddings.view(-1, self.embeddings.size(-1)) + + +class WavTokenizerWrapper(WavTokenizer, BaseTokenizer): + """This is a wrapper for the WavTokenizer implemented in the SpeechBrain main repository. + + Source paper: + https://arxiv.org/abs/2408.16532 + + Example + ------- + >>> model_hub = "novateur/WavTokenizer" + >>> save_path = "savedir" + >>> config="wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn.yaml" + >>> checkpoint="WavTokenizer_small_600_24k_4096.ckpt" + >>> model = WavTokenizerWrapper(model_hub, save_path,config=config,checkpoint=checkpoint) + >>> emb=model.get_pretrained_embeddings() + >>> emb.shape + torch.Size([4096, 512]) + >>> audio = torch.randn(4, 48000) + >>> length = torch.tensor([1.0, .5, .75, 1.0]) + >>> tokens= model.sig_to_tokens(audio, length) + >>> tokens.shape + torch.Size([4, 80, 1]) + >>> rec = model.tokens_to_sig(tokens) + >>> rec.shape + torch.Size([4, 48000]) + """ + + def __init__(self, *args, **kwargs): + WavTokenizer.__init__(self, *args, **kwargs) + BaseTokenizer.__init__(self) + + @torch.no_grad() + def sig_to_tokens(self, signal, lengths=None, num_codebooks=None, **kwargs): + self.eval() + tokens, _ = self.encode(signal) + if num_codebooks: + if tokens.shape[1] < num_codebooks: + raise ValueError( + f"Model only outputs {tokens.shape[1]} codebooks, but {num_codebooks} requested" + ) + tokens = tokens[:, :num_codebooks, :] + + return tokens.movedim(-2, -1) + + @torch.no_grad() + def tokens_to_sig(self, tokens, **kwargs): + self.eval() + signal = self.decode(tokens.movedim(-1, -2)) + return signal + + @torch.no_grad() + def get_pretrained_embeddings( + self, vocab_size=None, num_codebooks=None, **kwargs + ): + return self.embeddings + + +class SQCodecTokenizer(SQCodec, BaseTokenizer): + """This is a wrapper for the SQCoced implemented in the model folder. + + Source paper: + https://arxiv.org/abs/2406.02328, https://arxiv.org/abs/2408.13893 + + + Make sure that you download and extract the SQ-codec.zip in save_path from following Huggingface repo: + - HF repo: https://huggingface.co/Dongchao/UniAudio/blob/main/SQ-Codec.zip + + Example + ------- + >>> save_path = "savedir" + >>> config = "config.yaml" + >>> checkpoint = "ckpt_00190000.pth" + >>> model = SQCodecTokenizer(save_path, config, checkpoint) + >>> audio = torch.randn(3, 48000) + >>> tokens = model.sig_to_tokens(audio) + >>> tokens.shape + torch.Size([3, 150, 4]) + >>> rec = model.tokens_to_sig(tokens) + >>> rec.shape + torch.Size([3, 48000] + """ + + def __init__(self, *args, **kwargs): + SQCodec.__init__(self, *args, **kwargs) + BaseTokenizer.__init__(self) + + @torch.no_grad() + def sig_to_tokens(self, signal, lengths=None, num_codebooks=None, **kwargs): + self.eval() + tokens, _ = self.encode(signal) + return tokens.view(tokens.shape[0], -1, self.n_codebook) + + @torch.no_grad() + def tokens_to_sig(self, tokens, **kwargs): + self.eval() + signal = self.decode(tokens.view(tokens.shape[0], -1), **kwargs) + return signal.squeeze(1) + + @torch.no_grad() + def get_pretrained_embeddings( + self, vocab_size=None, num_codebooks=None, **kwargs + ): + """ + This method is not implemented for SQCodec, as it uses scalar quantization + and does not have any trainable quantizer or embedding. + """ + raise ValueError( + "SQCodec does not have any trainable quantizer or embedding since it uses scalar quantization." + ) diff --git a/benchmarks/DASB/VoiceBank/enhancement/utils/tokens.py b/benchmarks/DASB/VoiceBank/enhancement/utils/tokens.py new file mode 100644 index 000000000..03ea5049c --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/utils/tokens.py @@ -0,0 +1,411 @@ +""" +Unified interface for token extraction and pretrained embeddings handling for speech tokenizers. + +Authors +--------- +* Jarod Duret, 2024 +""" + +import math +import logging +import pathlib as pl +import kaldiio +import torch +import torchaudio +import numpy as np +from tqdm.auto import tqdm +import speechbrain as sb +from speechbrain.dataio.dataloader import make_dataloader +from speechbrain.dataio.dataset import DynamicItemDataset +from speechbrain.dataio.dataio import load_pkl, save_pkl + + +logger = logging.getLogger(__name__) +OPT_FILE = "opt_extract.pkl" + + +def get_device(use_cuda): + logger.info("=" * 30) + logger.info(f"USE_CUDA SET TO: {use_cuda}") + logger.info(f"CUDA AVAILABLE?: {torch.cuda.is_available()}") + logger.info("=" * 30) + use_cuda = use_cuda and torch.cuda.is_available() + return torch.device("cuda" if use_cuda else "cpu") + + +class TokensExtractor: + """ + Extracts tokens from audio data using a tokenizer and saves them to a specified format. + + Arguments + --------- + tokenizer : torch.nn.Module + The tokenizer model to use for token extraction. + sample_rate : int + The sample rate of the audio data. + src_key : str, optional + The key in the dataset that contains the audio data (default: "wav"). + id_key : str, optional + The key in the dataset that contains unique identifiers (default: "id"). + save_format : str, optional + The format to save the tokens ('numpy', 'pickle', 'soundfile_flac') (default: "numpy"). + use_cuda : bool, optional + Whether to use CUDA for computation (default: True). + dataloader_opts : dict, optional + Options for the data loader (default: None). + + Raises + ------ + ValueError + If an unsupported save_format is provided. + ValueError + If the tokenizer's sample rate does not match the provided sample_rate. + """ + + def __init__( + self, + tokenizer, + sample_rate, + src_key="wav", + id_key="id", + save_format="numpy", + use_cuda=True, + dataloader_opts=None, + ): + self.id_key = id_key + self.src_key = src_key + + self.device = get_device(use_cuda) + self.tokenizer = tokenizer.to(self.device) + self.sample_rate = sample_rate + + if tokenizer.sample_rate != self.sample_rate: + raise ValueError( + f"Sample rate mismatch: {self.sample_rate} != {tokenizer.sample_rate}" + ) + + if save_format not in ["numpy", "pickle", "soundfile_flac"]: + raise ValueError(f"Unsupported save_format: {save_format}") + self.save_format = save_format + + if not dataloader_opts: + dataloader_opts = {} + self.dataloader_opts = dataloader_opts + self.pipelines = self._make_pipelines() + + def extract_tokens( + self, dataset, num_codebooks, save_path, save_name="tokens" + ): + """ + Extracts tokens from the dataset and saves them to the specified format. + + Arguments + --------- + dataset : speechbrain.dataio.dataset.DynamicItemDataset or dict + The dataset from which to extract tokens. Can be a DynamicItemDataset or a dictionary. + num_codebooks: int + The number of codebooks to retrieve from the tokens. + save_path: str + The path where tokens will be saved. + save_name: str + The name of the .scp and .ark files. + """ + conf = { + "sample_rate": self.sample_rate, + "save_folder": save_path, + "dataset_length": len(dataset), + } + + save_path = pl.Path(save_path).absolute() + save_path.mkdir(parents=True, exist_ok=True) + + # Check if the extraction is already done (if so, skip it) + if _skip(save_path, save_name, conf): + logger.info("Skipping extraction, completed in previous run.") + return + + self.wspecifier = ( + f"ark,scp,t:{save_path}/{save_name}.ark,{save_path}/{save_name}.scp" + ) + self.writer = kaldiio.WriteHelper( + self.wspecifier, write_function="numpy" + ) + + if isinstance(dataset, dict): + dataset = DynamicItemDataset(dataset) + dataset.set_output_keys([self.src_key, self.id_key, "sig"]) + for pipeline in self.pipelines: + dataset.add_dynamic_item(pipeline) + + dataloader = make_dataloader(dataset, **self.dataloader_opts) + batch_size = self.dataloader_opts.get("batch_size", 1) + batch_count = int(math.ceil(len(dataset) / batch_size)) + for batch in tqdm(dataloader, total=batch_count): + batch = batch.to(self.device) + x, x_lengths = batch["sig"] + ids = batch[self.id_key] + batch_tokens = self.tokenizer.sig_to_tokens( + x, x_lengths, num_codebooks=num_codebooks + ) + batch_tokens = sb.utils.data_utils.undo_padding( + batch_tokens, x_lengths + ) + self.process_batch(batch_tokens, ids) + + logger.info("Extraction completed.") + + save_opt = save_path / OPT_FILE + save_pkl(conf, save_opt.as_posix()) + + def process_batch(self, batch, ids): + """ + Processes a batch of tokens and writes them to the output files. + + Arguments + --------- + batch : list + A list of tokens for each item in the batch. + ids : list + A list of unique identifiers corresponding to each item in the batch. + """ + for tokens, utt_id in zip(batch, ids): + tokens = np.array(tokens) + self.writer(utt_id, tokens) + + def _make_pipelines(self): + """ + Creates the data processing pipeline for audio data. + + The pipeline reads audio files, resamples them to the desired sample rate, and provides + the processed signal under the key "sig". + + Returns + ------- + pipeline : list + A list containing the audio processing pipeline function. + """ + + @sb.utils.data_pipeline.takes(self.src_key) + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav): + info = torchaudio.info(wav) + sig = sb.dataio.dataio.read_audio(wav) + sig = torchaudio.transforms.Resample( + info.sample_rate, self.sample_rate, + )(sig) + return sig + + return [audio_pipeline] + + def save_pretrained_embeddings( + self, + save_path, + save_name="embeddings", + vocab_size=None, + num_codebooks=None, + ): + """ + Saves the pretrained embeddings of the tokenizer to a specified directory. + + This method retrieves the pretrained embeddings from the tokenizer, + converts them to a NumPy array, and saves them as a `.npy` file. + + Parameters + ---------- + save_path : str or pathlib.Path + The directory where the pretrained embeddings will be saved. + If the directory does not exist, it will be created. + save_name : str, optional + The base name of the saved embeddings file (default is "embeddings"). + The embeddings will be saved as `.npy` in the specified directory. + """ + save_path = pl.Path(save_path).absolute() + save_path.mkdir(parents=True, exist_ok=True) + + embeddings = self.tokenizer.get_pretrained_embeddings( + vocab_size, num_codebooks + ) + embeddings = embeddings.cpu().numpy() + np.save(save_path / save_name, embeddings) + + def __del__(self): + """ + Close the writer. + """ + self.writer.close() + + +def _skip(save_path, save_name, conf): + """ + Detects if the dataset extraction has been already done. + If the extraction has been done, we can skip it. + + Arguments + --------- + save_path : str + The path to the directory containing extracted tokens. + save_name : str + The base name of the saved tokens file. + conf : dict + Configuration to match against saved config. + + Returns + ------- + bool + if True, the preparation phase can be skipped. + if False, it must be done. + """ + skip = True + + # Checking ark,scp files + for ext in [".ark", ".scp"]: + save_file = save_path / f"{save_name}{ext}" + if not save_file.exists: + skip = False + + # Checking saved options + save_opt = save_path / OPT_FILE + if skip is True: + if save_opt.exists(): + opts_old = load_pkl(save_opt.as_posix()) + if opts_old == conf: + skip = True + else: + skip = False + else: + skip = False + return skip + + +class TokensLoader: + """ + A loader class for retrieving tokens corresponding to utterance IDs. + + Arguments + --------- + data_path: str + The path to the data directory containing the token files. + save_name: str, optional + The base name of the tokens files (default: "tokens"). + """ + + def __init__( + self, data_path, save_name="tokens", + ): + self.data_path = pl.Path(data_path) + if not self.data_path.exists(): + raise ValueError( + f"Data folder not found: {self.data_path.as_posix()}" + ) + self.tokens = self._load(data_path, save_name) + + def tokens_by_uttid(self, utt_id, num_codebooks=None): + """ + Retrieves the tokens corresponding to a given utterance ID. + + Arguments + --------- + utt_id : str + The utterance ID to retrieve tokens for. + num_codebooks : int or list, optional + The number of codebooks to retrieve from the tokens. If specified as an int, the tokens + will be truncated to include only the first `num_codebooks` codebooks. If specified as a list, + the tokens will include only the codebooks at the specified indices. If not specified, all codebooks are returned. + + Returns + ------- + result : torch.LongTensor [T, N_Q] + The tokens associated with the utterance ID, possibly truncated to `num_codebooks` codebooks. + + Raises + ------ + KeyError + If the utterance ID is not found in the tokens. + ValueError + If `num_codebooks` is invalid or exceeds the number of available codebooks. + """ + if utt_id not in self.tokens: + raise KeyError(f"Utterance ID '{utt_id}' not found in tokens.") + tokens_path = self.tokens[utt_id] + tokens = kaldiio.load_mat(tokens_path) + tokens = torch.from_numpy(tokens).long() + + if num_codebooks is not None: + if isinstance(num_codebooks, int): + if num_codebooks <= 0: + raise ValueError( + f"Invalid num_codebooks value: {num_codebooks}. It must be a positive integer." + ) + if num_codebooks > tokens.size(-1): + raise ValueError( + f"Invalid number of codebooks: {num_codebooks}. " + f"Available codebooks: {tokens.size(-1)}." + ) + tokens = tokens[:, :num_codebooks] + elif isinstance(num_codebooks, list): + if not all( + isinstance(idx, int) and 0 <= idx < tokens.size(-1) + for idx in num_codebooks + ): + raise ValueError( + f"Invalid indices in num_codebooks list: {num_codebooks}. " + f"All indices must be integers within the range [0, {tokens.size(-1) - 1}]." + ) + tokens = tokens[:, num_codebooks] + else: + raise ValueError("num_codebooks must be an int or a list.") + + return tokens + + def _load(self, data_path, save_name): + """ + Loads the mapping from utterance IDs to token file paths. + + Arguments + --------- + data_path: str + The path to the data directory containing the token files. + save_name: str + The base name of the tokens files. + + Returns + ------- + utt2toks: dict + A dictionary mapping utterance IDs to their corresponding token file paths. + """ + scp_path = f"{data_path}/{save_name}.scp" + with open(scp_path, "r") as f: + utt2toks = { + line.strip().split(None, 1)[0]: line.strip().split(None, 1)[1] + for line in f + if line.strip() + } + return utt2toks + + def load_pretrained_embeddings(self, data_path, save_name="embeddings"): + """ + Loads pretrained embeddings from a specified path. + + Arguments + --------- + data_path : str + The directory where the embeddings are saved. + save_name : str, optional + The name of the embeddings file (default: "embeddings"). + + Returns + ------- + embeddings : torch.Tensor + The loaded embeddings as a PyTorch tensor. + + Raises + ------ + FileNotFoundError + If the embeddings file does not exist at the specified path. + """ + data_path = pl.Path(data_path).absolute() + if not self.data_path.exists(): + raise ValueError(f"Data folder not found: {data_path.as_posix()}") + embeddings = np.load(data_path / f"{save_name}.npy") + embeddings = torch.from_numpy(embeddings) + return embeddings diff --git a/benchmarks/DASB/VoiceBank/enhancement/voicebank_prepare.py b/benchmarks/DASB/VoiceBank/enhancement/voicebank_prepare.py new file mode 120000 index 000000000..fe458de38 --- /dev/null +++ b/benchmarks/DASB/VoiceBank/enhancement/voicebank_prepare.py @@ -0,0 +1 @@ +../voicebank_prepare.py \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/extraction/extract.py b/benchmarks/DASB/VoiceBank/extraction/extract.py new file mode 100644 index 000000000..8f88ad070 --- /dev/null +++ b/benchmarks/DASB/VoiceBank/extraction/extract.py @@ -0,0 +1,105 @@ +#!/usr/bin/env/python3 + +"""Recipe for extracting a discrete tokens with VoiceBank. + +Authors + * Jarod Duret 2024 + * Luca Della Libera 2024 +""" + +import os +import sys +import logging +import pathlib as pl +import speechbrain as sb +from speechbrain.dataio.dataset import DynamicItemDataset +from speechbrain.utils.distributed import run_on_main +from hyperpyyaml import load_hyperpyyaml + +base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +sys.path.append(base_dir) + +print(base_dir) + +logger = logging.getLogger(__name__) + + +if __name__ == "__main__": + # CLI: + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # Dataset prep (parsing voicebank) + from voicebank_prepare import prepare_voicebank # noqa + + # multi-gpu (ddp) save data preparation + os.makedirs(hparams["save_folder"], exist_ok=True) + run_on_main( + prepare_voicebank, + kwargs={ + "data_folder": hparams["data_folder"], + "save_folder": hparams["save_folder"], + "splits": hparams["splits"], + "num_valid_speakers": hparams["num_valid_speakers"], + }, + ) + + tokens_extractor_in = hparams["tokens_extractor_in"] + tokens_extractor_out = hparams["tokens_extractor_out"] + data_folder = hparams["data_folder"] + + datasets = [] + for csv_path in [hparams["train_csv"], hparams["valid_csv"], hparams["test_csv"]]: + name = pl.Path(csv_path).stem + dataset = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=csv_path, replacements={"DATA_ROOT": data_folder}, + ) + datasets.append(dataset) + + merged_data = { + key: value + for dataset in datasets + for key, value in dataset.data.items() + } + merged_dataset = DynamicItemDataset(merged_data) + + save_folder = pl.Path(hparams["save_folder"]) + logger.info("Extracting dataset input tokens ...") + tokens_extractor_in.extract_tokens( + merged_dataset, + hparams["num_codebooks"], + (save_folder / "input").as_posix(), + ) + + if hparams["save_embedding"]: + save_folder = pl.Path(hparams["save_folder"]) + logger.info(f"Saving embeddings ...") + tokens_extractor_in.save_pretrained_embeddings( + (save_folder / "embeddings" / "input").as_posix(), + vocab_size=hparams["vocab_size"], + num_codebooks=hparams["num_codebooks"], + ) + + logger.info("Extracting dataset output tokens ...") + tokens_extractor_out.extract_tokens( + merged_dataset, + hparams["num_codebooks"], + (save_folder / "output").as_posix(), + ) + + if hparams["save_embedding"]: + save_folder = pl.Path(hparams["save_folder"]) + logger.info(f"Saving embeddings ...") + tokens_extractor_out.save_pretrained_embeddings( + (save_folder / "embeddings" / "output").as_posix(), + vocab_size=hparams["vocab_size"], + num_codebooks=hparams["num_codebooks"], + ) diff --git a/benchmarks/DASB/VoiceBank/extraction/hparams/dac.yaml b/benchmarks/DASB/VoiceBank/extraction/hparams/dac.yaml new file mode 100644 index 000000000..3189fc514 --- /dev/null +++ b/benchmarks/DASB/VoiceBank/extraction/hparams/dac.yaml @@ -0,0 +1,70 @@ +# ############################################################################ +# Audio Tokenizer: DAC +# Extraction: VoiceBank +# Authors: Jarod Duret 2024, Luca Della Libera 2024 +# ############################################################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/dac +save_folder: !ref /save +train_log: !ref /extraction_log.txt + +# Data files +data_folder: !PLACEHOLDER +train_csv: !ref /trainset_28spk_wav.csv +valid_csv: !ref /validset_wav.csv +test_csv: !ref /testset_wav.csv +splits: [trainset_28spk_wav, validset_wav, testset_wav] +num_valid_speakers: 2 +cached_data_folder: !ref # e.g., path/to/cache + +batch_size: 1 +num_workers: 8 +src_key: noisy_wav +tgt_key: clean_wav +id_key: id + +# Dataloader options +dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +####################### Model parameters ########################### +# Tokenizer parameters +# DAC parameters +# model_type: [16khz, 24khz, 44khz, 44khz] +# vocab_size: [1024, 1024, 1024, 1024] +# model_bitrate: [8kbps, 8kbps, 8kbps, 16kbps] +# max_num_codebooks: [12, 32, 9, 18] +# embedding_dim: [1024, 1024, 1024, 128] +model_type: 24khz +vocab_size: 1024 +model_bitrate: 8kbps +num_codebooks: 32 +sample_rate: 24000 +# Feature parameters +encoder_dim: 1024 +save_embedding: False + +tokenizer: !new:utils.tokenizer_interface.DACTokenizer + model_type: !ref + model_bitrate: !ref + load_pretrained: True + tag: latest + +tokens_extractor_in: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref + +tokens_extractor_out: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref diff --git a/benchmarks/DASB/VoiceBank/extraction/hparams/discrete_ssl.yaml b/benchmarks/DASB/VoiceBank/extraction/hparams/discrete_ssl.yaml new file mode 100644 index 000000000..9c2821c43 --- /dev/null +++ b/benchmarks/DASB/VoiceBank/extraction/hparams/discrete_ssl.yaml @@ -0,0 +1,112 @@ +# ############################################################################ +# Audio Tokenizer: WavLM +# Extraction: VoiceBank +# Authors: Jarod Duret 2024, Luca Della Libera 2024 +# ############################################################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/wavlm +save_folder: !ref /save +train_log: !ref /extraction_log.txt + +# Data files +data_folder: !PLACEHOLDER +train_csv: !ref /trainset_28spk_wav.csv +valid_csv: !ref /validset_wav.csv +test_csv: !ref /testset_wav.csv +splits: [trainset_28spk_wav, validset_wav, testset_wav] +num_valid_speakers: 2 +cached_data_folder: !ref # e.g., path/to/cache + +batch_size: 1 +num_workers: 8 +src_key: noisy_wav +tgt_key: clean_wav +id_key: id + +# Dataloader options +dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +### Configuration for discrete SSL model +# | SSL Model | HF Encoder | K-Means Dataset | K-Means Size | SSL Layers | Vocoder Model | +# |------------|----------------------------------------|-----------------|--------------|----------------------|---------------------------------------------| +# | WavLM | microsoft/wavlm-large | LibriSpeech960 | 1000 | 1, 3, 7, 12, 18, 23 | speechbrain/hifigan-wavlm-k1000-LibriTTS | +# | HuBERT | facebook/hubert-large-ll60k | LibriSpeech960 | 1000 | 1, 3, 7, 12, 18, 23 | speechbrain/hifigan-hubert-k1000-LibriTTS | +# | Wav2Vec2 | facebook/wav2vec2-large | LibriSpeech960 | 1000 | 1, 3, 7, 12, 18, 23 | speechbrain/hifigan-wav2vec2-k1000-LibriTTS | + + +# ssl_model_type: HuBERT, WavLM, Wav2Vec2 +# ssl_hub: facebook/hubert-large-ll60k, microsoft/wavlm-large, facebook/wav2vec2-large +ssl_model_type: WavLM +ssl_folder: !ref /ssl_checkpoint +kmeans_cache_dir: !ref /kmeans_checkpoint +kmeans_dataset: LibriSpeech +vocab_size: 1000 +save_embedding: False + +### Config for Tokenizer +# Layer number should be among the supported layers for discrete SSL models(kmenas model should be available for that layer) +num_codebooks: [1, 3, 7, 12, 18, 23] +deduplicate: [False, False, False, False, False, False] +bpe_tokenizer_path: [null, null, null, null, null, null] +sample_rate: 16000 +encoder_dim: 1024 + +tokenizer: !apply:speechbrain.utils.hparams.choice + value: !ref + choices: + WavLM: !new:utils.tokenizer_interface.DiscreteSSLTokenizer + save_path: !ref + ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM + source: microsoft/wavlm-large + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref + vocoder_repo_id: speechbrain/hifigan-wavlm-k1000-LibriTTS + kmeans_dataset: !ref + num_clusters: !ref + HuBERT: !new:utils.tokenizer_interface.DiscreteSSLTokenizer + save_path: !ref + ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT + source: facebook/hubert-large-ll60k + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref + vocoder_repo_id: speechbrain/hifigan-hubert-k1000-LibriTTS + kmeans_dataset: !ref + num_clusters: !ref + Wav2Vec2: !new:utils.tokenizer_interface.DiscreteSSLTokenizer + save_path: !ref + ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 + source: facebook/wav2vec2-large + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref + vocoder_repo_id: speechbrain/hifigan-wav2vec2-k1000-LibriTTS + kmeans_dataset: !ref + num_clusters: !ref + +tokens_extractor_in: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref + +tokens_extractor_out: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref diff --git a/benchmarks/DASB/VoiceBank/extraction/hparams/encodec.yaml b/benchmarks/DASB/VoiceBank/extraction/hparams/encodec.yaml new file mode 100644 index 000000000..728c33392 --- /dev/null +++ b/benchmarks/DASB/VoiceBank/extraction/hparams/encodec.yaml @@ -0,0 +1,67 @@ +# ############################################################################ +# Audio Tokenizer: EnCodec +# Extraction: VoiceBank +# Authors: Jarod Duret 2024, Luca Della Libera 2024 +# ############################################################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/encodec +save_folder: !ref /save +train_log: !ref /extraction_log.txt + +# Data files +data_folder: !PLACEHOLDER +train_csv: !ref /trainset_28spk_wav.csv +valid_csv: !ref /validset_wav.csv +test_csv: !ref /testset_wav.csv +splits: [trainset_28spk_wav, validset_wav, testset_wav] +num_valid_speakers: 2 +cached_data_folder: !ref # e.g., path/to/cache + +batch_size: 1 +num_workers: 8 +src_key: noisy_wav +tgt_key: clean_wav +id_key: id + +# Dataloader options +dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +# EnCodec parameters +# sample_rate: [24000, 24000, 24000, 24000] +# vocab_size: [1024, 1024, 1024, 1024] +# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0] +# num_codebooks: [2, 4, 8, 16, 32] +bandwidth: 24.0 +num_codebooks: 32 +vocab_size: 1024 +sample_rate: 24000 +save_embedding: False + +tokenizer: !new:utils.tokenizer_interface.EncodecTokenizer + source: facebook/encodec_24khz # Only the 24kHz version supports mono audio + save_path: !ref + sample_rate: !ref + bandwidth: !ref + flat_embeddings: False + freeze: True + renorm_embeddings: False + +tokens_extractor_in: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref + +tokens_extractor_out: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref diff --git a/benchmarks/DASB/VoiceBank/extraction/hparams/mimi.yaml b/benchmarks/DASB/VoiceBank/extraction/hparams/mimi.yaml new file mode 100644 index 000000000..5480fd2cb --- /dev/null +++ b/benchmarks/DASB/VoiceBank/extraction/hparams/mimi.yaml @@ -0,0 +1,61 @@ +# ############################################################################ +# Audio Tokenizer: Mimi +# Extraction: VoiceBank +# Authors: Jarod Duret 2024, Luca Della Libera 2024 +# ############################################################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/mimi +save_folder: !ref /save +train_log: !ref /extraction_log.txt + +# Data files +data_folder: !PLACEHOLDER +train_csv: !ref /trainset_28spk_wav.csv +valid_csv: !ref /validset_wav.csv +test_csv: !ref /testset_wav.csv +splits: [trainset_28spk_wav, validset_wav, testset_wav] +num_valid_speakers: 2 +cached_data_folder: !ref # e.g., path/to/cache + +batch_size: 1 +num_workers: 8 +src_key: noisy_wav +tgt_key: clean_wav +id_key: id + +# Dataloader options +dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +####################### Model parameters ########################### +# Tokenizer parameters +model_hub: kyutai/mimi +vocab_size: 2048 +num_codebooks: 32 +sample_rate: 24000 +save_embedding: False + +tokenizer: !new:utils.tokenizer_interface.MimiTokenizer + source: !ref + save_path: !ref + num_codebooks: !ref + sample_rate: !ref + +tokens_extractor_in: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref + +tokens_extractor_out: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref diff --git a/benchmarks/DASB/VoiceBank/extraction/hparams/speech_tokenizer.yaml b/benchmarks/DASB/VoiceBank/extraction/hparams/speech_tokenizer.yaml new file mode 100644 index 000000000..7762b40c2 --- /dev/null +++ b/benchmarks/DASB/VoiceBank/extraction/hparams/speech_tokenizer.yaml @@ -0,0 +1,59 @@ +# ############################################################################ +# Audio Tokenizer: Speech Tokenizer +# Extraction: VoiceBank +# Authors: Jarod Duret 2024, Luca Della Libera 2024 +# ############################################################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/speech_tokenizer +save_folder: !ref /save +train_log: !ref /extraction_log.txt + +# Data files +data_folder: !PLACEHOLDER +train_csv: !ref /trainset_28spk_wav.csv +valid_csv: !ref /validset_wav.csv +test_csv: !ref /testset_wav.csv +splits: [trainset_28spk_wav, validset_wav, testset_wav] +num_valid_speakers: 2 +cached_data_folder: !ref # e.g., path/to/cache + +batch_size: 1 +num_workers: 8 +src_key: noisy_wav +tgt_key: clean_wav +id_key: id + +# Dataloader options +dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +vocab_size: 1024 +num_codebooks: 8 +sample_rate: 16000 +encoder_dim: 1024 +freeze_embedding: False +save_embedding: False + +tokenizer: !new:utils.tokenizer_interface.SpeechTokenizerWrapper + source: fnlp/SpeechTokenizer # Only the 24kHz version supports mono audio + save_path: !ref + sample_rate: !ref + +tokens_extractor_in: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref + +tokens_extractor_out: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref diff --git a/benchmarks/DASB/VoiceBank/extraction/hparams/sqcodec.yaml b/benchmarks/DASB/VoiceBank/extraction/hparams/sqcodec.yaml new file mode 100644 index 000000000..27b28f012 --- /dev/null +++ b/benchmarks/DASB/VoiceBank/extraction/hparams/sqcodec.yaml @@ -0,0 +1,61 @@ +# ############################################################################ +# Audio Tokenizer: SQCodec +# Extraction: VoiceBank +# Authors: Jarod Duret 2024, Luca Della Libera 2024 +# ############################################################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/sqcodec +save_folder: !ref /save +train_log: !ref /extraction_log.txt + +# Data files +data_folder: !PLACEHOLDER +train_csv: !ref /trainset_28spk_wav.csv +valid_csv: !ref /validset_wav.csv +test_csv: !ref /testset_wav.csv +splits: [trainset_28spk_wav, validset_wav, testset_wav] +num_valid_speakers: 2 +cached_data_folder: !ref # e.g., path/to/cache + +batch_size: 1 +num_workers: 8 +src_key: noisy_wav +tgt_key: clean_wav +id_key: id + +# Dataloader options +dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +# EnCodec parameters +config: config.yaml +checkpoint: ckpt_00190000.pth +sample_rate: 16000 +save_embedding: False +num_codebooks: 4 +tokenizer_save_path: !PLACEHOLDER + +tokenizer: !new:utils.tokenizer_interface.SQCodecTokenizer + save_path: !ref + checkpoint: !ref + config: !ref + sample_rate: !ref + +tokens_extractor_in: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref + +tokens_extractor_out: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref diff --git a/benchmarks/DASB/VoiceBank/extraction/hparams/wavtokenizer.yaml b/benchmarks/DASB/VoiceBank/extraction/hparams/wavtokenizer.yaml new file mode 100644 index 000000000..2ffdce696 --- /dev/null +++ b/benchmarks/DASB/VoiceBank/extraction/hparams/wavtokenizer.yaml @@ -0,0 +1,64 @@ +# ############################################################################ +# Audio Tokenizer: WavTokenizer +# Extraction: VoiceBank +# Authors: Jarod Duret 2024, Luca Della Libera 2024 +# ############################################################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/wavtokenizer +save_folder: !ref /save +train_log: !ref /extraction_log.txt + +# Data files +data_folder: !PLACEHOLDER +train_csv: !ref /trainset_28spk_wav.csv +valid_csv: !ref /validset_wav.csv +test_csv: !ref /testset_wav.csv +splits: [trainset_28spk_wav, validset_wav, testset_wav] +num_valid_speakers: 2 +cached_data_folder: !ref # e.g., path/to/cache + +batch_size: 1 +num_workers: 8 +src_key: noisy_wav +tgt_key: clean_wav +id_key: id + +# Dataloader options +dataloader_opts: + batch_size: !ref + shuffle: True + num_workers: !ref + +# EnCodec parameters +model_hub: novateur/WavTokenizer-medium-music-audio-75token +config: wavtokenizer_mediumdata_music_audio_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml +checkpoint: wavtokenizer_medium_music_audio_320_24k_v2.ckpt +sample_rate: 24000 +save_embedding: False +num_codebooks: 1 +vocab_size: 4096 + +tokenizer: !new:utils.tokenizer_interface.WavTokenizerWrapper + source: !ref + save_path: !ref + checkpoint: !ref + config: !ref + sample_rate: !ref + freeze: True + +tokens_extractor_in: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref + +tokens_extractor_out: !new:utils.tokens.TokensExtractor + tokenizer: !ref + sample_rate: !ref + src_key: !ref + id_key: !ref + dataloader_opts: !ref diff --git a/benchmarks/DASB/VoiceBank/extraction/voicebank_prepare.py b/benchmarks/DASB/VoiceBank/extraction/voicebank_prepare.py new file mode 120000 index 000000000..fe458de38 --- /dev/null +++ b/benchmarks/DASB/VoiceBank/extraction/voicebank_prepare.py @@ -0,0 +1 @@ +../voicebank_prepare.py \ No newline at end of file From fc788badeee1cde94ac28fddf4af6753cc79b527 Mon Sep 17 00:00:00 2001 From: Luca Della Libera Date: Thu, 10 Jul 2025 10:48:25 -0400 Subject: [PATCH 3/5] precommit --- .../DASB/VoiceBank/enhancement/common.py | 9 +- .../enhancement/hparams/CRDNN/train.yaml | 283 ------------------ .../enhancement/hparams/CRDNN/train_dac.yaml | 15 +- .../hparams/CRDNN/train_encodec.yaml | 5 +- .../hparams/CRDNN/train_hubert.yaml | 17 +- .../enhancement/hparams/CRDNN/train_mimi.yaml | 7 +- .../hparams/CRDNN/train_speech_tokenizer.yaml | 7 +- .../hparams/CRDNN/train_sqcodec.yaml | 7 +- .../hparams/CRDNN/train_wav2vec2.yaml | 17 +- .../hparams/CRDNN/train_wavlm.yaml | 17 +- .../hparams/CRDNN/train_wavtokenizer.yaml | 7 +- .../enhancement/hparams/Conformer/train.yaml | 273 ----------------- .../hparams/Conformer/train_dac.yaml | 15 +- .../hparams/Conformer/train_encodec.yaml | 5 +- .../hparams/Conformer/train_hubert.yaml | 17 +- .../hparams/Conformer/train_mimi.yaml | 7 +- .../Conformer/train_speech_tokenizer.yaml | 7 +- .../hparams/Conformer/train_sqcodec.yaml | 4 +- .../hparams/Conformer/train_wav2vec2.yaml | 17 +- .../hparams/Conformer/train_wavlm.yaml | 17 +- .../hparams/Conformer/train_wavtokenizer.yaml | 5 +- .../VoiceBank/enhancement/metrics/spk_sim.py | 1 + .../VoiceBank/enhancement/model/sq_codec.py | 4 +- .../DASB/VoiceBank/enhancement/train.py | 14 +- .../enhancement/train_continuous_ssl.py | 1 + .../DASB/VoiceBank/extraction/extract.py | 6 +- 26 files changed, 129 insertions(+), 655 deletions(-) delete mode 100644 benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train.yaml delete mode 100644 benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train.yaml diff --git a/benchmarks/DASB/VoiceBank/enhancement/common.py b/benchmarks/DASB/VoiceBank/enhancement/common.py index 227b839fe..5e8952ff2 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/common.py +++ b/benchmarks/DASB/VoiceBank/enhancement/common.py @@ -7,7 +7,6 @@ import speechbrain as sb import torch import torchaudio -from speechbrain.dataio.dataio import merge_csvs from transformers.models.hubert.modeling_hubert import ( HubertEncoderStableLayerNorm, ) @@ -142,10 +141,14 @@ def dataio_prepare( num_codebooks = hparams["num_codebooks"] def toks_pipeline(id): - in_toks = tokens_loader_in.tokens_by_uttid(id, num_codebooks=num_codebooks) + in_toks = tokens_loader_in.tokens_by_uttid( + id, num_codebooks=num_codebooks + ) yield in_toks - out_toks = tokens_loader_out.tokens_by_uttid(id, num_codebooks=num_codebooks) + out_toks = tokens_loader_out.tokens_by_uttid( + id, num_codebooks=num_codebooks + ) yield out_toks sb.dataio.dataset.add_dynamic_item( diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train.yaml deleted file mode 100644 index d7b663ad6..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train.yaml +++ /dev/null @@ -1,283 +0,0 @@ -# ########################################################################################### -# Model: CRDNN with discrete audio representations -# Authors: Luca Della Libera 2024 -# ########################################################################################### - -run_name: encodec - -# Seed needs to be set at top of YAML -seed: 0 -__set_seed: !apply:torch.manual_seed [!ref ] - -# Data preparation -data_folder: !PLACEHOLDER -tokens_folder: !PLACEHOLDER -cached_data_folder: !ref # e.g., path/to/cache -train_csv: !ref /trainset_28spk_wav.csv -valid_csv: !ref /validset_wav.csv -test_csv: !ref /testset_wav.csv -splits: [trainset_28spk_wav, validset_wav, testset_wav] -num_valid_speakers: 2 -testing: True - -# Output folders -output_folder: !ref results/CRDNN// -save_folder: !ref /save -cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE - -# Save options -compute_metrics: False -save_audios: False - -# Preprocessing parameters -train_remove_if_longer: 60.0 # Seconds -valid_remove_if_longer: 60.0 # Seconds -test_remove_if_longer: 60.0 # Seconds -sorting: random - -# Training parameters -num_epochs: 50 -grad_accumulation_factor: 1 -train_batch_size: 16 -valid_batch_size: 1 -test_batch_size: 1 -dataloader_workers: 4 -nonfinite_patience: 10 -max_grad_norm: 0.001 -precision: fp16 -ckpt_interval_minutes: 6000 -keep_checkpoints: 1 - -# Optimizer parameters -lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.0005)" -weight_decay: 0.01 -improvement_threshold: 0.0025 -annealing_factor: 0.9 -patient: 1 - -# Codec parameters -codec_type: encodec -sample_rate: 24000 # Should match the tokenizer's sample rate -vocab_size: 1024 -num_codebooks: 2 -bandwidth: 24.0 -kmeans_dataset: LibriSpeech -SSL_layers: [1, 3] #[1, 3, 7, 12, 18, 23] -tokenizer_save_path: /home/luca/Downloads/SQ-Codec/sqcodec - -# Embedding parameters -embedding_dim: 1024 -freeze_embedding: False - -# Encoder parameters -dropout: 0.1 -activation: !name:torch.nn.LeakyReLU -rnn_class: !name:speechbrain.nnet.RNN.LSTM -rnn_layers: 4 # @orion_step1: --rnn_layers~"uniform(1, 4,discrete=True)" -time_pooling_size: 1 -rnn_bidirectional: True -rnn_neurons: 256 -dnn_blocks: 2 -dnn_neurons: 256 -cnn_blocks: 2 -cnn_channels: (12, 12) -inter_layer_pooling_size: (2, 2) -cnn_kernelsize: (3, 3) - -# Modules -embedding: !new:custom_model.Discrete_EmbeddingLayer - num_codebooks: !ref - vocab_size: !ref - emb_dim: !ref - freeze: !ref - -attention_mlp: !new:custom_model.AttentionMLP - input_dim: !ref - hidden_dim: !ref - -encoder: !new:speechbrain.lobes.models.CRDNN.CRDNN - input_shape: [null, null, !ref ] - activation: !ref - dropout: !ref - cnn_blocks: !ref - cnn_channels: !ref - cnn_kernelsize: !ref - inter_layer_pooling_size: !ref - time_pooling: True - using_2d_pooling: False - time_pooling_size: !ref - rnn_class: !ref - rnn_layers: !ref - rnn_neurons: !ref - rnn_bidirectional: !ref - dnn_blocks: !ref - dnn_neurons: !ref - rnn_re_init: False - use_rnnp: False - -head: !new:torch.nn.Linear - in_features: !ref - out_features: !ref * - -modules: - embedding: !ref - attention_mlp: !ref - encoder: !ref - head: !ref - -model: !new:torch.nn.ModuleList - [[!ref , - !ref , - !ref , - !ref ]] - -# Loss functions -ce_loss: !name:speechbrain.nnet.losses.nll_loss - label_smoothing: 0.0 - allowed_len_diff: 0 - reduction: mean - -# Optimizers -opt_class: !name:torch.optim.AdamW - lr: !ref - betas: (0.9, 0.98) - eps: 1.e-8 - weight_decay: !ref - -# Schedulers -scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler - initial_value: !ref - improvement_threshold: !ref - annealing_factor: !ref - patient: !ref - -# Token loaders -tokens_loader_in: !new:utils.tokens.TokensLoader - data_path: !ref /input - -tokens_loader_out: !new:utils.tokens.TokensLoader - data_path: !ref /output - -# Codec -codec: !apply:speechbrain.utils.hparams.choice - value: !ref - choices: - dac: !new:utils.tokenizer_interface.DACTokenizer - model_type: 24khz - model_bitrate: 8kbps - load_pretrained: True - tag: latest - encodec: !new:utils.tokenizer_interface.EncodecTokenizer - source: facebook/encodec_24khz - save_path: !ref - sample_rate: !ref - bandwidth: !ref - flat_embeddings: False - freeze: True - renorm_embeddings: False - mimi: !new:utils.tokenizer_interface.MimiTokenizer - source: kyutai/mimi - save_path: !ref - num_codebooks: !ref - sample_rate: !ref - speech_tokenizer: !new:utils.tokenizer_interface.SpeechTokenizerWrapper - source: fnlp/SpeechTokenizer - save_path: !ref - sample_rate: !ref - sqcodec: !new:utils.tokenizer_interface.SQCodecTokenizer - save_path: !ref - checkpoint: ckpt_00190000.pth - config: config.yaml - sample_rate: !ref - wavtokenizer: !new:utils.tokenizer_interface.WavTokenizerWrapper - source: novateur/WavTokenizer-medium-music-audio-75token - save_path: !ref - checkpoint: wavtokenizer_medium_music_audio_320_24k_v2.ckpt - config: wavtokenizer_mediumdata_music_audio_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml - sample_rate: !ref - freeze: True - wavlm: !new:utils.tokenizer_interface.DiscreteSSLTokenizer - save_path: !ref - ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM - source: microsoft/wavlm-large - output_norm: False - freeze: True - freeze_feature_extractor: True - output_all_hiddens: True - save_path: !ref - vocoder_repo_id: speechbrain/hifigan-wavlm-k1000-LibriTTS - kmeans_dataset: !ref - num_clusters: !ref - hubert: !new:utils.tokenizer_interface.DiscreteSSLTokenizer - save_path: !ref - ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT - source: facebook/hubert-large-ll60k - output_norm: False - freeze: True - freeze_feature_extractor: True - output_all_hiddens: True - save_path: !ref - vocoder_repo_id: speechbrain/hifigan-hubert-k1000-LibriTTS - kmeans_dataset: !ref - num_clusters: !ref - wav2vec2: !new:utils.tokenizer_interface.DiscreteSSLTokenizer - save_path: !ref - ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 - source: facebook/wav2vec2-large - output_norm: False - freeze: True - freeze_feature_extractor: True - output_all_hiddens: True - save_path: !ref - vocoder_repo_id: speechbrain/hifigan-wav2vec2-k1000-LibriTTS - kmeans_dataset: !ref - num_clusters: !ref - -# Dataloaders -train_dataloader_kwargs: - batch_size: !ref - num_workers: !ref - pin_memory: True - shuffle: !apply:str.__eq__ [!ref , random] - -valid_dataloader_kwargs: - batch_size: !ref - num_workers: !ref - pin_memory: True - -test_dataloader_kwargs: - batch_size: !ref - num_workers: !ref - pin_memory: True - -# Performance metrics -ter_computer: !name:speechbrain.utils.metric_stats.MetricStats - metric: !name:speechbrain.nnet.losses.classification_error - reduction: batch - -dnsmos_computer: !name:metrics.dnsmos.DNSMOS - sample_rate: !ref - -dwer_computer: !name:metrics.dwer.DWER - model_hub: openai/whisper-small - save_path: !ref - sample_rate: !ref - -wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM - model_hub: microsoft/wavlm-base-sv - save_path: !ref - sample_rate: !ref - -# Counters, checkpointers, loggers, etc. -epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter - limit: !ref - -checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer - checkpoints_dir: !ref - recoverables: - model: !ref - scheduler: !ref - counter: !ref - -train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_dac.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_dac.yaml index 8d3016ef1..7766388d3 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_dac.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_dac.yaml @@ -148,17 +148,18 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler # Token loaders tokens_loader_in: !new:utils.tokens.TokensLoader - data_path: !ref /input + data_path: !ref /input tokens_loader_out: !new:utils.tokens.TokensLoader - data_path: !ref /output + data_path: !ref /output + # Codec codec: !name:utils.tokenizer_interface.DACTokenizer - model_type: 24khz - model_bitrate: 8kbps - load_pretrained: True - tag: latest + model_type: 24khz + model_bitrate: 8kbps + load_pretrained: True + tag: latest # Dataloaders train_dataloader_kwargs: @@ -207,4 +208,4 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer counter: !ref train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref /train_log.txt \ No newline at end of file + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_encodec.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_encodec.yaml index 8d07d34c0..40f8f30f8 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_encodec.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_encodec.yaml @@ -149,10 +149,11 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler # Token loaders tokens_loader_in: !new:utils.tokens.TokensLoader - data_path: !ref /input + data_path: !ref /input tokens_loader_out: !new:utils.tokens.TokensLoader - data_path: !ref /output + data_path: !ref /output + # Codec codec: !name:utils.tokenizer_interface.EncodecTokenizer diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_hubert.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_hubert.yaml index d3c84071e..fb93e4309 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_hubert.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_hubert.yaml @@ -150,21 +150,22 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler # Token loaders tokens_loader_in: !new:utils.tokens.TokensLoader - data_path: !ref /input + data_path: !ref /input tokens_loader_out: !new:utils.tokens.TokensLoader - data_path: !ref /output + data_path: !ref /output + # Codec codec: !name:utils.tokenizer_interface.DiscreteSSLTokenizer save_path: !ref ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT - source: facebook/hubert-large-ll60k - output_norm: False - freeze: True - freeze_feature_extractor: True - output_all_hiddens: True - save_path: !ref + source: facebook/hubert-large-ll60k + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref vocoder_repo_id: speechbrain/hifigan-hubert-k1000-LibriTTS kmeans_dataset: !ref num_clusters: !ref diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_mimi.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_mimi.yaml index 63fe4513c..35ec7dee0 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_mimi.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_mimi.yaml @@ -148,10 +148,11 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler # Token loaders tokens_loader_in: !new:utils.tokens.TokensLoader - data_path: !ref /input + data_path: !ref /input tokens_loader_out: !new:utils.tokens.TokensLoader - data_path: !ref /output + data_path: !ref /output + # Codec codec: !name:utils.tokenizer_interface.MimiTokenizer @@ -207,4 +208,4 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer counter: !ref train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref /train_log.txt \ No newline at end of file + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_speech_tokenizer.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_speech_tokenizer.yaml index 8a95811d8..c6ec1ff0a 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_speech_tokenizer.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_speech_tokenizer.yaml @@ -148,10 +148,11 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler # Token loaders tokens_loader_in: !new:utils.tokens.TokensLoader - data_path: !ref /input + data_path: !ref /input tokens_loader_out: !new:utils.tokens.TokensLoader - data_path: !ref /output + data_path: !ref /output + # Codec codec: !name:utils.tokenizer_interface.SpeechTokenizerWrapper @@ -206,4 +207,4 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer counter: !ref train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref /train_log.txt \ No newline at end of file + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_sqcodec.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_sqcodec.yaml index 58c4adb17..441805f43 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_sqcodec.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_sqcodec.yaml @@ -149,10 +149,11 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler # Token loaders tokens_loader_in: !new:utils.tokens.TokensLoader - data_path: !ref /input + data_path: !ref /input tokens_loader_out: !new:utils.tokens.TokensLoader - data_path: !ref /output + data_path: !ref /output + # Codec codec: !name:utils.tokenizer_interface.SQCodecTokenizer @@ -208,4 +209,4 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer counter: !ref train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref /train_log.txt \ No newline at end of file + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wav2vec2.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wav2vec2.yaml index 977951f2e..5320d0f29 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wav2vec2.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wav2vec2.yaml @@ -150,21 +150,22 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler # Token loaders tokens_loader_in: !new:utils.tokens.TokensLoader - data_path: !ref /input + data_path: !ref /input tokens_loader_out: !new:utils.tokens.TokensLoader - data_path: !ref /output + data_path: !ref /output + # Codec codec: !name:utils.tokenizer_interface.DiscreteSSLTokenizer save_path: !ref ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 - source: facebook/wav2vec2-large - output_norm: False - freeze: True - freeze_feature_extractor: True - output_all_hiddens: True - save_path: !ref + source: facebook/wav2vec2-large + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref vocoder_repo_id: speechbrain/hifigan-wav2vec2-k1000-LibriTTS kmeans_dataset: !ref num_clusters: !ref diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wavlm.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wavlm.yaml index 59a19ae53..232c7ea8e 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wavlm.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wavlm.yaml @@ -150,21 +150,22 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler # Token loaders tokens_loader_in: !new:utils.tokens.TokensLoader - data_path: !ref /input + data_path: !ref /input tokens_loader_out: !new:utils.tokens.TokensLoader - data_path: !ref /output + data_path: !ref /output + # Codec codec: !name:utils.tokenizer_interface.DiscreteSSLTokenizer save_path: !ref ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM - source: microsoft/wavlm-large - output_norm: False - freeze: True - freeze_feature_extractor: True - output_all_hiddens: True - save_path: !ref + source: microsoft/wavlm-large + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref vocoder_repo_id: speechbrain/hifigan-wavlm-k1000-LibriTTS kmeans_dataset: !ref num_clusters: !ref diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wavtokenizer.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wavtokenizer.yaml index 3e5b50d00..917790412 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wavtokenizer.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/CRDNN/train_wavtokenizer.yaml @@ -151,10 +151,11 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler # Token loaders tokens_loader_in: !new:utils.tokens.TokensLoader - data_path: !ref /input + data_path: !ref /input tokens_loader_out: !new:utils.tokens.TokensLoader - data_path: !ref /output + data_path: !ref /output + # Codec codec: !name:utils.tokenizer_interface.WavTokenizerWrapper @@ -212,4 +213,4 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer counter: !ref train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref /train_log.txt \ No newline at end of file + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train.yaml deleted file mode 100644 index 3e37dddaa..000000000 --- a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train.yaml +++ /dev/null @@ -1,273 +0,0 @@ -# ########################################################################################### -# Model: Conformer with discrete audio representations -# Authors: Luca Della Libera 2024 -# ########################################################################################### - -run_name: encodec - -# Seed needs to be set at top of YAML -seed: 0 -__set_seed: !apply:torch.manual_seed [!ref ] - -# Data preparation -data_folder: !PLACEHOLDER -tokens_folder: !PLACEHOLDER -cached_data_folder: !ref # e.g., path/to/cache -train_csv: !ref /trainset_28spk_wav.csv -valid_csv: !ref /validset_wav.csv -test_csv: !ref /testset_wav.csv -splits: [trainset_28spk_wav, validset_wav, testset_wav] -num_valid_speakers: 2 -testing: True - -# Output folders -output_folder: !ref results/Conformer// -save_folder: !ref /save -cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE - -# Save options -compute_metrics: False -save_audios: False - -# Preprocessing parameters -train_remove_if_longer: 60.0 # Seconds -valid_remove_if_longer: 60.0 # Seconds -test_remove_if_longer: 60.0 # Seconds -sorting: random - -# Training parameters -num_epochs: 50 -grad_accumulation_factor: 1 -train_batch_size: 16 -valid_batch_size: 1 -test_batch_size: 1 -dataloader_workers: 4 -nonfinite_patience: 10 -max_grad_norm: 5.0 -precision: fp32 -ckpt_interval_minutes: 6000 -keep_checkpoints: 1 - -# Optimizer parameters -lr: 0.0005 # @orion_step1: --lr~"loguniform(0.00001,0.5)" -weight_decay: 0.01 -improvement_threshold: 0.0025 -annealing_factor: 0.9 -patient: 1 - -# Codec parameters -codec_type: encodec -sample_rate: 24000 # Should match the tokenizer's sample rate -vocab_size: 1024 -num_codebooks: 2 -bandwidth: 24.0 -kmeans_dataset: LibriSpeech -SSL_layers: [1, 3] #[1, 3, 7, 12, 18, 23] -tokenizer_save_path: /home/luca/Downloads/SQ-Codec/sqcodec - -# Embedding parameters -embedding_dim: 1024 -freeze_embedding: False - -# Encoder parameters -dropout: 0.1 -activation: !name:torch.nn.GELU -d_model: 256 -nhead: 4 -num_layers: 6 # @orion_step1: --num_layers~"uniform(1, 8,discrete=True)" -d_ffn: 2048 -max_length: 2000 -causal: False - -# Modules -embedding: !new:custom_model.Discrete_EmbeddingLayer - num_codebooks: !ref - vocab_size: !ref - emb_dim: !ref - freeze: !ref - -attention_mlp: !new:custom_model.AttentionMLP - input_dim: !ref - hidden_dim: !ref - -encoder: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR - input_size: !ref - tgt_vocab: -1 - d_model: !ref - nhead: !ref - num_encoder_layers: !ref - num_decoder_layers: 0 - d_ffn: !ref - dropout: !ref - activation: !ref - max_length: !ref - encoder_module: conformer - normalize_before: True - causal: !ref - -head: !new:torch.nn.Linear - in_features: !ref - out_features: !ref * - -modules: - embedding: !ref - attention_mlp: !ref - encoder: !ref - head: !ref - -model: !new:torch.nn.ModuleList - [[!ref , - !ref , - !ref , - !ref ]] - -# Loss functions -ce_loss: !name:speechbrain.nnet.losses.nll_loss - label_smoothing: 0.0 - allowed_len_diff: 0 - reduction: mean - -# Optimizers -opt_class: !name:torch.optim.AdamW - lr: !ref - betas: (0.9, 0.98) - eps: 1.e-8 - weight_decay: !ref - -# Schedulers -scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler - initial_value: !ref - improvement_threshold: !ref - annealing_factor: !ref - patient: !ref - -# Token loaders -tokens_loader_in: !new:utils.tokens.TokensLoader - data_path: !ref /input - -tokens_loader_out: !new:utils.tokens.TokensLoader - data_path: !ref /output - -# Codec -codec: !apply:speechbrain.utils.hparams.choice - value: !ref - choices: - dac: !new:utils.tokenizer_interface.DACTokenizer - model_type: 24khz - model_bitrate: 8kbps - load_pretrained: True - tag: latest - encodec: !new:utils.tokenizer_interface.EncodecTokenizer - source: facebook/encodec_24khz - save_path: !ref - sample_rate: !ref - bandwidth: !ref - flat_embeddings: False - freeze: True - renorm_embeddings: False - mimi: !new:utils.tokenizer_interface.MimiTokenizer - source: kyutai/mimi - save_path: !ref - num_codebooks: !ref - sample_rate: !ref - speech_tokenizer: !new:utils.tokenizer_interface.SpeechTokenizerWrapper - source: fnlp/SpeechTokenizer - save_path: !ref - sample_rate: !ref - sqcodec: !new:utils.tokenizer_interface.SQCodecTokenizer - save_path: !ref - checkpoint: ckpt_00190000.pth - config: config.yaml - sample_rate: !ref - wavtokenizer: !new:utils.tokenizer_interface.WavTokenizerWrapper - source: novateur/WavTokenizer-medium-music-audio-75token - save_path: !ref - checkpoint: wavtokenizer_medium_music_audio_320_24k_v2.ckpt - config: wavtokenizer_mediumdata_music_audio_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml - sample_rate: !ref - freeze: True - wavlm: !new:utils.tokenizer_interface.DiscreteSSLTokenizer - save_path: !ref - ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM - source: microsoft/wavlm-large - output_norm: False - freeze: True - freeze_feature_extractor: True - output_all_hiddens: True - save_path: !ref - vocoder_repo_id: speechbrain/hifigan-wavlm-k1000-LibriTTS - kmeans_dataset: !ref - num_clusters: !ref - hubert: !new:utils.tokenizer_interface.DiscreteSSLTokenizer - save_path: !ref - ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT - source: facebook/hubert-large-ll60k - output_norm: False - freeze: True - freeze_feature_extractor: True - output_all_hiddens: True - save_path: !ref - vocoder_repo_id: speechbrain/hifigan-hubert-k1000-LibriTTS - kmeans_dataset: !ref - num_clusters: !ref - wav2vec2: !new:utils.tokenizer_interface.DiscreteSSLTokenizer - save_path: !ref - ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 - source: facebook/wav2vec2-large - output_norm: False - freeze: True - freeze_feature_extractor: True - output_all_hiddens: True - save_path: !ref - vocoder_repo_id: speechbrain/hifigan-wav2vec2-k1000-LibriTTS - kmeans_dataset: !ref - num_clusters: !ref - -# Dataloaders -train_dataloader_kwargs: - batch_size: !ref - num_workers: !ref - pin_memory: True - shuffle: !apply:str.__eq__ [!ref , random] - -valid_dataloader_kwargs: - batch_size: !ref - num_workers: !ref - pin_memory: True - -test_dataloader_kwargs: - batch_size: !ref - num_workers: !ref - pin_memory: True - -# Performance metrics -ter_computer: !name:speechbrain.utils.metric_stats.MetricStats - metric: !name:speechbrain.nnet.losses.classification_error - reduction: batch - -dnsmos_computer: !name:metrics.dnsmos.DNSMOS - sample_rate: !ref - -dwer_computer: !name:metrics.dwer.DWER - model_hub: openai/whisper-small - save_path: !ref - sample_rate: !ref - -wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM - model_hub: microsoft/wavlm-base-sv - save_path: !ref - sample_rate: !ref - -# Counters, checkpointers, loggers, etc. -epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter - limit: !ref - -checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer - checkpoints_dir: !ref - recoverables: - model: !ref - scheduler: !ref - counter: !ref - -train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref /train_log.txt \ No newline at end of file diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_dac.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_dac.yaml index 902e16029..fe6a52aa0 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_dac.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_dac.yaml @@ -138,17 +138,18 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler # Token loaders tokens_loader_in: !new:utils.tokens.TokensLoader - data_path: !ref /input + data_path: !ref /input tokens_loader_out: !new:utils.tokens.TokensLoader - data_path: !ref /output + data_path: !ref /output + # Codec codec: !name:utils.tokenizer_interface.DACTokenizer - model_type: 24khz - model_bitrate: 8kbps - load_pretrained: True - tag: latest + model_type: 24khz + model_bitrate: 8kbps + load_pretrained: True + tag: latest # Dataloaders train_dataloader_kwargs: @@ -197,4 +198,4 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer counter: !ref train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref /train_log.txt \ No newline at end of file + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_encodec.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_encodec.yaml index 9244de0e7..f47b1fc36 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_encodec.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_encodec.yaml @@ -139,10 +139,11 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler # Token loaders tokens_loader_in: !new:utils.tokens.TokensLoader - data_path: !ref /input + data_path: !ref /input tokens_loader_out: !new:utils.tokens.TokensLoader - data_path: !ref /output + data_path: !ref /output + # Codec codec: !name:utils.tokenizer_interface.EncodecTokenizer diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_hubert.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_hubert.yaml index 96cdda32a..cf600f406 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_hubert.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_hubert.yaml @@ -140,21 +140,22 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler # Token loaders tokens_loader_in: !new:utils.tokens.TokensLoader - data_path: !ref /input + data_path: !ref /input tokens_loader_out: !new:utils.tokens.TokensLoader - data_path: !ref /output + data_path: !ref /output + # Codec codec: !name:utils.tokenizer_interface.DiscreteSSLTokenizer save_path: !ref ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT - source: facebook/hubert-large-ll60k - output_norm: False - freeze: True - freeze_feature_extractor: True - output_all_hiddens: True - save_path: !ref + source: facebook/hubert-large-ll60k + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref vocoder_repo_id: speechbrain/hifigan-hubert-k1000-LibriTTS kmeans_dataset: !ref num_clusters: !ref diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_mimi.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_mimi.yaml index f53e7f768..cfcf8fab9 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_mimi.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_mimi.yaml @@ -138,10 +138,11 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler # Token loaders tokens_loader_in: !new:utils.tokens.TokensLoader - data_path: !ref /input + data_path: !ref /input tokens_loader_out: !new:utils.tokens.TokensLoader - data_path: !ref /output + data_path: !ref /output + # Codec codec: !name:utils.tokenizer_interface.MimiTokenizer @@ -197,4 +198,4 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer counter: !ref train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref /train_log.txt \ No newline at end of file + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_speech_tokenizer.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_speech_tokenizer.yaml index 71d907b0d..efa5985c1 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_speech_tokenizer.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_speech_tokenizer.yaml @@ -138,10 +138,11 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler # Token loaders tokens_loader_in: !new:utils.tokens.TokensLoader - data_path: !ref /input + data_path: !ref /input tokens_loader_out: !new:utils.tokens.TokensLoader - data_path: !ref /output + data_path: !ref /output + # Codec codec: !name:utils.tokenizer_interface.SpeechTokenizerWrapper @@ -196,4 +197,4 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer counter: !ref train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger - save_file: !ref /train_log.txt \ No newline at end of file + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_sqcodec.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_sqcodec.yaml index c5d056c21..5e0fbdcb8 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_sqcodec.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_sqcodec.yaml @@ -139,10 +139,10 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler # Token loaders tokens_loader_in: !new:utils.tokens.TokensLoader - data_path: !ref /input + data_path: !ref /input tokens_loader_out: !new:utils.tokens.TokensLoader - data_path: !ref /output + data_path: !ref /output # Codec codec: !name:utils.tokenizer_interface.SQCodecTokenizer diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wav2vec2.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wav2vec2.yaml index dcca7e7ea..1304bbe11 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wav2vec2.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wav2vec2.yaml @@ -140,21 +140,22 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler # Token loaders tokens_loader_in: !new:utils.tokens.TokensLoader - data_path: !ref /input + data_path: !ref /input tokens_loader_out: !new:utils.tokens.TokensLoader - data_path: !ref /output + data_path: !ref /output + # Codec codec: !name:utils.tokenizer_interface.DiscreteSSLTokenizer save_path: !ref ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2 - source: facebook/wav2vec2-large - output_norm: False - freeze: True - freeze_feature_extractor: True - output_all_hiddens: True - save_path: !ref + source: facebook/wav2vec2-large + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref vocoder_repo_id: speechbrain/hifigan-wav2vec2-k1000-LibriTTS kmeans_dataset: !ref num_clusters: !ref diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wavlm.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wavlm.yaml index 02a598a80..a1115a221 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wavlm.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wavlm.yaml @@ -140,21 +140,22 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler # Token loaders tokens_loader_in: !new:utils.tokens.TokensLoader - data_path: !ref /input + data_path: !ref /input tokens_loader_out: !new:utils.tokens.TokensLoader - data_path: !ref /output + data_path: !ref /output + # Codec codec: !name:utils.tokenizer_interface.DiscreteSSLTokenizer save_path: !ref ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM - source: microsoft/wavlm-large - output_norm: False - freeze: True - freeze_feature_extractor: True - output_all_hiddens: True - save_path: !ref + source: microsoft/wavlm-large + output_norm: False + freeze: True + freeze_feature_extractor: True + output_all_hiddens: True + save_path: !ref vocoder_repo_id: speechbrain/hifigan-wavlm-k1000-LibriTTS kmeans_dataset: !ref num_clusters: !ref diff --git a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wavtokenizer.yaml b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wavtokenizer.yaml index cc4572519..fffa623b9 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wavtokenizer.yaml +++ b/benchmarks/DASB/VoiceBank/enhancement/hparams/Conformer/train_wavtokenizer.yaml @@ -141,10 +141,11 @@ scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler # Token loaders tokens_loader_in: !new:utils.tokens.TokensLoader - data_path: !ref /input + data_path: !ref /input tokens_loader_out: !new:utils.tokens.TokensLoader - data_path: !ref /output + data_path: !ref /output + # Codec codec: !name:utils.tokenizer_interface.WavTokenizerWrapper diff --git a/benchmarks/DASB/VoiceBank/enhancement/metrics/spk_sim.py b/benchmarks/DASB/VoiceBank/enhancement/metrics/spk_sim.py index 106732f26..a4e5649c8 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/metrics/spk_sim.py +++ b/benchmarks/DASB/VoiceBank/enhancement/metrics/spk_sim.py @@ -16,6 +16,7 @@ SAMPLE_RATE = 16000 + class SpkSimWavLM(MetricStats): def __init__(self, model_hub, save_path, sample_rate): self.sample_rate = sample_rate diff --git a/benchmarks/DASB/VoiceBank/enhancement/model/sq_codec.py b/benchmarks/DASB/VoiceBank/enhancement/model/sq_codec.py index 916820b96..ef1283b05 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/model/sq_codec.py +++ b/benchmarks/DASB/VoiceBank/enhancement/model/sq_codec.py @@ -124,7 +124,9 @@ def build_codec_model(self, config): exp_model_config = OmegaConf.load(config) scalar_codec = ScalarModel(**exp_model_config.generator.config) device = next(iter(scalar_codec.parameters())).device - parameter_dict = torch.load(self.ckpt_path, map_location=device, weights_only=False) + parameter_dict = torch.load( + self.ckpt_path, map_location=device, weights_only=False + ) scalar_codec.load_state_dict(parameter_dict["codec_model"]) return scalar_codec diff --git a/benchmarks/DASB/VoiceBank/enhancement/train.py b/benchmarks/DASB/VoiceBank/enhancement/train.py index 0f7186843..2440461fa 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/train.py +++ b/benchmarks/DASB/VoiceBank/enhancement/train.py @@ -26,7 +26,7 @@ def toks_to_sig(self, toks): # toks: [B, N, K] self.hparams.codec.to(self.device).eval() self.hparams.codec.device = self.device - if hasattr(self.hparams.codec, "codec_vocoder"): + if hasattr(self.hparams.codec, "codec_vocoder"): self.hparams.codec.codec_vocoder.device = self.device kwargs = {} if hasattr(self.hparams, "SSL_layers"): @@ -49,12 +49,14 @@ def compute_forward(self, batch, stage): # Forward encoder if hasattr(self.modules.encoder, "encode"): - hyp_embs = self.modules.encoder.encode(in_embs, in_lens) # [B, N, H] + hyp_embs = self.modules.encoder.encode( + in_embs, in_lens + ) # [B, N, H] else: abs_length = (in_embs.shape[1] * in_lens).ceil().long() for i in range(len(abs_length)): if abs_length[i] < in_embs.shape[1]: - in_embs[i, abs_length[i]:] = 0 + in_embs[i, abs_length[i] :] = 0 hyp_embs = self.modules.encoder(in_embs) # [B, N, H] # Forward head @@ -244,9 +246,9 @@ def on_stage_end(self, stage, stage_loss, epoch=None): ) # Log number of parameters/buffers - #codec_params = sum( + # codec_params = sum( # [x.numel() for x in hparams["codec"].state_dict().values()] - #) + # ) model_params = sum( [ x.numel() @@ -256,7 +258,7 @@ def on_stage_end(self, stage, stage_loss, epoch=None): ) hparams["train_logger"].log_stats( stats_meta={ - #f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", + # f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", }, ) diff --git a/benchmarks/DASB/VoiceBank/enhancement/train_continuous_ssl.py b/benchmarks/DASB/VoiceBank/enhancement/train_continuous_ssl.py index d95d24666..5acd2b666 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/train_continuous_ssl.py +++ b/benchmarks/DASB/VoiceBank/enhancement/train_continuous_ssl.py @@ -22,6 +22,7 @@ _CACHE = {} + # To use in configuration files def len_(SSL_layers, embedding_dim): return len(SSL_layers) * embedding_dim diff --git a/benchmarks/DASB/VoiceBank/extraction/extract.py b/benchmarks/DASB/VoiceBank/extraction/extract.py index 8f88ad070..346fae442 100644 --- a/benchmarks/DASB/VoiceBank/extraction/extract.py +++ b/benchmarks/DASB/VoiceBank/extraction/extract.py @@ -57,7 +57,11 @@ data_folder = hparams["data_folder"] datasets = [] - for csv_path in [hparams["train_csv"], hparams["valid_csv"], hparams["test_csv"]]: + for csv_path in [ + hparams["train_csv"], + hparams["valid_csv"], + hparams["test_csv"], + ]: name = pl.Path(csv_path).stem dataset = sb.dataio.dataset.DynamicItemDataset.from_csv( csv_path=csv_path, replacements={"DATA_ROOT": data_folder}, From eedebb740f8895de1b680cd31891d416b929705f Mon Sep 17 00:00:00 2001 From: Luca Della Libera Date: Wed, 23 Jul 2025 15:23:35 -0400 Subject: [PATCH 4/5] Update pre-commit --- .github/workflows/pre-commit.yml | 4 ++-- .pre-commit-config.yaml | 4 ++-- benchmarks/CL_MASR/analyze_logs.py | 2 +- benchmarks/CL_MASR/common_voice_prepare.py | 4 ++-- benchmarks/CL_MASR/wavlm/pretrain.py | 2 +- benchmarks/CL_MASR/wavlm/train_agem.py | 2 +- benchmarks/CL_MASR/wavlm/train_der.py | 2 +- benchmarks/CL_MASR/wavlm/train_er.py | 2 +- benchmarks/CL_MASR/wavlm/train_ewc.py | 2 +- benchmarks/CL_MASR/wavlm/train_ft.py | 2 +- benchmarks/CL_MASR/wavlm/train_joint.py | 4 ++-- benchmarks/CL_MASR/wavlm/train_l2p.py | 2 +- benchmarks/CL_MASR/wavlm/train_lwf.py | 2 +- benchmarks/CL_MASR/wavlm/train_mas.py | 2 +- benchmarks/CL_MASR/wavlm/train_pb.py | 2 +- benchmarks/CL_MASR/wavlm/train_pnn.py | 2 +- benchmarks/CL_MASR/whisper/model.py | 10 +++------- benchmarks/CL_MASR/whisper/train_agem.py | 2 +- benchmarks/CL_MASR/whisper/train_der.py | 2 +- benchmarks/CL_MASR/whisper/train_er.py | 2 +- benchmarks/CL_MASR/whisper/train_ewc.py | 2 +- benchmarks/CL_MASR/whisper/train_ft.py | 2 +- benchmarks/CL_MASR/whisper/train_joint.py | 4 ++-- benchmarks/CL_MASR/whisper/train_l2p.py | 2 +- benchmarks/CL_MASR/whisper/train_lwf.py | 2 +- benchmarks/CL_MASR/whisper/train_mas.py | 2 +- benchmarks/CL_MASR/whisper/train_pb.py | 2 +- benchmarks/CL_MASR/whisper/train_pnn.py | 2 +- benchmarks/DASB/IEMOCAP/iemocap_prepare.py | 2 +- .../separation/conformer/train_continuous_ssl.py | 2 +- .../DASB/Libri2Mix/separation/conformer/train_dac.py | 2 +- .../separation/conformer/train_discrete_ssl.py | 2 +- .../Libri2Mix/separation/conformer/train_encodec.py | 2 +- .../separation/conformer/train_speech_tokenizer.py | 2 +- .../Libri2Mix/separation/crdnn/train_continuous_ssl.py | 2 +- .../DASB/Libri2Mix/separation/crdnn/train_dac.py | 2 +- .../Libri2Mix/separation/crdnn/train_discrete_ssl.py | 2 +- .../DASB/Libri2Mix/separation/crdnn/train_encodec.py | 2 +- .../separation/crdnn/train_speech_tokenizer.py | 2 +- benchmarks/DASB/LibriSpeech/ASR-on-the-fly/train.py | 2 +- benchmarks/DASB/LibriSpeech/extraction/extract.py | 2 +- benchmarks/DASB/VoiceBank/enhancement/train.py | 2 +- .../DASB/VoiceBank/enhancement/train_continuous_ssl.py | 2 +- benchmarks/DASB/VoiceBank/extraction/extract.py | 4 ++-- benchmarks/MOABB/utils/parse_results.py | 2 +- benchmarks/MOABB/utils/prepare.py | 4 ++-- benchmarks/MP3S/IEMOCAP/ecapa_tdnn/iemocap_prepare.py | 2 +- tests/utils/recipe_tests.py | 4 ++-- 48 files changed, 58 insertions(+), 62 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 6724b2764..03d60e521 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -12,5 +12,5 @@ jobs: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 with: - python-version: '3.8' - - uses: pre-commit/action@v2.0.3 + python-version: '3.12' + - uses: pre-commit/action@v3.0.1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 483db5ecc..01d9da4bc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.5.0 # Use the ref you want to point at + rev: v5.0.0 # Use the ref you want to point at hooks: - id: trailing-whitespace types: [file, text] @@ -21,7 +21,7 @@ repos: types: [python] additional_dependencies: ['click==8.0.4'] - repo: https://github.com/PyCQA/flake8 - rev: 3.7.9 + rev: 7.0.0 hooks: - id: flake8 types: [python] diff --git a/benchmarks/CL_MASR/analyze_logs.py b/benchmarks/CL_MASR/analyze_logs.py index 78061c49a..baa45420d 100644 --- a/benchmarks/CL_MASR/analyze_logs.py +++ b/benchmarks/CL_MASR/analyze_logs.py @@ -851,7 +851,7 @@ def hex_to_rgb(hex_color: "str") -> "Tuple": f"{name.lower().replace(' ', '_')}.{args.format}", ), xlabel=None, - ylabel=f"{name} (\%)" + ylabel=f"{name} (\\%)" if args.usetex else f"{name} (%)", # noqa: W605 xticks=["base"] + [f"L{i}" for i in range(1, 1 + len(new_locales))], diff --git a/benchmarks/CL_MASR/common_voice_prepare.py b/benchmarks/CL_MASR/common_voice_prepare.py index e6f6fdb10..f882da633 100644 --- a/benchmarks/CL_MASR/common_voice_prepare.py +++ b/benchmarks/CL_MASR/common_voice_prepare.py @@ -111,7 +111,7 @@ def prepare_common_voice( _LOGGER.info( "----------------------------------------------------------------------", ) - _LOGGER.info(f"Merging TSV files...") + _LOGGER.info("Merging TSV files...") for split, max_duration in zip(_SPLITS, max_durations): tsv_files = [ os.path.join(data_folder, locale, f"{split}_with_duration.tsv") @@ -126,7 +126,7 @@ def prepare_common_voice( _LOGGER.info( "----------------------------------------------------------------------", ) - _LOGGER.info(f"Creating data manifest CSV files...") + _LOGGER.info("Creating data manifest CSV files...") for split in _SPLITS: preprocess_tsv_file( os.path.join(data_folder, f"{split}_with_duration.tsv"), diff --git a/benchmarks/CL_MASR/wavlm/pretrain.py b/benchmarks/CL_MASR/wavlm/pretrain.py index b65be5573..868f9836c 100644 --- a/benchmarks/CL_MASR/wavlm/pretrain.py +++ b/benchmarks/CL_MASR/wavlm/pretrain.py @@ -341,7 +341,7 @@ def train(hparams, run_opts): # Testing test( - hparams, run_opts, hparams["locales"], f"wer_test.txt", + hparams, run_opts, hparams["locales"], "wer_test.txt", ) diff --git a/benchmarks/CL_MASR/wavlm/train_agem.py b/benchmarks/CL_MASR/wavlm/train_agem.py index 9d94875c7..73da11756 100644 --- a/benchmarks/CL_MASR/wavlm/train_agem.py +++ b/benchmarks/CL_MASR/wavlm/train_agem.py @@ -431,7 +431,7 @@ def train(hparams, run_opts): # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) # Train on new locales diff --git a/benchmarks/CL_MASR/wavlm/train_der.py b/benchmarks/CL_MASR/wavlm/train_der.py index 987e94b8a..47f0219eb 100644 --- a/benchmarks/CL_MASR/wavlm/train_der.py +++ b/benchmarks/CL_MASR/wavlm/train_der.py @@ -363,7 +363,7 @@ def train(hparams, run_opts): # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) replay_buffer = [] diff --git a/benchmarks/CL_MASR/wavlm/train_er.py b/benchmarks/CL_MASR/wavlm/train_er.py index 6fe01ae43..ac1462e41 100644 --- a/benchmarks/CL_MASR/wavlm/train_er.py +++ b/benchmarks/CL_MASR/wavlm/train_er.py @@ -306,7 +306,7 @@ def train(hparams, run_opts): # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) # Train on new locales diff --git a/benchmarks/CL_MASR/wavlm/train_ewc.py b/benchmarks/CL_MASR/wavlm/train_ewc.py index d02ad7c68..dc9df2971 100644 --- a/benchmarks/CL_MASR/wavlm/train_ewc.py +++ b/benchmarks/CL_MASR/wavlm/train_ewc.py @@ -417,7 +417,7 @@ def train(hparams, run_opts): # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) # Train on new locales diff --git a/benchmarks/CL_MASR/wavlm/train_ft.py b/benchmarks/CL_MASR/wavlm/train_ft.py index 3f8f7aaf4..325053267 100644 --- a/benchmarks/CL_MASR/wavlm/train_ft.py +++ b/benchmarks/CL_MASR/wavlm/train_ft.py @@ -305,7 +305,7 @@ def train(hparams, run_opts): # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) # Train on new locales diff --git a/benchmarks/CL_MASR/wavlm/train_joint.py b/benchmarks/CL_MASR/wavlm/train_joint.py index 52d4b9c94..471e4b761 100644 --- a/benchmarks/CL_MASR/wavlm/train_joint.py +++ b/benchmarks/CL_MASR/wavlm/train_joint.py @@ -305,7 +305,7 @@ def train(hparams, run_opts): # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) # Train on new locales @@ -358,7 +358,7 @@ def train(hparams, run_opts): hparams, run_opts, hparams["base_locales"] + hparams["new_locales"], - f"wer_test_after.txt", + "wer_test_after.txt", ) diff --git a/benchmarks/CL_MASR/wavlm/train_l2p.py b/benchmarks/CL_MASR/wavlm/train_l2p.py index 49114dcb2..0f36f3425 100644 --- a/benchmarks/CL_MASR/wavlm/train_l2p.py +++ b/benchmarks/CL_MASR/wavlm/train_l2p.py @@ -374,7 +374,7 @@ def train(hparams, run_opts): # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) # Train on new locales diff --git a/benchmarks/CL_MASR/wavlm/train_lwf.py b/benchmarks/CL_MASR/wavlm/train_lwf.py index fd29e613d..95cae6329 100644 --- a/benchmarks/CL_MASR/wavlm/train_lwf.py +++ b/benchmarks/CL_MASR/wavlm/train_lwf.py @@ -335,7 +335,7 @@ def train(hparams, run_opts): # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) # Train on new locales diff --git a/benchmarks/CL_MASR/wavlm/train_mas.py b/benchmarks/CL_MASR/wavlm/train_mas.py index a5f97465e..b8f203f00 100644 --- a/benchmarks/CL_MASR/wavlm/train_mas.py +++ b/benchmarks/CL_MASR/wavlm/train_mas.py @@ -421,7 +421,7 @@ def train(hparams, run_opts): # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) # Train on new locales diff --git a/benchmarks/CL_MASR/wavlm/train_pb.py b/benchmarks/CL_MASR/wavlm/train_pb.py index 052d70b73..a098748ac 100644 --- a/benchmarks/CL_MASR/wavlm/train_pb.py +++ b/benchmarks/CL_MASR/wavlm/train_pb.py @@ -398,7 +398,7 @@ def train(hparams, run_opts): # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) # Train on new locales diff --git a/benchmarks/CL_MASR/wavlm/train_pnn.py b/benchmarks/CL_MASR/wavlm/train_pnn.py index 3641fa54b..90d89e672 100644 --- a/benchmarks/CL_MASR/wavlm/train_pnn.py +++ b/benchmarks/CL_MASR/wavlm/train_pnn.py @@ -309,7 +309,7 @@ def train(hparams, run_opts): # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) # Train on new locales diff --git a/benchmarks/CL_MASR/whisper/model.py b/benchmarks/CL_MASR/whisper/model.py index 8c7dc26e6..10f51e414 100644 --- a/benchmarks/CL_MASR/whisper/model.py +++ b/benchmarks/CL_MASR/whisper/model.py @@ -277,7 +277,7 @@ def generate( if forced_decoder_locale is None: # Compute most likely language token IDs all_lang_tokens = [ - f"<|{l}|>" for l in self.tokenizer.supported_languages + f"<|{lang}|>" for lang in self.tokenizer.supported_languages ] all_lang_tokens_ids = self.tokenizer.convert_tokens_to_ids( all_lang_tokens @@ -382,9 +382,7 @@ def _greedy_search( # B* alive_mask_unchanged = gen_token_ids != endoftext_id if not alive_mask_unchanged.all(): - alive_mask[ - alive_mask == True - ] = alive_mask_unchanged # noqa: E712 + alive_mask[alive_mask] = alive_mask_unchanged # noqa: E712 if not alive_mask.any(): break # B* x S x F @@ -566,9 +564,7 @@ def _beam_search( # B* alive_mask_unchanged = end_idxes < beam_size if not alive_mask_unchanged.all(): - alive_mask[ - alive_mask == True - ] = alive_mask_unchanged # noqa: E712 + alive_mask[alive_mask] = alive_mask_unchanged # noqa: E712 if not alive_mask.any(): break # N x B* x S x F diff --git a/benchmarks/CL_MASR/whisper/train_agem.py b/benchmarks/CL_MASR/whisper/train_agem.py index 70454a829..2c351a59a 100644 --- a/benchmarks/CL_MASR/whisper/train_agem.py +++ b/benchmarks/CL_MASR/whisper/train_agem.py @@ -452,7 +452,7 @@ def train(hparams, run_opts): """ # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) # Train on new locales diff --git a/benchmarks/CL_MASR/whisper/train_der.py b/benchmarks/CL_MASR/whisper/train_der.py index 86ab58048..42294b916 100644 --- a/benchmarks/CL_MASR/whisper/train_der.py +++ b/benchmarks/CL_MASR/whisper/train_der.py @@ -409,7 +409,7 @@ def train(hparams, run_opts): """ # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) replay_buffer = [] diff --git a/benchmarks/CL_MASR/whisper/train_er.py b/benchmarks/CL_MASR/whisper/train_er.py index 2783c1f85..bf31c3ec7 100644 --- a/benchmarks/CL_MASR/whisper/train_er.py +++ b/benchmarks/CL_MASR/whisper/train_er.py @@ -332,7 +332,7 @@ def train(hparams, run_opts): """ # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) # Train on new locales diff --git a/benchmarks/CL_MASR/whisper/train_ewc.py b/benchmarks/CL_MASR/whisper/train_ewc.py index 44b1607c1..16bb63c10 100644 --- a/benchmarks/CL_MASR/whisper/train_ewc.py +++ b/benchmarks/CL_MASR/whisper/train_ewc.py @@ -454,7 +454,7 @@ def train(hparams, run_opts): """ # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) # Train on new locales diff --git a/benchmarks/CL_MASR/whisper/train_ft.py b/benchmarks/CL_MASR/whisper/train_ft.py index cf404404d..13c9ec9d9 100644 --- a/benchmarks/CL_MASR/whisper/train_ft.py +++ b/benchmarks/CL_MASR/whisper/train_ft.py @@ -331,7 +331,7 @@ def train(hparams, run_opts): """ # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) # Train on new locales diff --git a/benchmarks/CL_MASR/whisper/train_joint.py b/benchmarks/CL_MASR/whisper/train_joint.py index ea0cb2743..3c67c6cf6 100644 --- a/benchmarks/CL_MASR/whisper/train_joint.py +++ b/benchmarks/CL_MASR/whisper/train_joint.py @@ -332,7 +332,7 @@ def train(hparams, run_opts): """ # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) # Train on new locales @@ -413,7 +413,7 @@ def train(hparams, run_opts): hparams, run_opts, hparams["base_locales"] + hparams["new_locales"], - f"wer_test_after.txt", + "wer_test_after.txt", ) diff --git a/benchmarks/CL_MASR/whisper/train_l2p.py b/benchmarks/CL_MASR/whisper/train_l2p.py index d2ce451d0..c392e8bd4 100644 --- a/benchmarks/CL_MASR/whisper/train_l2p.py +++ b/benchmarks/CL_MASR/whisper/train_l2p.py @@ -403,7 +403,7 @@ def train(hparams, run_opts): """ # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) # Train on new locales diff --git a/benchmarks/CL_MASR/whisper/train_lwf.py b/benchmarks/CL_MASR/whisper/train_lwf.py index d69d4ab3a..0851886f3 100644 --- a/benchmarks/CL_MASR/whisper/train_lwf.py +++ b/benchmarks/CL_MASR/whisper/train_lwf.py @@ -368,7 +368,7 @@ def train(hparams, run_opts): """ # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) # Train on new locales diff --git a/benchmarks/CL_MASR/whisper/train_mas.py b/benchmarks/CL_MASR/whisper/train_mas.py index 1b0a56dfd..1de188d1a 100644 --- a/benchmarks/CL_MASR/whisper/train_mas.py +++ b/benchmarks/CL_MASR/whisper/train_mas.py @@ -456,7 +456,7 @@ def train(hparams, run_opts): """ # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) # Train on new locales diff --git a/benchmarks/CL_MASR/whisper/train_pb.py b/benchmarks/CL_MASR/whisper/train_pb.py index b032fb6cd..f5a2945e1 100644 --- a/benchmarks/CL_MASR/whisper/train_pb.py +++ b/benchmarks/CL_MASR/whisper/train_pb.py @@ -422,7 +422,7 @@ def train(hparams, run_opts): # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) # Train on new locales diff --git a/benchmarks/CL_MASR/whisper/train_pnn.py b/benchmarks/CL_MASR/whisper/train_pnn.py index c610935e8..290a279fe 100644 --- a/benchmarks/CL_MASR/whisper/train_pnn.py +++ b/benchmarks/CL_MASR/whisper/train_pnn.py @@ -334,7 +334,7 @@ def train(hparams, run_opts): """ # Testing test( - hparams, run_opts, hparams["base_locales"], f"wer_test_before.txt", + hparams, run_opts, hparams["base_locales"], "wer_test_before.txt", ) # Train on new locales diff --git a/benchmarks/DASB/IEMOCAP/iemocap_prepare.py b/benchmarks/DASB/IEMOCAP/iemocap_prepare.py index d42fcff19..0a6c469d0 100644 --- a/benchmarks/DASB/IEMOCAP/iemocap_prepare.py +++ b/benchmarks/DASB/IEMOCAP/iemocap_prepare.py @@ -271,7 +271,7 @@ def load_utterInfo(inputFile): # [START_TIME - END_TIME] TURN_NAME EMOTION [V, A, D] # [V, A, D] means [Valence, Arousal, Dominance] pattern = re.compile( - "[\[]*[0-9]*[.][0-9]*[ -]*[0-9]*[.][0-9]*[\]][\t][a-z0-9_]*[\t][a-z]{3}[\t][\[][0-9]*[.][0-9]*[, ]+[0-9]*[.][0-9]*[, ]+[0-9]*[.][0-9]*[\]]", + "[\[]*[0-9]*[.][0-9]*[ -]*[0-9]*[.][0-9]*[\]][\t][a-z0-9_]*[\t][a-z]{3}[\t][\[][0-9]*[.][0-9]*[, ]+[0-9]*[.][0-9]*[, ]+[0-9]*[.][0-9]*[\]]", # noqa re.IGNORECASE, ) # noqa with open(inputFile, "r") as myfile: diff --git a/benchmarks/DASB/Libri2Mix/separation/conformer/train_continuous_ssl.py b/benchmarks/DASB/Libri2Mix/separation/conformer/train_continuous_ssl.py index a6765f483..280f92c74 100644 --- a/benchmarks/DASB/Libri2Mix/separation/conformer/train_continuous_ssl.py +++ b/benchmarks/DASB/Libri2Mix/separation/conformer/train_continuous_ssl.py @@ -372,7 +372,7 @@ def on_stage_end(self, stage, stage_loss, epoch=None): ) hparams["train_logger"].log_stats( stats_meta={ - f"SSL parameters/buffers (M)": f"{ssl_params / 1e6:.2f}", + "SSL parameters/buffers (M)": f"{ssl_params / 1e6:.2f}", "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", }, ) diff --git a/benchmarks/DASB/Libri2Mix/separation/conformer/train_dac.py b/benchmarks/DASB/Libri2Mix/separation/conformer/train_dac.py index eed798594..0b9eb26eb 100644 --- a/benchmarks/DASB/Libri2Mix/separation/conformer/train_dac.py +++ b/benchmarks/DASB/Libri2Mix/separation/conformer/train_dac.py @@ -128,7 +128,7 @@ def toks_to_sig(self, toks): ) hparams["train_logger"].log_stats( stats_meta={ - f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", + "Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", }, ) diff --git a/benchmarks/DASB/Libri2Mix/separation/conformer/train_discrete_ssl.py b/benchmarks/DASB/Libri2Mix/separation/conformer/train_discrete_ssl.py index a288450c3..0f4e085f5 100644 --- a/benchmarks/DASB/Libri2Mix/separation/conformer/train_discrete_ssl.py +++ b/benchmarks/DASB/Libri2Mix/separation/conformer/train_discrete_ssl.py @@ -155,7 +155,7 @@ def toks_to_sig(self, toks): ) hparams["train_logger"].log_stats( stats_meta={ - f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", + "Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", }, ) diff --git a/benchmarks/DASB/Libri2Mix/separation/conformer/train_encodec.py b/benchmarks/DASB/Libri2Mix/separation/conformer/train_encodec.py index b4e9f2a1e..8621d8a6d 100644 --- a/benchmarks/DASB/Libri2Mix/separation/conformer/train_encodec.py +++ b/benchmarks/DASB/Libri2Mix/separation/conformer/train_encodec.py @@ -376,7 +376,7 @@ def on_stage_end(self, stage, stage_loss, epoch=None): ) hparams["train_logger"].log_stats( stats_meta={ - f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", + "Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", }, ) diff --git a/benchmarks/DASB/Libri2Mix/separation/conformer/train_speech_tokenizer.py b/benchmarks/DASB/Libri2Mix/separation/conformer/train_speech_tokenizer.py index 926376b56..7691a247c 100644 --- a/benchmarks/DASB/Libri2Mix/separation/conformer/train_speech_tokenizer.py +++ b/benchmarks/DASB/Libri2Mix/separation/conformer/train_speech_tokenizer.py @@ -121,7 +121,7 @@ def toks_to_sig(self, toks): ) hparams["train_logger"].log_stats( stats_meta={ - f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", + "Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", }, ) diff --git a/benchmarks/DASB/Libri2Mix/separation/crdnn/train_continuous_ssl.py b/benchmarks/DASB/Libri2Mix/separation/crdnn/train_continuous_ssl.py index 469280af0..7c80c71ec 100644 --- a/benchmarks/DASB/Libri2Mix/separation/crdnn/train_continuous_ssl.py +++ b/benchmarks/DASB/Libri2Mix/separation/crdnn/train_continuous_ssl.py @@ -372,7 +372,7 @@ def on_stage_end(self, stage, stage_loss, epoch=None): ) hparams["train_logger"].log_stats( stats_meta={ - f"SSL parameters/buffers (M)": f"{ssl_params / 1e6:.2f}", + "SSL parameters/buffers (M)": f"{ssl_params / 1e6:.2f}", "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", }, ) diff --git a/benchmarks/DASB/Libri2Mix/separation/crdnn/train_dac.py b/benchmarks/DASB/Libri2Mix/separation/crdnn/train_dac.py index eed798594..0b9eb26eb 100644 --- a/benchmarks/DASB/Libri2Mix/separation/crdnn/train_dac.py +++ b/benchmarks/DASB/Libri2Mix/separation/crdnn/train_dac.py @@ -128,7 +128,7 @@ def toks_to_sig(self, toks): ) hparams["train_logger"].log_stats( stats_meta={ - f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", + "Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", }, ) diff --git a/benchmarks/DASB/Libri2Mix/separation/crdnn/train_discrete_ssl.py b/benchmarks/DASB/Libri2Mix/separation/crdnn/train_discrete_ssl.py index a288450c3..0f4e085f5 100644 --- a/benchmarks/DASB/Libri2Mix/separation/crdnn/train_discrete_ssl.py +++ b/benchmarks/DASB/Libri2Mix/separation/crdnn/train_discrete_ssl.py @@ -155,7 +155,7 @@ def toks_to_sig(self, toks): ) hparams["train_logger"].log_stats( stats_meta={ - f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", + "Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", }, ) diff --git a/benchmarks/DASB/Libri2Mix/separation/crdnn/train_encodec.py b/benchmarks/DASB/Libri2Mix/separation/crdnn/train_encodec.py index b2f2fed53..fa28c948a 100644 --- a/benchmarks/DASB/Libri2Mix/separation/crdnn/train_encodec.py +++ b/benchmarks/DASB/Libri2Mix/separation/crdnn/train_encodec.py @@ -376,7 +376,7 @@ def on_stage_end(self, stage, stage_loss, epoch=None): ) hparams["train_logger"].log_stats( stats_meta={ - f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", + "Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", }, ) diff --git a/benchmarks/DASB/Libri2Mix/separation/crdnn/train_speech_tokenizer.py b/benchmarks/DASB/Libri2Mix/separation/crdnn/train_speech_tokenizer.py index 926376b56..7691a247c 100644 --- a/benchmarks/DASB/Libri2Mix/separation/crdnn/train_speech_tokenizer.py +++ b/benchmarks/DASB/Libri2Mix/separation/crdnn/train_speech_tokenizer.py @@ -121,7 +121,7 @@ def toks_to_sig(self, toks): ) hparams["train_logger"].log_stats( stats_meta={ - f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", + "Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", }, ) diff --git a/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/train.py b/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/train.py index 938ce8b96..098986565 100644 --- a/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/train.py +++ b/benchmarks/DASB/LibriSpeech/ASR-on-the-fly/train.py @@ -387,7 +387,7 @@ def text_pipeline(wrd): ) hparams["train_logger"].log_stats( stats_meta={ - f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", + "Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", }, ) diff --git a/benchmarks/DASB/LibriSpeech/extraction/extract.py b/benchmarks/DASB/LibriSpeech/extraction/extract.py index 814d252be..3a649d24f 100644 --- a/benchmarks/DASB/LibriSpeech/extraction/extract.py +++ b/benchmarks/DASB/LibriSpeech/extraction/extract.py @@ -88,7 +88,7 @@ if hparams["save_embedding"]: save_folder = pl.Path(hparams["save_folder"]) - logger.info(f"Saving embeddings ...") + logger.info("Saving embeddings ...") tokens_extractor.save_pretrained_embeddings( (save_folder / "embeddings").as_posix(), vocab_size=hparams["vocab_size"], diff --git a/benchmarks/DASB/VoiceBank/enhancement/train.py b/benchmarks/DASB/VoiceBank/enhancement/train.py index 2440461fa..ac4ea9b9c 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/train.py +++ b/benchmarks/DASB/VoiceBank/enhancement/train.py @@ -258,7 +258,7 @@ def on_stage_end(self, stage, stage_loss, epoch=None): ) hparams["train_logger"].log_stats( stats_meta={ - # f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", + # "Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", }, ) diff --git a/benchmarks/DASB/VoiceBank/enhancement/train_continuous_ssl.py b/benchmarks/DASB/VoiceBank/enhancement/train_continuous_ssl.py index 5acd2b666..adcd8cb26 100644 --- a/benchmarks/DASB/VoiceBank/enhancement/train_continuous_ssl.py +++ b/benchmarks/DASB/VoiceBank/enhancement/train_continuous_ssl.py @@ -289,7 +289,7 @@ def on_stage_end(self, stage, stage_loss, epoch=None): ) hparams["train_logger"].log_stats( stats_meta={ - f"SSL parameters/buffers (M)": f"{ssl_params / 1e6:.2f}", + "SSL parameters/buffers (M)": f"{ssl_params / 1e6:.2f}", "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", }, ) diff --git a/benchmarks/DASB/VoiceBank/extraction/extract.py b/benchmarks/DASB/VoiceBank/extraction/extract.py index 346fae442..36e1b4191 100644 --- a/benchmarks/DASB/VoiceBank/extraction/extract.py +++ b/benchmarks/DASB/VoiceBank/extraction/extract.py @@ -85,7 +85,7 @@ if hparams["save_embedding"]: save_folder = pl.Path(hparams["save_folder"]) - logger.info(f"Saving embeddings ...") + logger.info("Saving embeddings ...") tokens_extractor_in.save_pretrained_embeddings( (save_folder / "embeddings" / "input").as_posix(), vocab_size=hparams["vocab_size"], @@ -101,7 +101,7 @@ if hparams["save_embedding"]: save_folder = pl.Path(hparams["save_folder"]) - logger.info(f"Saving embeddings ...") + logger.info("Saving embeddings ...") tokens_extractor_out.save_pretrained_embeddings( (save_folder / "embeddings" / "output").as_posix(), vocab_size=hparams["vocab_size"], diff --git a/benchmarks/MOABB/utils/parse_results.py b/benchmarks/MOABB/utils/parse_results.py index 5c22445b3..3bbd470f8 100644 --- a/benchmarks/MOABB/utils/parse_results.py +++ b/benchmarks/MOABB/utils/parse_results.py @@ -68,7 +68,7 @@ def visualize_results(paradigm: str, results: dict, vis_metrics: list) -> None: """ print("\n----", paradigm.name, "----") for key in results: - if type(results[key]) == dict: + if isinstance(results[key], dict): for m in vis_metrics: print( key, diff --git a/benchmarks/MOABB/utils/prepare.py b/benchmarks/MOABB/utils/prepare.py index 9f6371be4..4aea975c4 100644 --- a/benchmarks/MOABB/utils/prepare.py +++ b/benchmarks/MOABB/utils/prepare.py @@ -83,11 +83,11 @@ def get_output_dict( ) if verbose == 1: - for l in np.unique(labels): + for lang in np.unique(labels): print( print( "Number of label {0} examples: {1}".format( - l, np.where(labels == l)[0].shape[0] + lang, np.where(labels == lang)[0].shape[0] ) ) ) diff --git a/benchmarks/MP3S/IEMOCAP/ecapa_tdnn/iemocap_prepare.py b/benchmarks/MP3S/IEMOCAP/ecapa_tdnn/iemocap_prepare.py index 48feed81f..4236092bb 100644 --- a/benchmarks/MP3S/IEMOCAP/ecapa_tdnn/iemocap_prepare.py +++ b/benchmarks/MP3S/IEMOCAP/ecapa_tdnn/iemocap_prepare.py @@ -273,7 +273,7 @@ def load_utterInfo(inputFile): # [START_TIME - END_TIME] TURN_NAME EMOTION [V, A, D] # [V, A, D] means [Valence, Arousal, Dominance] pattern = re.compile( - "[\[]*[0-9]*[.][0-9]*[ -]*[0-9]*[.][0-9]*[\]][\t][a-z0-9_]*[\t][a-z]{3}[\t][\[][0-9]*[.][0-9]*[, ]+[0-9]*[.][0-9]*[, ]+[0-9]*[.][0-9]*[\]]", + "[\[]*[0-9]*[.][0-9]*[ -]*[0-9]*[.][0-9]*[\]][\t][a-z0-9_]*[\t][a-z]{3}[\t][\[][0-9]*[.][0-9]*[, ]+[0-9]*[.][0-9]*[, ]+[0-9]*[.][0-9]*[\]]", # noqa re.IGNORECASE, ) # noqa with open(inputFile, "r") as myfile: diff --git a/tests/utils/recipe_tests.py b/tests/utils/recipe_tests.py index 747ac6f1d..f61bb45a3 100644 --- a/tests/utils/recipe_tests.py +++ b/tests/utils/recipe_tests.py @@ -41,11 +41,11 @@ def check_row_for_test(row, filters_fields, filters, test_field): test = True for i, field in enumerate(filters_fields): field_values = filters[i] - if type(field_values) == str: + if isinstance(field_values, str): # ... AND ... filter if not (field_values == row[field]): test = False - elif type(field_values) == list: # type(field) == list + elif isinstance(field_values, list): # type(field) == list # ... AND (... OR ...) ... filter; at least one entry of the list matches test_flag = False for filt in field_values: From 36c0ed4c3767ec9ff544ce7f95af80a6dba6a39f Mon Sep 17 00:00:00 2001 From: Luca Della Libera Date: Wed, 23 Jul 2025 15:39:15 -0400 Subject: [PATCH 5/5] pre-commit --- benchmarks/CL_MASR/whisper/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmarks/CL_MASR/whisper/model.py b/benchmarks/CL_MASR/whisper/model.py index 1eaf6c508..39554a163 100644 --- a/benchmarks/CL_MASR/whisper/model.py +++ b/benchmarks/CL_MASR/whisper/model.py @@ -383,8 +383,8 @@ def _greedy_search( alive_mask_unchanged = gen_token_ids != endoftext_id if not alive_mask_unchanged.all(): alive_mask[ - alive_mask == True - ] = alive_mask_unchanged # noqa: E712 + alive_mask == True # noqa: E712 + ] = alive_mask_unchanged if not alive_mask.any(): break # B* x S x F @@ -567,8 +567,8 @@ def _beam_search( alive_mask_unchanged = end_idxes < beam_size if not alive_mask_unchanged.all(): alive_mask[ - alive_mask == True - ] = alive_mask_unchanged # noqa: E712 + alive_mask == True # noqa: E712 + ] = alive_mask_unchanged if not alive_mask.any(): break # N x B* x S x F