Skip to content

Commit eaefc57

Browse files
Stardust-minusYYuX-1145pre-commit-ci[bot]jiangyuxiaoxiaoOedoSoldier
authored
Use clap to achieve prompt controlled generation (#223)
* 快速分类音频并把yml格式结果存在训练根目录里 (#190) * Add files via upload * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Update models.py * Update webui.py * Update infer.py * Create compress_model.py * 重新提交,更新Gradio推理UI (#193) * Update webui.py * Update webui.py * 更新 train_ms.py * 更新 models.py * 更新 models.py * 更新 models.py * 更新 train_ms.py * 更新 train_ms.py * 更新 models.py * Update preprocess_text.py * Update config.json * Update train_ms.py * Update webui.py (#206) * Add files via upload (#209) * Update train_ms.py * Update train_ms.py * Update preprocess_text.py * Update train_ms.py * fix (#211) * Update emotion_clustering.py * Add files via upload * Update emotion_clustering.py * add cluster center save * Add files via upload * Update config.py * Update default_config.yml * Update config.py * Update config.py * Update emotion_clustering.py * Update emotion_clustering.py * Update config.py * Update emotion_clustering.py * Update emotion_clustering.py * Update webui.py * Update emotion_clustering.py * Update commons.py * Update emotion_clustering.py * Update webui.py * Update webui.py * Add files via upload * Update train_ms.py * Update train_ms.py * Update train_ms.py * Update train_ms.py * Update train_ms.py * Update webui.py * Update emotion_clustering.py * Update emotion_clustering.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix default_config.yml. * Update infer.py * feat: support infer 2.1 models * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: support infer 2.1 models 兼容bug修复 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update train_ms.py * Add CLAP * Fix data loader * Fix infer.py * Fix webui.py * Add prompt template * Update clap_gen.py * Fix wrong environ value * Add g for dur disc * Update clap_gen.py * Fix multilang generation * Update config.json * Prompt mode * Improve slice segments performance * Add preprocess webui * Update webui_preprocess.py * Update webui_preprocess.py * Update config.py * Update default_config.yml * Update config.py * Update clap_gen.py * Delete emo_gen.py * Delete get_emo.py * Delete emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim directory * Update README.md * Update README * Split val per lang * Delete emotion_clustering.py * Update default_config.yml * Update default_config.yml * Update config.py * Update preprocess_text.py * Update webui_preprocess.py * Update defalut_config.yml * Update webui_preprocess.py * Update preprocess_text.py * Random augmentation for CLAP * Update data_utils.py * Update preprocess_text.py * Add vq for CLAP features to avoid overfitting * Random dummy inputs * Update webui.py * Update models.py * Update infer.py * Apply Code Formatter Change * Update config.json * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: YYuX-1145 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sora <[email protected]> Co-authored-by: Sihan Wang <[email protected]> Co-authored-by: Stardust-minus <[email protected]>
1 parent 9cc786d commit eaefc57

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+287404
-3426
lines changed

README.md

+5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
# Bert-VITS2
66

77
VITS2 Backbone with multilingual bert
8+
9+
For quick guide, please refer to `webui_preprocess.py`.
10+
11+
简易教程请参见 `webui_preprocess.py`
12+
813
## 请注意,本项目核心思路来源于[anyvoiceai/MassTTS](https://github.com/anyvoiceai/MassTTS) 一个非常好的tts项目
914
## MassTTS的演示demo为[ai版峰哥锐评峰哥本人,并找回了在金三角失落的腰子](https://www.bilibili.com/video/BV1w24y1c7z9)
1015

clap_gen.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import argparse
2+
from multiprocessing import Pool, cpu_count
3+
4+
import torch
5+
import torch.multiprocessing as mp
6+
from tqdm import tqdm
7+
8+
import utils
9+
from config import config
10+
from clap_wrapper import get_clap_audio_feature
11+
import librosa
12+
import os
13+
14+
os.environ["OMP_NUM_THREADS"] = "1"
15+
os.environ["MKL_NUM_THREADS"] = "1"
16+
17+
18+
def process_line(line):
19+
device = config.emo_gen_config.device
20+
if config.emo_gen_config.use_multi_device:
21+
rank = mp.current_process()._identity
22+
rank = rank[0] if len(rank) > 0 else 0
23+
if torch.cuda.is_available():
24+
gpu_id = rank % torch.cuda.device_count()
25+
device = torch.device(f"cuda:{gpu_id}")
26+
else:
27+
device = torch.device("cpu")
28+
wav_path, _, language_str, text, phones, tone, word2ph = line.strip().split("|")
29+
30+
clap_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".emo.npy")
31+
if os.path.isfile(clap_path):
32+
return
33+
34+
audio = librosa.load(wav_path, 48000)[0]
35+
# audio = librosa.resample(audio, 44100, 48000)
36+
37+
clap = get_clap_audio_feature(audio, device)
38+
torch.save(clap, clap_path)
39+
40+
41+
if __name__ == "__main__":
42+
parser = argparse.ArgumentParser()
43+
parser.add_argument(
44+
"-c", "--config", type=str, default=config.emo_gen_config.config_path
45+
)
46+
parser.add_argument(
47+
"--num_processes", type=int, default=config.emo_gen_config.num_processes
48+
)
49+
args, _ = parser.parse_known_args()
50+
config_path = args.config
51+
hps = utils.get_hparams_from_file(config_path)
52+
lines = []
53+
with open(hps.data.training_files, encoding="utf-8") as f:
54+
lines.extend(f.readlines())
55+
56+
with open(hps.data.validation_files, encoding="utf-8") as f:
57+
lines.extend(f.readlines())
58+
if len(lines) != 0:
59+
num_processes = min(args.num_processes, cpu_count())
60+
with Pool(processes=num_processes) as pool:
61+
for _ in tqdm(pool.imap_unordered(process_line, lines), total=len(lines)):
62+
pass
63+
64+
print(f"clap生成完毕!, 共有{len(lines)}个emo.pt生成!")

clap_wrapper.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import sys
2+
3+
import torch
4+
from transformers import ClapModel, ClapProcessor
5+
6+
from config import config
7+
8+
models = dict()
9+
processor = ClapProcessor.from_pretrained("./emotional/clap-htsat-fused")
10+
11+
12+
def get_clap_audio_feature(audio_data, device=config.bert_gen_config.device):
13+
if (
14+
sys.platform == "darwin"
15+
and torch.backends.mps.is_available()
16+
and device == "cpu"
17+
):
18+
device = "mps"
19+
if not device:
20+
device = "cuda"
21+
if device not in models.keys():
22+
models[device] = ClapModel.from_pretrained("./emotional/clap-htsat-fused").to(
23+
device
24+
)
25+
with torch.no_grad():
26+
inputs = processor(
27+
audios=audio_data, return_tensors="pt", sampling_rate=48000
28+
).to(device)
29+
emb = models[device].get_audio_features(**inputs)
30+
return emb.T
31+
32+
33+
def get_clap_text_feature(text, device=config.bert_gen_config.device):
34+
if (
35+
sys.platform == "darwin"
36+
and torch.backends.mps.is_available()
37+
and device == "cpu"
38+
):
39+
device = "mps"
40+
if not device:
41+
device = "cuda"
42+
if device not in models.keys():
43+
models[device] = ClapModel.from_pretrained("./emotional/clap-htsat-fused").to(
44+
device
45+
)
46+
with torch.no_grad():
47+
inputs = processor(text=text, return_tensors="pt").to(device)
48+
emb = models[device].get_text_features(**inputs)
49+
return emb.T

commons.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -46,26 +46,18 @@ def rand_gumbel_like(x):
4646

4747

4848
def slice_segments(x, ids_str, segment_size=4):
49-
ret = torch.zeros_like(x[:, :, :segment_size])
50-
for i in range(x.size(0)):
51-
idx_str = ids_str[i]
52-
idx_end = idx_str + segment_size
53-
if idx_str < 0:
54-
i1 = x.size(2) + idx_str
55-
r1 = x[i, :, i1:]
56-
r2 = x[i, :, :idx_end]
57-
ret[i] = torch.cat([r1, r2], dim=1)
58-
else:
59-
ret[i] = x[i, :, idx_str:idx_end]
60-
return ret
49+
gather_indices = ids_str.view(x.size(0), 1, 1).repeat(
50+
1, x.size(1), 1
51+
) + torch.arange(segment_size, device=x.device)
52+
return torch.gather(x, 2, gather_indices)
6153

6254

6355
def rand_slice_segments(x, x_lengths=None, segment_size=4):
6456
b, d, t = x.size()
6557
if x_lengths is None:
6658
x_lengths = t
67-
ids_str_max = x_lengths - segment_size + 1
68-
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
59+
ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
60+
ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
6961
ret = slice_segments(x, ids_str, segment_size)
7062
return ret, ids_str
7163

compress_model.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections import OrderedDict
22
from text.symbols import symbols
33
import torch
4+
45
from tools.log import logger
56
import utils
67
from models import SynthesizerTrn

0 commit comments

Comments
 (0)