-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Description
def call(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray):
"""
WAV: np.ndarray, WAV_LENS: np.ndarray
LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
"""
# Ensure the default CUDA device is set correctly for this invocation
torch.cuda.set_device(self.device_id)
if self.device_id == 0:
print(f"device_id: {self.device_id}, TOKENS: {TOKENS.shape}, TOKEN_LENS: {TOKEN_LENS.shape}")
tokens_list = [TOKENS[i, :TOKEN_LENS[i, 0]] for i in range(len(TOKENS))]
# Decode ground-truth text strings (BYTES → str)
if GT_TEXT.ndim == 2:
gt_texts = [GT_TEXT[i, 0].decode("utf-8") for i in range(len(GT_TEXT))]
else:
gt_texts = [GT_TEXT[i].decode("utf-8") for i in range(len(GT_TEXT))]
wavs = []
for tokens in tokens_list:
prompt_text, prompt_speech_16k = get_random_prompt_from_dataset(self.dataset)
audio_tokens = torch.tensor(tokens, dtype=torch.long, device=self.asr_model.device).unsqueeze(0)
audio_hat = audio_decode_cosyvoice2(
audio_tokens,
prompt_text,
prompt_speech_16k,
self.codec_decoder,
)
在这里为随机挑选prompt会导致一个group里的n个样本的prompt都不一样吧?