Skip to content

GRPO prompt issue #1853

@hajararabiu869

Description

@hajararabiu869

 def call(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray):
        """
        WAV: np.ndarray, WAV_LENS: np.ndarray
        LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
        See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
        """
        # Ensure the default CUDA device is set correctly for this invocation
        torch.cuda.set_device(self.device_id)

        if self.device_id == 0:
            print(f"device_id: {self.device_id}, TOKENS: {TOKENS.shape}, TOKEN_LENS: {TOKEN_LENS.shape}")

        tokens_list = [TOKENS[i, :TOKEN_LENS[i, 0]] for i in range(len(TOKENS))]

        # Decode ground-truth text strings (BYTES → str)
        if GT_TEXT.ndim == 2:
            gt_texts = [GT_TEXT[i, 0].decode("utf-8") for i in range(len(GT_TEXT))]
        else:
            gt_texts = [GT_TEXT[i].decode("utf-8") for i in range(len(GT_TEXT))]

        wavs = []
        for tokens in tokens_list:
            prompt_text, prompt_speech_16k = get_random_prompt_from_dataset(self.dataset)
            audio_tokens = torch.tensor(tokens, dtype=torch.long, device=self.asr_model.device).unsqueeze(0)
            audio_hat = audio_decode_cosyvoice2(
                audio_tokens,
                prompt_text,
                prompt_speech_16k,
                self.codec_decoder,
            )
在这里为随机挑选prompt会导致一个group里的n个样本的prompt都不一样吧?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions