Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/DASB/LibriSpeech/ASR-on-the-fly/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def text_pipeline(wrd):
)
hparams["train_logger"].log_stats(
stats_meta={
f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}",
"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}",
"Model parameters/buffers (M)": f"{model_params / 1e6:.2f}",
},
)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/DASB/LibriSpeech/extraction/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@

if hparams["save_embedding"]:
save_folder = pl.Path(hparams["save_folder"])
logger.info(f"Saving embeddings ...")
logger.info("Saving embeddings ...")
tokens_extractor.save_pretrained_embeddings(
(save_folder / "embeddings").as_posix(),
vocab_size=hparams["vocab_size"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@
* Luca Della Libera 2024
"""

import os

import speechbrain as sb
import torch
import torchaudio
from speechbrain.dataio.dataio import merge_csvs
from transformers.models.hubert.modeling_hubert import (
HubertEncoderStableLayerNorm,
)
Expand Down Expand Up @@ -100,30 +97,6 @@ def dataio_prepare(
It also defines the data processing pipeline through user-defined functions.

"""
if isinstance(train_csv, (list, tuple)):
csvs = [os.path.basename(x) for x in train_csv]
save_folder = os.path.dirname(train_csv[0])
merge_csvs(
save_folder, csvs, "train.csv",
)
train_csv = os.path.join(save_folder, "train.csv")

if isinstance(valid_csv, (list, tuple)):
csvs = [os.path.basename(x) for x in valid_csv]
save_folder = os.path.dirname(valid_csv[0])
merge_csvs(
save_folder, csvs, "valid.csv",
)
valid_csv = os.path.join(save_folder, "valid.csv")

if isinstance(test_csv, (list, tuple)):
csvs = [os.path.basename(x) for x in test_csv]
save_folder = os.path.dirname(test_csv[0])
merge_csvs(
save_folder, csvs, "test.csv",
)
test_csv = os.path.join(save_folder, "test.csv")

train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
csv_path=train_csv, replacements={"DATA_ROOT": data_folder},
)
Expand Down Expand Up @@ -155,10 +128,37 @@ def dataio_prepare(
)

datasets = [train_data, valid_data, test_data]
output_keys = ["id"]

if "tokens_loader_in" in hparams and "tokens_loader_out" in hparams:
# Define tokens pipeline:
takes = ["id"]
provides = ["in_toks", "out_toks"]
output_keys += provides

tokens_loader_in = hparams["tokens_loader_in"]
tokens_loader_out = hparams["tokens_loader_out"]
num_codebooks = hparams["num_codebooks"]

def toks_pipeline(id):
in_toks = tokens_loader_in.tokens_by_uttid(
id, num_codebooks=num_codebooks
)
yield in_toks

out_toks = tokens_loader_out.tokens_by_uttid(
id, num_codebooks=num_codebooks
)
yield out_toks

sb.dataio.dataset.add_dynamic_item(
datasets, toks_pipeline, takes, provides
)

# Define audio pipeline
takes = ["clean_wav", "noisy_wav"]
provides = ["in_sig", "out_sig"]
output_keys += provides

def audio_pipeline(clean_wav, noisy_wav):
# Clean signal
Expand All @@ -185,11 +185,11 @@ def audio_pipeline(clean_wav, noisy_wav):
yield out_sig

sb.dataio.dataset.add_dynamic_item(
[train_data, valid_data, test_data], audio_pipeline, takes, provides
datasets, audio_pipeline, takes, provides
)

# Set output
sb.dataio.dataset.set_output_keys(datasets, ["id"] + provides)
sb.dataio.dataset.set_output_keys(datasets, output_keys)

return train_data, valid_data, test_data

Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

1 change: 0 additions & 1 deletion benchmarks/DASB/VoiceBank/enhancement/conformer/utils.py

This file was deleted.

This file was deleted.

This file was deleted.

Loading