Skip to content

Commit e37a445

Browse files
Fix backend (#627)
* Linux pyaudio dependencies * revert generate.py * Better bug report & feat request * Auto-select torchaudio backend * safety * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: manual seed for restore * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Gradio > 5 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ecaa69e commit e37a445

16 files changed

+185
-109
lines changed

docs/en/inference.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ python -m tools.webui \
122122
```
123123

124124
!!! note
125-
You can save the label file and reference audio file in advance to the examples folder in the main directory (which you need to create yourself), so that you can directly call them in the WebUI.
125+
You can save the label file and reference audio file in advance to the `references` folder in the main directory (which you need to create yourself), so that you can directly call them in the WebUI.
126126

127127
!!! note
128128
You can use Gradio environment variables, such as `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME` to configure WebUI.

docs/ja/inference.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ python -m tools.webui \
152152
```
153153

154154
!!! note
155-
ラベルファイルと参照音声ファイルをメインディレクトリの examples フォルダ(自分で作成する必要があります)に事前に保存しておくことで、WebUI で直接呼び出すことができます。
155+
ラベルファイルと参照音声ファイルをメインディレクトリの `references` フォルダ(自分で作成する必要があります)に事前に保存しておくことで、WebUI で直接呼び出すことができます。
156156

157157
!!! note
158158
Gradio 環境変数(`GRADIO_SHARE``GRADIO_SERVER_PORT``GRADIO_SERVER_NAME`など)を使用して WebUI を構成できます。

docs/pt/inference.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ python -m tools.webui \
148148
```
149149

150150
!!! note
151-
Você pode salvar antecipadamente o arquivo de rótulos e o arquivo de áudio de referência na pasta examples do diretório principal (que você precisa criar), para que possa chamá-los diretamente na WebUI.
151+
Você pode salvar antecipadamente o arquivo de rótulos e o arquivo de áudio de referência na pasta `references` do diretório principal (que você precisa criar), para que possa chamá-los diretamente na WebUI.
152152

153153
!!! note
154154
É possível usar variáveis de ambiente do Gradio, como `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME`, para configurar a WebUI.

docs/zh/inference.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ python -m tools.webui \
132132
```
133133

134134
!!! note
135-
你可以提前将label文件和参考音频文件保存到主目录下的examples文件夹(需要自行创建),这样你可以直接在WebUI中调用它们。
135+
你可以提前将label文件和参考音频文件保存到主目录下的 `references` 文件夹(需要自行创建),这样你可以直接在WebUI中调用它们。
136136

137137
!!! note
138138
你可以使用 Gradio 环境变量, 如 `GRADIO_SHARE`, `GRADIO_SERVER_PORT`, `GRADIO_SERVER_NAME` 来配置 WebUI.

fish_speech/models/text2semantic/llama.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,10 @@ def from_pretrained(
369369
model = simple_quantizer.convert_for_runtime()
370370

371371
weights = torch.load(
372-
Path(path) / "model.pth", map_location="cpu", mmap=True
372+
Path(path) / "model.pth",
373+
map_location="cpu",
374+
mmap=True,
375+
weights_only=True,
373376
)
374377

375378
if "state_dict" in weights:

fish_speech/utils/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .logger import RankedLogger
66
from .logging_utils import log_hyperparameters
77
from .rich_utils import enforce_tags, print_config_tree
8-
from .utils import extras, get_metric_value, task_wrapper
8+
from .utils import extras, get_metric_value, set_seed, task_wrapper
99

1010
__all__ = [
1111
"enforce_tags",
@@ -20,4 +20,5 @@
2020
"braceexpand",
2121
"get_latest_checkpoint",
2222
"autocast_exclude_mps",
23+
"set_seed",
2324
]

fish_speech/utils/utils.py

+22
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import random
12
import warnings
23
from importlib.util import find_spec
34
from typing import Callable
45

6+
import numpy as np
7+
import torch
58
from omegaconf import DictConfig
69

710
from .logger import RankedLogger
@@ -112,3 +115,22 @@ def get_metric_value(metric_dict: dict, metric_name: str) -> float:
112115
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
113116

114117
return metric_value
118+
119+
120+
def set_seed(seed: int):
121+
if seed < 0:
122+
seed = -seed
123+
if seed > (1 << 31):
124+
seed = 1 << 31
125+
126+
random.seed(seed)
127+
np.random.seed(seed)
128+
torch.manual_seed(seed)
129+
130+
if torch.cuda.is_available():
131+
torch.cuda.manual_seed(seed)
132+
torch.cuda.manual_seed_all(seed)
133+
134+
if torch.backends.cudnn.is_available():
135+
torch.backends.cudnn.deterministic = True
136+
torch.backends.cudnn.benchmark = False

fish_speech/webui/launch_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def __init__(
114114
block_title_text_weight="600",
115115
block_border_width="3px",
116116
block_shadow="*shadow_drop_lg",
117-
button_shadow="*shadow_drop_lg",
117+
# button_shadow="*shadow_drop_lg",
118118
button_small_padding="0px",
119119
button_large_padding="3px",
120120
)

fish_speech/webui/manage.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,7 @@ def llama_quantify(llama_weight, quantify_mode):
794794
value="VQGAN",
795795
)
796796
with gr.Row():
797-
with gr.Tabs():
797+
with gr.Column():
798798
with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page:
799799
gr.HTML("You don't need to train this model!")
800800

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ dependencies = [
2323
"einops>=0.7.0",
2424
"librosa>=0.10.1",
2525
"rich>=13.5.3",
26-
"gradio<5.0.0",
26+
"gradio>5.0.0",
2727
"wandb>=0.15.11",
2828
"grpcio>=1.58.0",
2929
"kui>=1.6.0",

tools/api.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
# from fish_speech.models.vqgan.lit_module import VQGAN
3636
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
3737
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
38-
from fish_speech.utils import autocast_exclude_mps
38+
from fish_speech.utils import autocast_exclude_mps, set_seed
3939
from tools.commons import ServeTTSRequest
4040
from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
4141
from tools.llama.generate import (
@@ -46,6 +46,14 @@
4646
)
4747
from tools.vqgan.inference import load_model as load_decoder_model
4848

49+
backends = torchaudio.list_audio_backends()
50+
if "sox" in backends:
51+
backend = "sox"
52+
elif "ffmpeg" in backends:
53+
backend = "ffmpeg"
54+
else:
55+
backend = "soundfile"
56+
4957

5058
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
5159
buffer = io.BytesIO()
@@ -88,10 +96,7 @@ def load_audio(reference_audio, sr):
8896
audio_data = reference_audio
8997
reference_audio = io.BytesIO(audio_data)
9098

91-
waveform, original_sr = torchaudio.load(
92-
reference_audio,
93-
backend="soundfile", # not every linux release supports 'sox' or 'ffmpeg'
94-
)
99+
waveform, original_sr = torchaudio.load(reference_audio, backend=backend)
95100

96101
if waveform.shape[0] > 1:
97102
waveform = torch.mean(waveform, dim=0, keepdim=True)
@@ -215,6 +220,10 @@ def inference(req: ServeTTSRequest):
215220
else:
216221
logger.info("Use same references")
217222

223+
if req.seed is not None:
224+
set_seed(req.seed)
225+
logger.warning(f"set seed: {req.seed}")
226+
218227
# LLAMA Inference
219228
request = dict(
220229
device=decoder_model.device,

tools/commons.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class ServeTTSRequest(BaseModel):
2020
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
2121
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
2222
reference_id: str | None = None
23+
seed: int | None = None
2324
use_memory_cache: Literal["on-demand", "never"] = "never"
2425
# Normalize text for en & zh, this increase stability for numbers
2526
normalize: bool = True

tools/post_api.py

+7
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ def parse_args():
109109
default="never",
110110
help="Cache encoded references codes in memory",
111111
)
112+
parser.add_argument(
113+
"--seed",
114+
type=int,
115+
default=None,
116+
help="None means randomized inference, otherwise deterministic",
117+
)
112118

113119
return parser.parse_args()
114120

@@ -155,6 +161,7 @@ def parse_args():
155161
"emotion": args.emotion,
156162
"streaming": args.streaming,
157163
"use_memory_cache": args.use_memory_cache,
164+
"seed": args.seed,
158165
}
159166

160167
pydantic_data = ServeTTSRequest(**data)

tools/vqgan/extract_vq.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@
2424
# This file is used to convert the audio files to text files using the Whisper model.
2525
# It's mainly used to generate the training data for the VQ model.
2626

27+
backends = torchaudio.list_audio_backends()
28+
if "sox" in backends:
29+
backend = "sox"
30+
elif "ffmpeg" in backends:
31+
backend = "ffmpeg"
32+
else:
33+
backend = "soundfile"
2734

2835
RANK = int(os.environ.get("SLURM_PROCID", 0))
2936
WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
@@ -81,7 +88,7 @@ def process_batch(files: list[Path], model) -> float:
8188
for file in files:
8289
try:
8390
wav, sr = torchaudio.load(
84-
str(file), backend="sox" if sys.platform == "linux" else "soundfile"
91+
str(file), backend=backend
8592
) # Need to install libsox-dev
8693
except Exception as e:
8794
logger.error(f"Error reading {file}: {e}")

tools/vqgan/inference.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ def load_model(config_name, checkpoint_path, device="cuda"):
2424

2525
model = instantiate(cfg)
2626
state_dict = torch.load(
27-
checkpoint_path,
28-
map_location=device,
27+
checkpoint_path, map_location=device, mmap=True, weights_only=True
2928
)
3029
if "state_dict" in state_dict:
3130
state_dict = state_dict["state_dict"]
@@ -37,7 +36,7 @@ def load_model(config_name, checkpoint_path, device="cuda"):
3736
if "generator." in k
3837
}
3938

40-
result = model.load_state_dict(state_dict, strict=False)
39+
result = model.load_state_dict(state_dict, strict=False, assign=True)
4140
model.eval()
4241
model.to(device)
4342

0 commit comments

Comments
 (0)