|
35 | 35 | # from fish_speech.models.vqgan.lit_module import VQGAN
|
36 | 36 | from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
37 | 37 | from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
38 |
| -from fish_speech.utils import autocast_exclude_mps |
| 38 | +from fish_speech.utils import autocast_exclude_mps, set_seed |
39 | 39 | from tools.commons import ServeTTSRequest
|
40 | 40 | from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
|
41 | 41 | from tools.llama.generate import (
|
|
46 | 46 | )
|
47 | 47 | from tools.vqgan.inference import load_model as load_decoder_model
|
48 | 48 |
|
| 49 | +backends = torchaudio.list_audio_backends() |
| 50 | +if "sox" in backends: |
| 51 | + backend = "sox" |
| 52 | +elif "ffmpeg" in backends: |
| 53 | + backend = "ffmpeg" |
| 54 | +else: |
| 55 | + backend = "soundfile" |
| 56 | + |
49 | 57 |
|
50 | 58 | def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
51 | 59 | buffer = io.BytesIO()
|
@@ -88,10 +96,7 @@ def load_audio(reference_audio, sr):
|
88 | 96 | audio_data = reference_audio
|
89 | 97 | reference_audio = io.BytesIO(audio_data)
|
90 | 98 |
|
91 |
| - waveform, original_sr = torchaudio.load( |
92 |
| - reference_audio, |
93 |
| - backend="soundfile", # not every linux release supports 'sox' or 'ffmpeg' |
94 |
| - ) |
| 99 | + waveform, original_sr = torchaudio.load(reference_audio, backend=backend) |
95 | 100 |
|
96 | 101 | if waveform.shape[0] > 1:
|
97 | 102 | waveform = torch.mean(waveform, dim=0, keepdim=True)
|
@@ -215,6 +220,10 @@ def inference(req: ServeTTSRequest):
|
215 | 220 | else:
|
216 | 221 | logger.info("Use same references")
|
217 | 222 |
|
| 223 | + if req.seed is not None: |
| 224 | + set_seed(req.seed) |
| 225 | + logger.warning(f"set seed: {req.seed}") |
| 226 | + |
218 | 227 | # LLAMA Inference
|
219 | 228 | request = dict(
|
220 | 229 | device=decoder_model.device,
|
|
0 commit comments