Skip to content

Commit dc250ab

Browse files
add auto_rerank part (#393)
* add auto_rerank part * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * swin to UTF-8 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 04b6c10 commit dc250ab

File tree

3 files changed

+241
-8
lines changed

3 files changed

+241
-8
lines changed

tools/api.py

+35
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
# from fish_speech.models.vqgan.lit_module import VQGAN
3434
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
35+
from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
3536
from tools.llama.generate import (
3637
GenerateRequest,
3738
GenerateResponse,
@@ -293,6 +294,39 @@ def inference(req: InvokeRequest):
293294
yield fake_audios
294295

295296

297+
def auto_rerank_inference(req: InvokeRequest, use_auto_rerank: bool = True):
298+
if not use_auto_rerank:
299+
# 如果不使用 auto_rerank,直接调用原始的 inference 函数
300+
return inference(req)
301+
302+
zh_model, en_model = load_model()
303+
max_attempts = 5
304+
best_wer = float("inf")
305+
best_audio = None
306+
307+
for attempt in range(max_attempts):
308+
# 调用原始的 inference 函数
309+
audio_generator = inference(req)
310+
fake_audios = next(audio_generator)
311+
312+
asr_result = batch_asr(
313+
zh_model if is_chinese(req.text) else en_model, [fake_audios], 44100
314+
)[0]
315+
wer = calculate_wer(req.text, asr_result["text"])
316+
317+
if wer <= 0.1 and not asr_result["huge_gap"]:
318+
return fake_audios
319+
320+
if wer < best_wer:
321+
best_wer = wer
322+
best_audio = fake_audios
323+
324+
if attempt == max_attempts - 1:
325+
break
326+
327+
return best_audio
328+
329+
296330
async def inference_async(req: InvokeRequest):
297331
for chunk in inference(req):
298332
yield chunk
@@ -377,6 +411,7 @@ def parse_args():
377411
parser.add_argument("--max-text-length", type=int, default=0)
378412
parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
379413
parser.add_argument("--workers", type=int, default=1)
414+
parser.add_argument("--use-auto-rerank", type=bool, default=True)
380415

381416
return parser.parse_args()
382417

tools/auto_rerank.py

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import time
2+
from threading import Lock
3+
4+
import numpy as np
5+
import torch
6+
import torchaudio
7+
from funasr import AutoModel
8+
from funasr.models.seaco_paraformer.model import SeacoParaformer
9+
10+
# Monkey patching to disable hotwords
11+
SeacoParaformer.generate_hotwords_list = lambda self, *args, **kwargs: None
12+
13+
14+
def load_model(*, device="cuda"):
15+
zh_model = AutoModel(
16+
model="paraformer-zh",
17+
device=device,
18+
disable_pbar=True,
19+
)
20+
en_model = AutoModel(
21+
model="paraformer-en",
22+
device=device,
23+
disable_pbar=True,
24+
)
25+
26+
return zh_model, en_model
27+
28+
29+
@torch.no_grad()
30+
def batch_asr_internal(model, audios, sr):
31+
resampled_audios = []
32+
for audio in audios:
33+
# 将 NumPy 数组转换为 PyTorch 张量
34+
if isinstance(audio, np.ndarray):
35+
audio = torch.from_numpy(audio).float()
36+
37+
# 确保音频是一维的
38+
if audio.dim() > 1:
39+
audio = audio.squeeze()
40+
41+
audio = torchaudio.functional.resample(audio, sr, 16000)
42+
assert audio.dim() == 1
43+
resampled_audios.append(audio)
44+
45+
res = model.generate(input=resampled_audios, batch_size=len(resampled_audios))
46+
47+
results = []
48+
for r, audio in zip(res, audios):
49+
text = r["text"]
50+
duration = len(audio) / sr * 1000
51+
huge_gap = False
52+
53+
if "timestamp" in r and len(r["timestamp"]) > 2:
54+
for timestamp_a, timestamp_b in zip(
55+
r["timestamp"][:-1], r["timestamp"][1:]
56+
):
57+
# If there is a gap of more than 5 seconds, we consider it as a huge gap
58+
if timestamp_b[0] - timestamp_a[1] > 5000:
59+
huge_gap = True
60+
break
61+
62+
# Doesn't make sense to have a huge gap at the end
63+
if duration - r["timestamp"][-1][1] > 3000:
64+
huge_gap = True
65+
66+
results.append(
67+
{
68+
"text": text,
69+
"duration": duration,
70+
"huge_gap": huge_gap,
71+
}
72+
)
73+
74+
return results
75+
76+
77+
global_lock = Lock()
78+
79+
80+
def batch_asr(model, audios, sr):
81+
return batch_asr_internal(model, audios, sr)
82+
83+
84+
def is_chinese(text):
85+
return True
86+
87+
88+
def calculate_wer(text1, text2):
89+
words1 = text1.split()
90+
words2 = text2.split()
91+
92+
# 计算编辑距离
93+
m, n = len(words1), len(words2)
94+
dp = [[0] * (n + 1) for _ in range(m + 1)]
95+
96+
for i in range(m + 1):
97+
dp[i][0] = i
98+
for j in range(n + 1):
99+
dp[0][j] = j
100+
101+
for i in range(1, m + 1):
102+
for j in range(1, n + 1):
103+
if words1[i - 1] == words2[j - 1]:
104+
dp[i][j] = dp[i - 1][j - 1]
105+
else:
106+
dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1
107+
108+
# 计算WER
109+
edits = dp[m][n]
110+
wer = edits / len(words1)
111+
112+
return wer
113+
114+
115+
if __name__ == "__main__":
116+
zh_model, en_model = load_model()
117+
audios = [
118+
torchaudio.load("lengyue.wav")[0][0],
119+
torchaudio.load("lengyue.wav")[0][0, : 44100 * 5],
120+
]
121+
print(batch_asr(zh_model, audios, 44100))
122+
123+
start_time = time.time()
124+
for _ in range(10):
125+
batch_asr(zh_model, audios, 44100)
126+
print("Time taken:", time.time() - start_time)

tools/webui.py

+80-8
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from fish_speech.i18n import i18n
2222
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
2323
from tools.api import decode_vq_tokens, encode_reference
24+
from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
2425
from tools.llama.generate import (
2526
GenerateRequest,
2627
GenerateResponse,
@@ -162,7 +163,81 @@ def inference(
162163
gc.collect()
163164

164165

165-
inference_stream = partial(inference, streaming=True)
166+
def inference_with_auto_rerank(
167+
text,
168+
enable_reference_audio,
169+
reference_audio,
170+
reference_text,
171+
max_new_tokens,
172+
chunk_length,
173+
top_p,
174+
repetition_penalty,
175+
temperature,
176+
streaming=False,
177+
use_auto_rerank=True,
178+
):
179+
if not use_auto_rerank:
180+
return inference(
181+
text,
182+
enable_reference_audio,
183+
reference_audio,
184+
reference_text,
185+
max_new_tokens,
186+
chunk_length,
187+
top_p,
188+
repetition_penalty,
189+
temperature,
190+
streaming,
191+
)
192+
193+
zh_model, en_model = load_model()
194+
max_attempts = 2
195+
best_wer = float("inf")
196+
best_audio = None
197+
best_sample_rate = None
198+
199+
for attempt in range(max_attempts):
200+
audio_generator = inference(
201+
text,
202+
enable_reference_audio,
203+
reference_audio,
204+
reference_text,
205+
max_new_tokens,
206+
chunk_length,
207+
top_p,
208+
repetition_penalty,
209+
temperature,
210+
streaming=False,
211+
)
212+
213+
# 获取音频数据
214+
for _ in audio_generator:
215+
pass
216+
_, (sample_rate, audio), message = _
217+
218+
if audio is None:
219+
return None, None, message
220+
221+
asr_result = batch_asr(
222+
zh_model if is_chinese(text) else en_model, [audio], sample_rate
223+
)[0]
224+
wer = calculate_wer(text, asr_result["text"])
225+
226+
if wer <= 0.3 and not asr_result["huge_gap"]:
227+
return None, (sample_rate, audio), None
228+
229+
if wer < best_wer:
230+
best_wer = wer
231+
best_audio = audio
232+
best_sample_rate = sample_rate
233+
234+
if attempt == max_attempts - 1:
235+
break
236+
237+
return None, (best_sample_rate, best_audio), None
238+
239+
240+
inference_stream = partial(inference_with_auto_rerank, streaming=True)
166241

167242
n_audios = 4
168243

@@ -186,7 +261,7 @@ def inference_wrapper(
186261
errors = []
187262

188263
for _ in range(batch_infer_num):
189-
items = inference(
264+
result = inference_with_auto_rerank(
190265
text,
191266
enable_reference_audio,
192267
reference_audio,
@@ -198,16 +273,13 @@ def inference_wrapper(
198273
temperature,
199274
)
200275

201-
try:
202-
item = next(items)
203-
except StopIteration:
204-
print("No more audio data available.")
276+
_, audio_data, error_message = result
205277

206278
audios.append(
207-
gr.Audio(value=item[1] if (item and item[1]) else None, visible=True),
279+
gr.Audio(value=audio_data if audio_data else None, visible=True),
208280
)
209281
errors.append(
210-
gr.HTML(value=item[2] if (item and item[2]) else None, visible=True),
282+
gr.HTML(value=error_message if error_message else None, visible=True),
211283
)
212284

213285
for _ in range(batch_infer_num, n_audios):

0 commit comments

Comments
 (0)