diff --git a/src/fairseq2/assets/cards/models/llama.yaml b/src/fairseq2/assets/cards/models/llama.yaml index 7570312a4..99dccb6b0 100644 --- a/src/fairseq2/assets/cards/models/llama.yaml +++ b/src/fairseq2/assets/cards/models/llama.yaml @@ -168,4 +168,34 @@ model_arch: llama3_1_8b checkpoint: "hg://deepseek-ai/DeepSeek-R1-Distill-Llama-8B" tokenizer: "hg://deepseek-ai/DeepSeek-R1-Distill-Llama-8B" tokenizer_family: llama -use_v2_tokenizer: true \ No newline at end of file +use_v2_tokenizer: true + +--- + +name: octothinker_8b_hybrid +model_family: llama +model_arch: llama3_1_8b +checkpoint: /datasets/pretrained-llms/OctoThinker-8B-Hybrid-Base/ +tokenizer: /datasets/pretrained-llms/OctoThinker-8B-Hybrid-Base/ +tokenizer_family: llama +use_v2_tokenizer: true + +--- + +name: octothinker_8b_long +model_family: llama +model_arch: llama3_1_8b +checkpoint: /datasets/pretrained-llms/OctoThinker-8B-Long-Base/ +tokenizer: /datasets/pretrained-llms/OctoThinker-8B-Long-Base/ +tokenizer_family: llama +use_v2_tokenizer: true + +--- + +name: octothinker_8b_short +model_family: llama +model_arch: llama3_1_8b +checkpoint: /datasets/pretrained-llms/OctoThinker-8B-Short-Base/ +tokenizer: /datasets/pretrained-llms/OctoThinker-8B-Short-Base/ +tokenizer_family: llama +use_v2_tokenizer: true diff --git a/src/fairseq2/recipes/lm/__init__.py b/src/fairseq2/recipes/lm/__init__.py index 907c8a312..2f4347536 100644 --- a/src/fairseq2/recipes/lm/__init__.py +++ b/src/fairseq2/recipes/lm/__init__.py @@ -48,6 +48,12 @@ from fairseq2.recipes.lm._online_finetune._generative_judge import ( J1PairwiseScoreExtractorHandler as J1PairwiseScoreExtractorHandler, ) +from fairseq2.recipes.lm._online_finetune._generative_judge import ( + J1KwiseScoreExtractor as J1KwiseScoreExtractor, +) +from fairseq2.recipes.lm._online_finetune._generative_judge import ( + J1KwiseScoreExtractorHandler as J1KwiseScoreExtractorHandler, +) from fairseq2.recipes.lm._online_finetune._generative_judge import ( J1PointwiseExtractor as J1PointwiseExtractor, ) @@ -84,6 +90,12 @@ from fairseq2.recipes.lm._online_finetune._remote_model import ( NoEnvGeneralVerifierPipeline as NoEnvGeneralVerifierPipeline, ) +from fairseq2.recipes.lm._online_finetune._remote_model import ( + NoEnvAceMathRMPipeline as NoEnvAceMathRMPipeline, +) +from fairseq2.recipes.lm._online_finetune._remote_model import ( + NoEnvSkyworkRMPipeline as NoEnvSkyworkRMPipeline, +) from fairseq2.recipes.lm._online_finetune._remote_model import ( RemoteModelHandler as RemoteModelHandler, ) @@ -93,6 +105,18 @@ from fairseq2.recipes.lm._online_finetune._rewards import ( AtheneVerifierHandler as AtheneVerifierHandler, ) +from fairseq2.recipes.lm._online_finetune._rewards import ( + SkyworkVerifier as SkyworkVerifier, +) +from fairseq2.recipes.lm._online_finetune._rewards import ( + SkyworkVerifierHandler as SkyworkVerifierHandler, +) +from fairseq2.recipes.lm._online_finetune._rewards import ( + AceMathVerifier as AceMathVerifier, +) +from fairseq2.recipes.lm._online_finetune._rewards import ( + AceMathVerifierHandler as AceMathVerifierHandler, +) from fairseq2.recipes.lm._online_finetune._rewards import ( GenerativePairwiseVerifier as GenerativePairwiseVerifier, ) @@ -105,6 +129,12 @@ from fairseq2.recipes.lm._online_finetune._rewards import ( GenerativePointwiseVerifierHandler as GenerativePointwiseVerifierHandler, ) +from fairseq2.recipes.lm._online_finetune._rewards import ( + GenerativeKwiseVerifier as GenerativeKwiseVerifier, +) +from fairseq2.recipes.lm._online_finetune._rewards import ( + GenerativeKwiseVerifierHandler as GenerativeKwiseVerifierHandler, +) from fairseq2.recipes.lm._online_finetune._rewards import GSM8kVerifier as GSM8kVerifier from fairseq2.recipes.lm._online_finetune._rewards import ( GSM8kVerifierHandler as GSM8kVerifierHandler, diff --git a/src/fairseq2/recipes/lm/_online_finetune/_common.py b/src/fairseq2/recipes/lm/_online_finetune/_common.py index c5ed7aa70..399efe369 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_common.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_common.py @@ -8,6 +8,7 @@ import contextlib import io +import re from dataclasses import dataclass from typing import List, cast @@ -17,14 +18,8 @@ from torch import Tensor from vllm import RequestOutput -from fairseq2.data import ( - CollateOptionsOverride, - Collater, - SequenceData, -) -from fairseq2.datasets import ( - SequenceBatch, -) +from fairseq2.data import CollateOptionsOverride, Collater, SequenceData +from fairseq2.datasets import SequenceBatch from fairseq2.datasets.preference import PreferenceBatch from fairseq2.datasets.prompt import PromptBatch from fairseq2.gang import Gang, Gangs @@ -93,9 +88,13 @@ def collate_with_target_mask( seq_data = cast(SequenceData, collater(to_collate)) + seq_lens = seq_data["seqs"]["seq_lens"] + assert isinstance(seq_lens, Tensor) or isinstance(seq_lens, list) + if isinstance(seq_lens, Tensor): + seq_lens = seq_lens.tolist() batch = SequenceBatch( seq_data["seqs"]["seqs"], - seq_data["seqs"]["seq_lens"], + seq_lens, target_mask=seq_data["target_loss_mask"]["seqs"], ) batch.to(device) @@ -395,6 +394,8 @@ def log_rollouts(prompt_batch: PromptBatch, rollouts, split_name, num_rollouts=1 prompt = prompt_batch.meta_info.get("prompt_raw")[0] elif "raw_prompt" in prompt_batch.meta_info: prompt = prompt_batch.meta_info.get("raw_prompt")[0] + elif "problem" in prompt_batch.meta_info: + prompt = prompt_batch.meta_info.get("problem")[0] else: # raw text prompt doesn't exist for this dataset prompt = "DUMMY PROMPT" @@ -416,6 +417,48 @@ def get_rollout_lengths(rollouts: List[SequenceData]): return rollout_lengths +def strip_think_tokens(rollouts: List[SequenceData]): + count_stripped, count_not_stripped, total_count, think_present = 0, 0, 0, 0 + for sample in rollouts: + for rollout in sample.outputs: + rollout_text = rollout.text + if "" in rollout_text: + think_present += 1 + if rollout.finish_reason == "length": + count_not_stripped += 1 + if rollout.finish_reason == "stop": + count_stripped += 1 + total_count += 1 + rollout.text = re.sub( + r".*?", "", rollout_text, flags=re.DOTALL + ).strip() + + log.info(f"Total count: {total_count}") + log.info(f"Think present: {think_present}") + log.info(f"Count stripped: {count_stripped/total_count}") + log.info(f"Count not stripped: {count_not_stripped/total_count}") + + return rollouts + +def get_failed_to_parse_answers(reward_output: dict, batch_size: int): + if "answers" in reward_output: + log.info(f"Answers: {reward_output['answers']}") + failed_to_parse = sum(answer is None for rollouts in reward_output["answers"] for answer in rollouts) + return failed_to_parse/batch_size + else: + return 0.0 + +def strip_for_octothinker(rollouts: List[SequenceData]): + for sample in rollouts: + for rollout in sample.outputs: + rollout_text = rollout.text + if "\nUser:" in rollout_text: + rollout_text = rollout_text[:rollout_text.find("\nUser:")] + rollout.text = rollout_text + + return rollouts + + class StatefulRolloutBag: """A stateful container for managing and reusing model rollouts across multiple micro-batches. @@ -504,11 +547,23 @@ def update_num_dummy_batches( @torch.inference_mode() def update_avg_reward(metric_bag: MetricBag, avg_reward): metric_bag.get(Mean, "avg_reward").update(avg_reward, weight=1) + +@torch.inference_mode() +def update_avg_second_reward(metric_bag: MetricBag, avg_reward): + metric_bag.get(Mean, "avg_second_reward").update(avg_reward, weight=1) + +@torch.inference_mode() +def update_reward_matches(metric_bag: MetricBag, reward_matches): + metric_bag.get(Mean, "reward_matches").update(reward_matches, weight=1) @torch.inference_mode() def update_std_reward(metric_bag: MetricBag, std_reward): metric_bag.get(Mean, "std_reward").update(std_reward, weight=1) + +@torch.inference_mode() +def update_failed_to_parse_answers(metric_bag: MetricBag, failed_to_parse_answers): + metric_bag.get(Mean, "failed_to_parse_answers").update(failed_to_parse_answers, weight=1) @torch.inference_mode() diff --git a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py index f261e7d4e..8c7dabdc6 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_generative_judge.py @@ -1,13 +1,13 @@ +# --------------------------------------- Pointwise prompts ---------------------------------------- # + POINTWISE_J1_PROMPT = """ You are given a user question and a response from an AI assistant. Your task is to act as an impartial judge and evaluate how well the response fulfills the user's instructions. You will be shown multiple responses to the same prompt, but only one at a time. Evaluate each response independently. -Think carefully about how to assess the quality of the response, and enclose your reasoning within and tags. Your reasoning should include your evaluation criteria, a clear understanding of what an ideal response would look like for this particular question, and a concrete example of such an ideal or reference answer if possible. Then compare the assistant's response to your ideal or reference answer, explaining how it aligns with or deviates from your expectations. Be specific and avoid vague or overly general judgments. Remain as objective as possible. - -Finally, assign the assistant's response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision. A higher score should indicate a higher-quality response. Enclose the score within and tags. +Think carefully about how to assess the quality of the response and assign the assistant's response a score 1 if the response is correct, and 0 if not. Enclose the score within and tags. Format your output like this: your_thinking_process - your_score + 0 or 1 Below are the user's question and the assistant's response: @@ -19,16 +19,60 @@ [The End of the Assistant's Answer] """ -PAIRWISE_J1_PROMPT = """ -You are given a user question and two responses from two AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. +# Uncomment this for non-verifiable prompt + +# POINTWISE_J1_PROMPT = """ +# You are given a user question and a response from an AI assistant. Your task is to act as an impartial judge and evaluate how well the response fulfills the user's instructions. You will be shown multiple responses to the same prompt, but only one at a time. Evaluate each response independently. + +# Think carefully about how to assess the quality of the response and assign the assistant's response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision. A higher score should indicate a higher-quality response. Enclose the score within and tags. + +# Format your output like this: +# your_thinking_process +# your_score + +# Below are the user's question and the assistant's response: + +# [User Question] +# {instruction} + +# [The Start of the Assistant's Answer] +# {response} +# [The End of the Assistant's Answer] +# """ + + +POINTWISE_J1_PROMPT_WITH_REF_ANSWER = """ +You are given a user question, a reference answer and a response from an AI assistant. Your task is to act as an impartial judge and evaluate how well the response fulfills the user's instructions. You will be shown multiple responses to the same prompt, but only one at a time. Evaluate each response independently. -First, provide your reasoning within and tags. This should include your evaluation criteria for a high-quality response, a detailed comparison of the two responses, and when helpful, a reference answer as part of your evaluation. Be explicit in your thought process, referencing your criteria and explaining how each response aligns with or deviates from them. +Think carefully about how to assess the quality of the response and assign the assistant's response a score 1 if the response is correct, and 0 if not. Enclose the score within and tags. -Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. +Format your output like this: + your_thinking_process + 0 or 1 + +Below are the user's question, reference answer and the assistant's response: + +[User Question] +{instruction} -Finally, provide your verdict within and tags, strictly following this format: -- [[A]] if Assistant A is better -- [[B]] if Assistant B is better +[Reference Answer] +{reference_answer} + +[The Start of the Assistant's Answer] +{response} +[The End of the Assistant's Answer] +""" + +# --------------------------------------- Pairwise prompts ---------------------------------------- # + +PAIRWISE_WITH_SCORES_J1_PROMPT = """ +You are given a user question and two responses from two AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. + +Think carefully about how to assess the quality of the responses and assign each response a score 1 if the response is correct, and 0 if not. Enclose the scores within the tags , and . + +Format your output like this: + your_thinking_process + 0 or 1 0 or 1 Below are the user's question and the two responses: @@ -44,20 +88,42 @@ [The End of Assistant B's Answer] """ -PAIRWISE_WITH_SCORES_J1_PROMPT = """ -You are given a user question and two responses from two AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. +# Uncomment this for non-verifiable prompt + +# PAIRWISE_WITH_SCORES_J1_PROMPT = """ +# You are given a user question and two responses from two AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. -First, provide your reasoning within and tags. This should include your evaluation criteria for a high-quality response, a detailed comparison of the two responses, and when helpful, a reference answer as part of your evaluation. Be explicit in your thought process, referencing your criteria and explaining how each response aligns with or deviates from them. +# Think carefully about how to assess the quality of the responses and assign each response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and . -Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. +# Format your output like this: +# your_thinking_process +# your_score_a your_score_b -Finally, assign the assistant's response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and . +# Below are the user's question and the two responses: + +# [User Question] +# {instruction} + +# [The Start of Assistant A's Answer] +# {response_A} +# [The End of Assistant A's Answer] + +# [The Start of Assistant B's Answer] +# {response_B} +# [The End of Assistant B's Answer] +# """ + +PAIRWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER = """ +You are given a user question, two responses from two AI assistants and the parsed version of the responses, and a reference answer. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. + +Think carefully about how to assess the quality of the responses and finally, utilize the reference answer for your judgement. Note that the parsed version of the responses are automatically extracted and may contain errors, therefore you should primarily rely on the original responses for your judgement. +Finally, assign each response a score 1 if the response is correct, and 0 if not. Enclose the scores within the tags , and . Format your output like this: your_thinking_process - your_score_a your_score_b + 0 or 1 0 or 1 -Below are the user's question and the two responses: +Below are the user's question, two responses and the parsed versions of the responses, and the reference answer: [User Question] {instruction} @@ -69,6 +135,64 @@ [The Start of Assistant B's Answer] {response_B} [The End of Assistant B's Answer] + +[The Parsed Version of Assistant A's Answer] +{parsed_response_A} + +[The Parsed Version of Assistant B's Answer] +{parsed_response_B} + +[Reference Answer] +{reference_answer} +""" + + +# --------------------------------------- K-wise prompts ---------------------------------------- # + +KWISE_WITH_SCORES_J1_PROMPT = """ +You are given a user question and {k} responses from {k} AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. + +Think carefully about how to assess the quality of the responses and finally, assign each response a score from 0 to 10, using either an integer or a decimal with up to 0.1 precision, with a higher score indicating a higher-quality response that better satisfies the criteria. Enclose the scores within the tags , and so on. + +Format your output like this: + your_thinking_process + your_score_1 + your_score_2 + your_score_3 +... + +Below are the user's question and the responses: + +[User Question] +{instruction} + +{responses} +""" + +KWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER = """ +You are given a user question, a reference answer, and {k} responses with the parsed versions from AI assistants. Your task is to act as an impartial judge and evaluate which response better follows the user's instructions and provides a higher-quality answer. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. + +Think carefully about how to assess the quality of the responses and finally, utilize the reference answer for your judgement. Note that the parsed version of the responses are automatically extracted and may contain errors, therefore you should primarily rely on the original responses for your judgement. +Finally, assign each response a score 1 if the response is correct, and 0 if not. Enclose the scores within the tags , and so on. + +Format your output like this: + your_thinking_process + 0 or 1 + 0 or 1 + 0 or 1 +... + +Below are the user's question, reference answer, responses and the parsed versions of the responses: + +[User Question] +{instruction} + +[Reference Answer] +{reference_answer} + +{responses} + +{parsed_responses} """ @@ -83,15 +207,18 @@ class JudgmentExtractorHandler(ABC): @abstractmethod - def create(self): ... + def create(self, tokenizer): + ... @property @abstractmethod - def name(self) -> str: ... + def name(self) -> str: + ... @property @abstractmethod - def config_kls(self) -> type[object]: ... + def config_kls(self) -> type[object]: + ... """ @@ -110,10 +237,12 @@ class JudgmentExtractor(ABC): """ @abstractmethod - def prompt(self) -> str: ... + def prompt(self) -> str: + ... @abstractmethod - def format_prompt(self, prompt_text, **kwargs: Any) -> str: ... + def format_prompt(self, prompt_text, **kwargs: Any) -> str: + ... """ Format the prompt text and additional arguments into a string suitable for input to the reward model. @@ -126,7 +255,8 @@ def format_prompt(self, prompt_text, **kwargs: Any) -> str: ... """ @abstractmethod - def extract(self, generation) -> float | str: ... + def extract(self, generation) -> float | str: + ... """ Extract the final scalar reward score from the model's response. @@ -145,7 +275,8 @@ def extract(self, generation) -> float | str: ... """ @abstractmethod - def aggregate(self, judgments) -> float | str: ... + def aggregate(self, judgments) -> float | str: + ... """ Aggregate multiple responses (judgments) from the reward model into a single value. @@ -165,8 +296,8 @@ def __init__(self): pass @override - def create(self): - return GeneralVerifierExtractor() + def create(self, tokenizer): + return GeneralVerifierExtractor(tokenizer) @property @override @@ -180,7 +311,7 @@ def config_kls(self): class GeneralVerifierExtractor(JudgmentExtractor): - def __init__(self): + def __init__(self, tokenizer): try: from math_verify import parse from math_verify.parser import ( @@ -253,8 +384,8 @@ def __init__(self): pass @override - def create(self): - return J1PointwiseExtractor() + def create(self, tokenizer): + return J1PointwiseExtractor(tokenizer) @property @override @@ -268,20 +399,35 @@ def config_kls(self): class J1PointwiseExtractor(JudgmentExtractor): - def __init__(self): - pass + def __init__(self, tokenizer): + self.tokenizer = tokenizer @override - def prompt(self): - return POINTWISE_J1_PROMPT + def prompt(self, reference_answer): + return ( + POINTWISE_J1_PROMPT + if reference_answer is None + else POINTWISE_J1_PROMPT_WITH_REF_ANSWER + ) @override def format_prompt(self, prompt_text, rollout_text, reference_answer): - content = self.prompt().format(instruction=prompt_text, response=rollout_text) + prompt_template = self.prompt(reference_answer) + content = ( + prompt_template.format(instruction=prompt_text, response=rollout_text) + if reference_answer is None + else prompt_template.format( + instruction=prompt_text, + reference_answer=reference_answer, + response=rollout_text, + ) + ) + wrapped_text = [{"role": "user", "content": content}] chat_str = self.tokenizer.apply_chat_template( wrapped_text, tokenize=False, add_generation_prompt=True ) + # log.info(f"Judge input = {chat_str}") return chat_str @override @@ -307,8 +453,8 @@ def __init__(self): pass @override - def create(self): - return J1PairwiseScoreExtractor() + def create(self, tokenizer): + return J1PairwiseScoreExtractor(tokenizer) @property @override @@ -322,20 +468,66 @@ def config_kls(self): class J1PairwiseScoreExtractor(JudgmentExtractor): - def __init__(self): - pass + def __init__(self, tokenizer): + self.tokenizer = tokenizer + try: + from math_verify import parse + from math_verify.parser import ( + ExprExtractionConfig, + LatexExtractionConfig, + NormalizationConfig, + ) + except ImportError: + raise ImportError( + "install mathverify from https://github.com/huggingface/Math-Verify" + ) + + self.student_extraction_config = ( + LatexExtractionConfig(boxed_match_priority=0), + ) + self.parse = parse @override - def prompt(self): - return PAIRWISE_WITH_SCORES_J1_PROMPT + def prompt(self, reference_answer): + return ( + PAIRWISE_WITH_SCORES_J1_PROMPT + if reference_answer is None + else PAIRWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER + ) + + def get_preferred_index(self, lst): + """ + math_verify parse returns a list of parsed answers, we want want the item at idex 1, which is a string + """ + if len(lst) > 1: + return lst[1] + elif len(lst) == 1: + return lst[0] + else: + return "None" @override - def format_prompt(self, prompt_text, rollout_A_text, rollout_B_text): - content = self.prompt().format( - instruction=prompt_text, - response_A=rollout_A_text, - response_B=rollout_B_text, + def format_prompt( + self, prompt_text, rollout_A_text, rollout_B_text, reference_answer + ): + prompt_template = self.prompt(reference_answer) + content = ( + prompt_template.format( + instruction=prompt_text, + response_A=rollout_A_text, + response_B=rollout_B_text, + ) + if reference_answer is None + else prompt_template.format( + instruction=prompt_text, + response_A=rollout_A_text, + response_B=rollout_B_text, + parsed_response_A=self.get_preferred_index(self.parse(rollout_A_text, self.student_extraction_config)), + parsed_response_B=self.get_preferred_index(self.parse(rollout_B_text, self.student_extraction_config)), + reference_answer=reference_answer, + ) ) + wrapped_text = [{"role": "user", "content": content}] chat_str = self.tokenizer.apply_chat_template( wrapped_text, tokenize=False, add_generation_prompt=True @@ -372,18 +564,18 @@ def aggregate(self, judgments): ) -class J1PairwisePreferenceExtractorHandler(JudgmentExtractorHandler): +class J1KwiseScoreExtractorHandler(JudgmentExtractorHandler): def __init__(self): pass @override - def create(self): - return J1PairwisePreferenceExtractor() + def create(self, tokenizer, k): + return J1KwiseScoreExtractor(tokenizer, k) @property @override def name(self): - return "j1_pairwise_preference_extractor" + return "j1_kwise_score_extractor" @property @override @@ -391,22 +583,92 @@ def config_kls(self): return None -class J1PairwisePreferenceExtractor(JudgmentExtractor): - def __init__(self): - pass +class J1KwiseScoreExtractor(JudgmentExtractor): + def __init__(self, tokenizer, k): + self.tokenizer = tokenizer + self.k = k + try: + from math_verify import parse + from math_verify.parser import ( + ExprExtractionConfig, + LatexExtractionConfig, + NormalizationConfig, + ) + except ImportError: + raise ImportError( + "install mathverify from https://github.com/huggingface/Math-Verify" + ) + + self.student_extraction_config = ( + LatexExtractionConfig(boxed_match_priority=0), + ) + self.parse = parse + + def get_preferred_index(self, lst): + """ + math_verify parse returns a list of parsed answers, we want want the item at idex 1, which is a string + """ + if len(lst) > 1: + return lst[1] + elif len(lst) == 1: + return lst[0] + else: + return "None" @override - def prompt(self): - return PAIRWISE_J1_PROMPT + def prompt(self, reference_answer): + return ( + KWISE_WITH_SCORES_J1_PROMPT + if reference_answer is None + else KWISE_WITH_SCORES_J1_PROMPT_WITH_REF_ANSWER + ) @override - def extract(self, generation): - matches = list( - re.findall(r"\s*\[\[(A|B)\]\]\s*", generation.strip()) + def format_prompt(self, prompt_text, rollouts, reference_answer): + prompt_template = self.prompt(reference_answer) + content = ( + prompt_template.format( + k=self.k, + instruction=prompt_text, + responses="".join([f"[Start of Assistant {assistant_id+1}'s Answer]\n{rollout}\n[End of Assistant {assistant_id+1}'s Answer]\n\n" for assistant_id, rollout in enumerate(rollouts)]) + ) + if reference_answer is None + else prompt_template.format( + k=self.k, + instruction=prompt_text, + responses="".join([f"[Start of Assistant {assistant_id+1}'s Answer]\n{rollout}\n[End of Assistant {assistant_id+1}'s Answer]\n\n" for assistant_id, rollout in enumerate(rollouts)]), + parsed_responses="".join([f"[The Parsed Version of Assistant {assistant_id+1}'s Answer]\n{self.get_preferred_index(self.parse(rollout, self.student_extraction_config))}\n\n" for assistant_id, rollout in enumerate(rollouts)]), + reference_answer=reference_answer, + ) ) - return matches[-1].strip() if matches else None + wrapped_text = [{"role": "user", "content": content}] + chat_str = self.tokenizer.apply_chat_template( + wrapped_text, tokenize=False, add_generation_prompt=True + ) + return chat_str + + @override + def extract(self, generation): + scores = [] + for i in range(self.k): + score_matches = re.findall( + rf"\s*([0-9]+(?:\.[0-9])?)\s*(?:/10)?\s*", + generation, + ) + if score_matches: + scores.append(float(score_matches[-1].strip())) + else: + scores.append(0.0) + + return scores @override def aggregate(self, judgments): - pass + avg_score = [0.0] * self.k + for scores in judgments: + for i, score in enumerate(scores): + avg_score[i] += score + + avg_score = [round(avg_score[i] / len(judgments), 4) for i in range(self.k)] + return avg_score diff --git a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py index 58cd226a7..d5493843c 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_grpo.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_grpo.py @@ -17,9 +17,7 @@ from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from fairseq2.context import RuntimeContext -from fairseq2.datasets import ( - SequenceBatch, -) +from fairseq2.datasets import SequenceBatch from fairseq2.datasets.preference import PreferenceBatch from fairseq2.datasets.prompt import PromptBatch from fairseq2.gang import Gang, Gangs @@ -36,10 +34,15 @@ collate_with_target_mask, compute_reference_logps, compute_token_level_entropy, + strip_for_octothinker, generate_rollouts, get_rollout_lengths, log_rollouts, + get_failed_to_parse_answers, + strip_think_tokens, update_avg_reward, + update_avg_second_reward, + update_reward_matches, update_avg_reward_len_norm, update_avg_rollout_length, update_batch_metrics, @@ -47,6 +50,7 @@ update_grpo_loss, update_logit_entropy, update_std_reward, + update_failed_to_parse_answers ) from fairseq2.recipes.lm._online_finetune._handler import OnlineFinetuneUnitHandler from fairseq2.recipes.lm._online_finetune._remote_model import ( @@ -139,6 +143,7 @@ class GrpoFinetuneUnit(TrainUnit[SequenceBatch]): _config: GrpoFinetuneConfig _model_update_group: PyNcclCommunicator _reward: VLLMOutputReward + _second_reward: VLLMOutputReward _display_name: str _rollout_bag: StatefulRolloutBag @@ -149,6 +154,7 @@ def __init__( vllm_model: RemoteVllmModel, vllm_actors: List[Union[RemoteVllmModel, RemoteHFModel]], reward, + second_reward, gangs: Gangs, config: GrpoFinetuneConfig, ) -> None: @@ -160,6 +166,7 @@ def __init__( self._vllm_model = vllm_model self._gangs = gangs self._reward = reward + self._second_reward = second_reward self._rollout_bag = StatefulRolloutBag( max_bag_steps=int( config.loss_config.group_size / config.loss_config.forward_group_size @@ -189,15 +196,19 @@ def validate_reward( ) -> tuple[Tensor, int]: if self._gangs.dp.rank == 0: policy_sampling_params = copy(self._vllm_model.sampling_params) - # For a pairwise RM, need to sample at least two judgments - policy_sampling_params.n = ( - 2 if self._reward.reward_name == "generative_pairwise_verifier" else 1 - ) for ( k, v, ) in self._config.loss_config.validation_vllm_sampling_params.items(): policy_sampling_params.__setattr__(k, v) + + # For a pairwise RM, need to sample at least two rollouts + if self._reward.reward_name == "generative_pairwise_verifier": + policy_sampling_params.n = 2 + elif self._reward.reward_name == "generative_kwise_verifier": + policy_sampling_params.n = self._config.reward.config.k + else: + policy_sampling_params.n = 1 else: policy_sampling_params = None rollouts = generate_rollouts( @@ -206,12 +217,30 @@ def validate_reward( vllm_model=self._vllm_model, sampling_params=policy_sampling_params, ) + if self._config.reward.config.strip_thinking: + rollouts = strip_think_tokens(rollouts) + else: + rollouts = strip_for_octothinker(rollouts) + + log.info("After stripping") if self._config.loss_config.log_rollouts: log_rollouts(prompt_batch, rollouts, "Valid") + log.info(f"Sampling params: {len(rollouts[0].outputs)}") + log.info(f"Rollouts: {len(rollouts[0].outputs)}") reward_output = self._reward.process_rollouts(rollouts, prompt_batch) log.info(f"Rewards: {reward_output['rewards']}") avg_reward = torch.tensor(reward_output["rewards"]).float().mean() std_reward = torch.tensor(reward_output["rewards"]).float().std() + failed_to_parse_answers = get_failed_to_parse_answers(reward_output, prompt_batch.batch_size) + + second_reward_output = self._second_reward.process_rollouts(rollouts, prompt_batch) + log.info(f"Second Rewards: {second_reward_output['rewards']}") + avg_second_reward = torch.tensor(second_reward_output["rewards"]).float().mean() + update_avg_second_reward(metric_bag, avg_second_reward) + + reward_matches = (torch.tensor(reward_output["rewards"]) == torch.tensor(second_reward_output["rewards"])).all(dim=1).float().mean() + log.info(f"Reward matches: {reward_matches}") + update_reward_matches(metric_bag, reward_matches) rollout_lengths = get_rollout_lengths(rollouts) avg_rollout_length = torch.tensor(rollout_lengths).float().mean() @@ -223,6 +252,7 @@ def validate_reward( update_avg_reward(metric_bag, avg_reward) update_batch_metrics(metric_bag, prompt_batch, train=False) update_std_reward(metric_bag, std_reward) + update_failed_to_parse_answers(metric_bag, failed_to_parse_answers) # returning dummy loss since trainer expects it return torch.tensor(0.0, device=self._gangs.dp.device), prompt_batch.batch_size @@ -262,9 +292,16 @@ def __call__( dp_gang=self._gangs.dp, vllm_model=self._vllm_model, ) + # if self._config.loss_config.log_rollouts: + # log_rollouts(prompt_batch, rollouts, "Train") + + if self._config.reward.config.strip_thinking: + rollouts = strip_think_tokens(rollouts) + else: + rollouts = strip_for_octothinker(rollouts) + log.info("After stripping") if self._config.loss_config.log_rollouts: log_rollouts(prompt_batch, rollouts, "Train") - reward_output = self._reward.process_rollouts(rollouts, prompt_batch) self._rollout_bag.save(rollouts, reward_output) @@ -349,6 +386,9 @@ def __call__( update_std_reward(metric_bag, std_reward) update_avg_reward(metric_bag, avg_reward) + + failed_to_parse_answers = get_failed_to_parse_answers(reward_output, prompt_batch.batch_size) + update_failed_to_parse_answers(metric_bag, failed_to_parse_answers) loss = grpo_loss @@ -466,6 +506,9 @@ class GrpoFinetuneConfig: vllm_reward_model_actor_name: str | None = None """Optional name of the Ray vLLM actor used as a reward model.""" + + vllm_second_reward_model_actor_name: str | None = None + """Optional name of the Ray vLLM actor used as a reward model.""" vllm_reference_model_actor_name: str | None = None """Optional name of the Ray vLLM actor used as a reference model.""" @@ -474,6 +517,10 @@ class GrpoFinetuneConfig: default_factory=lambda: RewardSection(name="gsm8k_verifier") ) """Configuration for the reward function that evaluates generated rollouts.""" + + second_reward: RewardSection = field( + default_factory=lambda: RewardSection(name="gsm8k_verifier") + ) vllm_sync: VllmSyncSection = field(default_factory=lambda: VllmSyncSection()) @@ -513,6 +560,8 @@ def create( vllm_model.sampling_params.n = config.loss_config.group_size vllm_reward_model = vllm_actors.get(config.vllm_reward_model_actor_name, None) + vllm_second_reward_model = vllm_actors.get(config.vllm_second_reward_model_actor_name, None) + reward_registry = self._context.get_registry(VLLMOutputRewardHandler) reward_name = config.reward.name reward_handler = reward_registry.get(reward_name) @@ -523,6 +572,16 @@ def create( gangs=gangs, context=self._context, ) + + second_reward_name = config.second_reward.name + second_reward_handler = reward_registry.get(second_reward_name) + second_reward = second_reward_handler.create( + reward_model=vllm_second_reward_model, + reward_name=second_reward_name, + reward_config=config.second_reward.config, + gangs=gangs, + context=self._context, + ) # sync models here before we start training if config.vllm_sync.sync_model_every_n_steps > 0: @@ -533,7 +592,7 @@ def create( log.info("GRPO setup complete.") return GrpoFinetuneUnit( - model, reference_model, vllm_model, vllm_actors, reward, gangs, config + model, reference_model, vllm_model, vllm_actors, reward, second_reward, gangs, config ) @property diff --git a/src/fairseq2/recipes/lm/_online_finetune/_handler.py b/src/fairseq2/recipes/lm/_online_finetune/_handler.py index 943528f51..0badf2b10 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_handler.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_handler.py @@ -19,15 +19,18 @@ class OnlineFinetuneUnitHandler(ABC): @abstractmethod def create( self, model: Model, gangs: Gangs, recipe_config: object, vllm_actors: object - ) -> TrainUnit[SequenceBatch]: ... + ) -> TrainUnit[SequenceBatch]: + ... @property @abstractmethod - def name(self) -> str: ... + def name(self) -> str: + ... @property @abstractmethod - def config_kls(self) -> type[object]: ... + def config_kls(self) -> type[object]: + ... class UnknownOnlineFinetuneUnitError(Exception): diff --git a/src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py b/src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py index 8f01c2ef7..3547be2df 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_online_dpo.py @@ -23,12 +23,7 @@ from fairseq2.context import RuntimeContext from fairseq2.data import CollateOptionsOverride, Collater, SequenceData -from fairseq2.datasets import ( - LengthBatching, - SequenceBatch, - StaticBatching, - SyncMode, -) +from fairseq2.datasets import LengthBatching, SequenceBatch, StaticBatching, SyncMode from fairseq2.datasets.preference import PreferenceBatch from fairseq2.datasets.prompt import PromptBatch from fairseq2.gang import Gang, Gangs @@ -44,10 +39,7 @@ from fairseq2.recipes import Model, TrainUnit from fairseq2.recipes.common import setup_reference_model from fairseq2.recipes.common._distributed import broadcast_model -from fairseq2.recipes.config import ( - ReferenceModelSection, - TrainerSection, -) +from fairseq2.recipes.config import ReferenceModelSection, TrainerSection from fairseq2.recipes.lm._instruction_finetune import update_nll_loss from fairseq2.recipes.lm._online_finetune._common import ( VllmSyncSection, @@ -62,6 +54,11 @@ update_avg_rollout_length, update_batch_metrics, update_dpo_loss, + update_grpo_batch_metrics, + compute_reference_logps, + collate_with_target_mask, + update_avg_loss_zeroer, + strip_think_tokens, update_logit_entropy, ) from fairseq2.recipes.lm._online_finetune._handler import OnlineFinetuneUnitHandler @@ -142,12 +139,16 @@ def validate_reward( ) -> tuple[Tensor, int]: if self._gangs.dp.rank == 0: policy_sampling_params = copy(self._vllm_model.sampling_params) - policy_sampling_params.n = 1 for ( k, v, ) in self._config.loss_config.validation_vllm_sampling_params.items(): policy_sampling_params.__setattr__(k, v) + + # For a pairwise RM, need to sample at least two rollouts + policy_sampling_params.n = ( + 2 if self._reward.reward_name == "generative_pairwise_verifier" else 1 + ) else: policy_sampling_params = None rollouts = generate_rollouts( @@ -158,6 +159,8 @@ def validate_reward( ) if self._config.loss_config.log_rollouts: log_rollouts(prompt_batch, rollouts, "Valid") + + rollouts = strip_think_tokens(rollouts) reward_output = self._reward.process_rollouts(rollouts, prompt_batch) avg_reward = torch.tensor(reward_output["rewards"]).float().mean() @@ -204,6 +207,8 @@ def __call__( if self._config.loss_config.log_rollouts: log_rollouts(prompt_batch, rollouts, "Train") + rollouts = strip_think_tokens(rollouts) + batch: PreferenceBatch batch, is_bad_batch, reward_output = self._reward.prepare_preference_batch( prompt_batch, rollouts @@ -459,6 +464,24 @@ def create( context=self._context, ) + + # TODO: decide converter as part of the model handler + if "llama" in model.name: + from fairseq2.models.llama._hg import _convert_parameter + + model._convert_parameter = _convert_parameter + else: + from fairseq2.models.qwen._hg import _convert_parameter + + model._convert_parameter = _convert_parameter + + # sync models here before we start training + if config.vllm_sync.sync_model_every_n_steps > 0: + maybe_sync_model(gangs, model, vllm_model, -1, -1, force_sync=True) + if config.vllm_sync.sync_ref_model_every_n_steps > 0: + maybe_sync_model(gangs, model, reference_model, -1, -1, force_sync=True) + + return OnlineDpoFinetuneUnit( model, reference_model, vllm_model, vllm_actors, reward, gangs, config ) diff --git a/src/fairseq2/recipes/lm/_online_finetune/_recipe.py b/src/fairseq2/recipes/lm/_online_finetune/_recipe.py index 3bfbf60ea..c580545d0 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_recipe.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_recipe.py @@ -67,9 +67,7 @@ OnlineCriterionSection, get_parameter_converter, ) -from fairseq2.recipes.lm._online_finetune._grpo import ( - GrpoFinetuneConfig, -) +from fairseq2.recipes.lm._online_finetune._grpo import GrpoFinetuneConfig from fairseq2.recipes.lm._online_finetune._handler import ( OnlineFinetuneUnitHandler, UnknownOnlineFinetuneUnitError, diff --git a/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py b/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py index 607a3cbac..560c885f7 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_remote_model.py @@ -27,10 +27,12 @@ from fairseq2.gang import Gangs from fairseq2.logging import log from fairseq2.nn._batch_layout import BatchLayout +from fairseq2.recipes.lm._online_finetune.third_party.ace_math import AceMathRMPipeline from fairseq2.recipes.lm._online_finetune.third_party.athene import AtheneRewardPipeline from fairseq2.recipes.lm._online_finetune.third_party.general_verifier import ( GeneralVerifierPipeline, ) +from fairseq2.recipes.lm._online_finetune.third_party.skywork import SkyworkRMPipeline from fairseq2.utils.structured import StructureError, structure @@ -48,6 +50,7 @@ class VllmEngineArgs: tokenizer: str = "/datasets/pretrained-llms/Llama-3.1-8B-Instruct" task: str = "generate" tensor_parallel_size: int = 4 + max_model_len: int | None = None trust_remote_code: bool = False model_impl: str = "auto" enforce_eager: bool = True @@ -90,6 +93,7 @@ def __init__(self, *args, **kwargs): # at the top-level del os.environ["CUDA_VISIBLE_DEVICES"] # os.environ["VLLM_USE_V1"] = "1" + os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" super().__init__(*args, **kwargs) self.ready = True # Set a flag or return a signal @@ -138,6 +142,48 @@ def name(self): return "general_verifier_pipeline" +@ray.remote +class NoEnvAceMathRMPipeline(AceMathRMPipeline): + """ + This is for running Ace Math RM pipeline with HF backend. + """ + + def __init__(self, *args, **kwargs): + # stop ray from manipulating CUDA_VISIBLE_DEVICES + # at the top-level + del os.environ["CUDA_VISIBLE_DEVICES"] + super().__init__(*args, **kwargs) + self.ready = True # Set a flag or return a signal + + def is_ready(self): + return self.ready + + @property + def name(self): + return "ace_math_rm_pipeline" + + +@ray.remote +class NoEnvSkyworkRMPipeline(SkyworkRMPipeline): + """ + This is for running Ace Math RM pipeline with HF backend. + """ + + def __init__(self, *args, **kwargs): + # stop ray from manipulating CUDA_VISIBLE_DEVICES + # at the top-level + del os.environ["CUDA_VISIBLE_DEVICES"] + super().__init__(*args, **kwargs) + self.ready = True # Set a flag or return a signal + + def is_ready(self): + return self.ready + + @property + def name(self): + return "skywork_rm_pipeline" + + class WorkerExtension: """ The class for vLLM's worker to inherit from. @@ -309,6 +355,7 @@ def setup_vllm_worker(self, ray_actor_name, vllm_engine_args, gangs: Gangs): ).remote( model=vllm_engine_args.model, tokenizer=vllm_engine_args.tokenizer, + max_model_len=vllm_engine_args.max_model_len, enforce_eager=vllm_engine_args.enforce_eager, worker_extension_cls="fairseq2.recipes.lm._online_finetune._remote_model.WorkerExtension", tensor_parallel_size=vllm_engine_args.tensor_parallel_size, @@ -438,6 +485,8 @@ def reward_from_model(self, prompt_list, batch_size=64): ray_outputs_flat = [o for sublist in ray_outputs for o in sublist] rewards = [o.outputs.data.item() for o in ray_outputs_flat] + log.info(f"Rewards = {rewards}") + return rewards @@ -537,7 +586,7 @@ def rollout_from_model(self, prompt_list, sampling_params=None, string_input=Fal "RemoteHFModel.rollout_from_model is not implemented. " ) - def reward_from_model(self, prompt_list, batch_size=64): + def reward_from_model(self, prompt_list, batch_size=2): # NOTE: need to batch inputs to hf.encode model for current models that aren't supported by hf rewards = [] outputs = [] @@ -564,15 +613,18 @@ class RemoteModelHandler(ABC): @abstractmethod def create( self, gangs: Gangs, unit_config: object - ) -> Union[RemoteVllmModel, RemoteHFModel]: ... + ) -> Union[RemoteVllmModel, RemoteHFModel]: + ... @property @abstractmethod - def name(self) -> str: ... + def name(self) -> str: + ... @property @abstractmethod - def config_kls(self) -> type[object]: ... + def config_kls(self) -> type[object]: + ... class RemoteRayModelHandler(RemoteModelHandler): diff --git a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py index a4c0bae8c..b168acf72 100644 --- a/src/fairseq2/recipes/lm/_online_finetune/_rewards.py +++ b/src/fairseq2/recipes/lm/_online_finetune/_rewards.py @@ -6,6 +6,9 @@ from __future__ import annotations +import itertools +import math +import random import re from abc import ABC, abstractmethod from dataclasses import dataclass, field @@ -20,6 +23,7 @@ from fairseq2.datasets.preference import PreferenceBatch from fairseq2.datasets.prompt import PromptBatch from fairseq2.gang import Gangs +from fairseq2.logging import log from fairseq2.recipes.lm._online_finetune._common import ( _mute_output, collate_with_target_mask, @@ -38,6 +42,9 @@ class RewardModelConfig: prompt_key: str = "prompt" tokenizer: str | None = None judgment_extractor: str | None = None + pair_type: str | None = None + k: int | None = None + strip_thinking: bool | None = None @dataclass(kw_only=True) @@ -50,23 +57,28 @@ class VLLMOutputRewardHandler(ABC): @abstractmethod def create( self, reward_model: Any, gangs: Gangs, reward_config: object - ) -> VLLMOutputReward: ... + ) -> VLLMOutputReward: + ... @property @abstractmethod - def name(self) -> str: ... + def name(self) -> str: + ... @property @abstractmethod - def config_kls(self) -> type[object]: ... + def config_kls(self) -> type[object]: + ... class VLLMOutputReward(ABC): @abstractmethod - def process_rollouts(self, vllm_outputs: list[RequestOutput]): ... + def process_rollouts(self, vllm_outputs: list[RequestOutput]): + ... @abstractmethod - def prepare_preference_batch(self, prompt_batch: PromptBatch, rollouts): ... + def prepare_preference_batch(self, prompt_batch: PromptBatch, rollouts): + ... class GSM8kVerifierHandler(VLLMOutputRewardHandler): @@ -242,6 +254,7 @@ def process_rollouts( batch_text = [] batch_tokens = [] batch_rewards = [] + batch_answers = [] reference_answers = prompt_batch.meta_info.get(self.answer_key) @@ -250,6 +263,7 @@ def process_rollouts( rollouts_tokens = [] i_reference_answer = reference_answers[i] rollouts_rewards = [] + rollouts_answers = [] for rollout_output in i_batch_request_output.outputs: rollouts_text.append(rollout_output.text) rollouts_tokens.append(rollout_output.token_ids) @@ -257,9 +271,126 @@ def process_rollouts( rollout_output.text, i_reference_answer ) rollouts_rewards.append(predicted_reward) + rollouts_answers.append(predicted_answer) batch_text.append(rollouts_text) batch_tokens.append(rollouts_tokens) batch_rewards.append(rollouts_rewards) + batch_answers.append(rollouts_answers) + + return {"text": batch_text, "tokens": batch_tokens, "rewards": batch_rewards, "answers": batch_answers} + + def prepare_preference_batch( + self, prompt_batch: PromptBatch, rollouts + ) -> PreferenceBatch: + + reward_output = self.process_rollouts(rollouts, prompt_batch) + + batch, is_bad_batch = prepare_preference_batch_random_pair( + prompt_batch=prompt_batch, reward_output=reward_output, gangs=self._gangs + ) + + return batch, is_bad_batch, reward_output + + +class SkyworkVerifierHandler(VLLMOutputRewardHandler): + def __init__(self): + pass + + @override + def create(self, reward_model, reward_name, reward_config, gangs, context): + if reward_config.tokenizer is not None: + tokenizer = reward_config.tokenizer + else: + tokenizer = "Skywork/Skywork-Reward-V2-Llama-3.1-8B" + + return SkyworkVerifier( + gangs, + context, + reward_model, + reward_name=reward_name, + answer_key=reward_config.answer_key, + prompt_key=reward_config.prompt_key, + tokenizer=tokenizer, + ) + + @property + @override + def name(self): + return "skywork_verifier" + + @property + @override + def config_kls(self): + return None + + +class SkyworkVerifier(VLLMOutputReward): + def __init__( + self, + gangs, + context, + reward_model, + reward_name, + answer_key, + prompt_key, + tokenizer, + ): + self.answer_key = answer_key + self.prompt_key = prompt_key + self._gangs = gangs + self._context = context + self.reward_model = reward_model + self.reward_name = reward_name + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) + + def wrap_text(self, prompt_text, rollout_text): + wrapped_text = [ + {"role": "user", "content": prompt_text}, + {"role": "assistant", "content": rollout_text}, + ] + chat_str = self.tokenizer.apply_chat_template(wrapped_text, tokenize=False) + if self.tokenizer.bos_token is not None and chat_str.startswith( + self.tokenizer.bos_token + ): + chat_str = chat_str[len(self.tokenizer.bos_token) :] + + return chat_str + + @override + def process_rollouts( + self, vllm_outputs: list[RequestOutput], prompt_batch: PromptBatch + ): + vllm_inputs = [] + batch_text = [] + batch_tokens = [] + + if vllm_outputs is None: + vllm_outputs = [None] * len(prompt_batch.prompts) + + text_prompts = prompt_batch.meta_info.get(self.prompt_key) + for i, (i_batch_request_output, prompt_text) in enumerate( + zip(vllm_outputs, text_prompts) + ): + + rollouts_text = [] + rollouts_tokens = [] + for rollout_output in i_batch_request_output.outputs: + rollout_text = rollout_output.text + vllm_input = self.wrap_text(prompt_text, rollout_text) + vllm_inputs.append(vllm_input) + rollouts_text.append(rollout_output.text) + rollouts_tokens.append(rollout_output.token_ids) + + batch_text.append(rollouts_text) + batch_tokens.append(rollouts_tokens) + + batch_rewards = generate_rewards( + vllm_inputs, dp_gang=self._gangs.dp, vllm_model=self.reward_model + ) + + # reshape batch_rewards to [Batch, Rollouts] + B, R = len(batch_text), len(batch_text[0]) # batch size, rollouts + batch_rewards = [batch_rewards[i * R : (i + 1) * R] for i in range(B)] return {"text": batch_text, "tokens": batch_tokens, "rewards": batch_rewards} @@ -269,8 +400,261 @@ def prepare_preference_batch( reward_output = self.process_rollouts(rollouts, prompt_batch) - batch, is_bad_batch = prepare_preference_batch_random_pair( - prompt_batch=prompt_batch, reward_output=reward_output, gangs=self._gangs + chosen_batch = [] + rejected_batch = [] + prompt_lens = [] + dummy_batch_ids = [] # keep posiitons of dummy pairs here + + # choosing first rollouts with reward 1 as chosen and 0 as rejected (sort of random given that we sample rollouts randomly) + for i_batch, (i_batch_rewards, i_batch_tokens) in enumerate( + zip(reward_output["rewards"], reward_output["tokens"]) + ): + chosen_rollout_position = i_batch_rewards.index(max(i_batch_rewards)) + rejected_rollout_position = i_batch_rewards.index(min(i_batch_rewards)) + + if chosen_rollout_position == rejected_rollout_position: + # cant form preference pair when we dont have such rollouts + # this will be dummy batch and we zero out loss + dummy_batch_ids.append(i_batch) + + chosen_rollout_tokens = list(i_batch_tokens[chosen_rollout_position]) + rejected_rollout_tokens = list(i_batch_tokens[rejected_rollout_position]) + prompt_tokens = prompt_batch.prompts[i_batch] + + chosen_tokens = prompt_tokens + chosen_rollout_tokens + chosen_batch.append(chosen_tokens) + + rejected_tokens = prompt_tokens + rejected_rollout_tokens + rejected_batch.append(rejected_tokens) + + prompt_lens.append(len(prompt_tokens)) + + filter_batch = lambda batch: [ + item for index, item in enumerate(batch) if index not in dummy_batch_ids + ] + + if len(dummy_batch_ids) == len(reward_output["tokens"]): + # entire batch does not have a valid preference pair + # we use it as dummy batch and zero the loss in the end + is_bad_batch = True + else: + # removing dummy pairs from the batch + chosen_batch = filter_batch(chosen_batch) + rejected_batch = filter_batch(rejected_batch) + prompt_lens = filter_batch(prompt_lens) + is_bad_batch = False + + prompt_lens = torch.tensor(prompt_lens) + + chosen_batch = [ + torch.tensor(sequence, device=self._gangs.dp.device) + for sequence in chosen_batch + ] + chosen_batch = collate_with_target_mask( + chosen_batch, prompt_lens, device=self._gangs.dp.device + ) + + rejected_batch = [ + torch.tensor(sequence, device=self._gangs.dp.device) + for sequence in rejected_batch + ] + rejected_batch = collate_with_target_mask( + rejected_batch, prompt_lens, device=self._gangs.dp.device + ) + + batch = PreferenceBatch( + chosen=chosen_batch, + rejected=rejected_batch, + reference_score_chosen=None, + reference_score_rejected=None, + ) + + return batch, is_bad_batch, reward_output + + +class AceMathVerifierHandler(VLLMOutputRewardHandler): + def __init__(self): + pass + + @override + def create(self, reward_model, reward_name, reward_config, gangs, context): + if reward_config.tokenizer is not None: + tokenizer = reward_config.tokenizer + else: + tokenizer = "nvidia/AceMath-7B-RM" + + return AceMathVerifier( + gangs, + context, + reward_model, + reward_name=reward_name, + answer_key=reward_config.answer_key, + prompt_key=reward_config.prompt_key, + tokenizer=tokenizer, + ) + + @property + @override + def name(self): + return "acemath_verifier" + + @property + @override + def config_kls(self): + return None + + +class AceMathVerifier(VLLMOutputReward): + def __init__( + self, + gangs, + context, + reward_model, + reward_name, + answer_key, + prompt_key, + tokenizer, + ): + self.answer_key = answer_key + self.prompt_key = prompt_key + self._gangs = gangs + self._context = context + self.reward_model = reward_model + self.reward_name = reward_name + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) + + def wrap_text(self, prompt_text, rollout_text): + wrapped_text = [ + { + "role": "system", + "content": "Please reason step by step, and check your final answer within \\boxed{}.", + }, + {"role": "user", "content": prompt_text}, + {"role": "assistant", "content": rollout_text}, + ] + chat_str = self.tokenizer.apply_chat_template( + wrapped_text, tokenize=False, add_generation_prompt=False + ) + if self.tokenizer.bos_token is not None and chat_str.startswith( + self.tokenizer.bos_token + ): + chat_str = chat_str[len(self.tokenizer.bos_token) :] + + return chat_str + + @override + def process_rollouts( + self, vllm_outputs: list[RequestOutput], prompt_batch: PromptBatch + ): + vllm_inputs = [] + batch_text = [] + batch_tokens = [] + + if vllm_outputs is None: + vllm_outputs = [None] * len(prompt_batch.prompts) + + text_prompts = prompt_batch.meta_info.get(self.prompt_key) + for i, (i_batch_request_output, prompt_text) in enumerate( + zip(vllm_outputs, text_prompts) + ): + + rollouts_text = [] + rollouts_tokens = [] + for rollout_output in i_batch_request_output.outputs: + rollout_text = rollout_output.text + vllm_input = self.wrap_text(prompt_text, rollout_text) + vllm_inputs.append(vllm_input) + rollouts_text.append(rollout_output.text) + rollouts_tokens.append(rollout_output.token_ids) + + batch_text.append(rollouts_text) + batch_tokens.append(rollouts_tokens) + + batch_rewards = generate_rewards( + vllm_inputs, dp_gang=self._gangs.dp, vllm_model=self.reward_model + ) + + log.info(f"Batch rewards = {batch_rewards}") + + # reshape batch_rewards to [Batch, Rollouts] + B, R = len(batch_text), len(batch_text[0]) # batch size, rollouts + batch_rewards = [batch_rewards[i * R : (i + 1) * R] for i in range(B)] + + return {"text": batch_text, "tokens": batch_tokens, "rewards": batch_rewards} + + def prepare_preference_batch( + self, prompt_batch: PromptBatch, rollouts + ) -> PreferenceBatch: + + reward_output = self.process_rollouts(rollouts, prompt_batch) + + chosen_batch = [] + rejected_batch = [] + prompt_lens = [] + dummy_batch_ids = [] # keep posiitons of dummy pairs here + + # choosing first rollouts with reward 1 as chosen and 0 as rejected (sort of random given that we sample rollouts randomly) + for i_batch, (i_batch_rewards, i_batch_tokens) in enumerate( + zip(reward_output["rewards"], reward_output["tokens"]) + ): + chosen_rollout_position = i_batch_rewards.index(max(i_batch_rewards)) + rejected_rollout_position = i_batch_rewards.index(min(i_batch_rewards)) + + if chosen_rollout_position == rejected_rollout_position: + # cant form preference pair when we dont have such rollouts + # this will be dummy batch and we zero out loss + dummy_batch_ids.append(i_batch) + + chosen_rollout_tokens = list(i_batch_tokens[chosen_rollout_position]) + rejected_rollout_tokens = list(i_batch_tokens[rejected_rollout_position]) + prompt_tokens = prompt_batch.prompts[i_batch] + + chosen_tokens = prompt_tokens + chosen_rollout_tokens + chosen_batch.append(chosen_tokens) + + rejected_tokens = prompt_tokens + rejected_rollout_tokens + rejected_batch.append(rejected_tokens) + + prompt_lens.append(len(prompt_tokens)) + + filter_batch = lambda batch: [ + item for index, item in enumerate(batch) if index not in dummy_batch_ids + ] + + if len(dummy_batch_ids) == len(reward_output["tokens"]): + # entire batch does not have a valid preference pair + # we use it as dummy batch and zero the loss in the end + is_bad_batch = True + else: + # removing dummy pairs from the batch + chosen_batch = filter_batch(chosen_batch) + rejected_batch = filter_batch(rejected_batch) + prompt_lens = filter_batch(prompt_lens) + is_bad_batch = False + + prompt_lens = torch.tensor(prompt_lens) + + chosen_batch = [ + torch.tensor(sequence, device=self._gangs.dp.device) + for sequence in chosen_batch + ] + chosen_batch = collate_with_target_mask( + chosen_batch, prompt_lens, device=self._gangs.dp.device + ) + + rejected_batch = [ + torch.tensor(sequence, device=self._gangs.dp.device) + for sequence in rejected_batch + ] + rejected_batch = collate_with_target_mask( + rejected_batch, prompt_lens, device=self._gangs.dp.device + ) + + batch = PreferenceBatch( + chosen=chosen_batch, + rejected=rejected_batch, + reference_score_chosen=None, + reference_score_rejected=None, ) return batch, is_bad_batch, reward_output @@ -527,7 +911,7 @@ def __init__( JudgmentExtractorHandler ) judgment_extractor_handler = judgment_extractor_registry.get(judgment_extractor) - self.judgment_extractor = judgment_extractor_handler.create() + self.judgment_extractor = judgment_extractor_handler.create(self.tokenizer) @override def process_rollouts( @@ -542,17 +926,18 @@ def process_rollouts( text_prompts = prompt_batch.meta_info.get(self.prompt_key) reference_answers = prompt_batch.meta_info.get(self.answer_key) + if reference_answers is None: + reference_answers = [None] * len(prompt_batch.prompts) for i, (i_batch_request_output, prompt_text) in enumerate( zip(vllm_outputs, text_prompts) ): rollouts_text = [] rollouts_tokens = [] - i_reference_answer = reference_answers[i] for rollout_output in i_batch_request_output.outputs: rollout_text = rollout_output.text vllm_input = self.judgment_extractor.format_prompt( - prompt_text, rollout_text, i_reference_answer + prompt_text, rollout_text, reference_answers[i] ) vllm_inputs.append(vllm_input) rollouts_text.append(rollout_output.text) @@ -565,6 +950,8 @@ def process_rollouts( vllm_inputs, dp_gang=self._gangs.dp, vllm_model=self.reward_model ) + log.info(f"Sample judgment: {batch_judgments[0].outputs[0].text}") + batch_rewards = [] for per_rollout_judgments in batch_judgments: per_rollout_rewards = [ @@ -672,12 +1059,18 @@ def create(self, reward_model, reward_name, reward_config, gangs, context): "Generative judges require implementing and specifying a judgment extractor" ) + if reward_config.pair_type is None: + raise RuntimeError( + "Pairwise generative judges require specifying how the pairs should be created" + ) + return GenerativePairwiseVerifier( gangs, context, reward_model, reward_name, judgment_extractor=reward_config.judgment_extractor, + pair_type=reward_config.pair_type, answer_key=reward_config.answer_key, prompt_key=reward_config.prompt_key, tokenizer=reward_config.tokenizer, @@ -702,6 +1095,7 @@ def __init__( reward_model, reward_name, judgment_extractor, + pair_type, answer_key, prompt_key, tokenizer, @@ -713,13 +1107,166 @@ def __init__( self.reward_model = reward_model self.reward_name = reward_name self.judgment_extractor = judgment_extractor + self.pair_type = pair_type self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) judgment_extractor_registry = self._context.get_registry( JudgmentExtractorHandler ) judgment_extractor_handler = judgment_extractor_registry.get(judgment_extractor) - self.judgment_extractor = judgment_extractor_handler.create() + self.judgment_extractor = judgment_extractor_handler.create(self.tokenizer) + + def construct_all_pairs( + self, + prompt_text, + i_batch_request_output, + vllm_inputs, + batch_pairwise_indices, + reference_answer, + ): + for a in range(len(i_batch_request_output.outputs)): + for b in range(len(i_batch_request_output.outputs)): + if a != b: + rollout_A_text = i_batch_request_output.outputs[a].text + rollout_B_text = i_batch_request_output.outputs[b].text + vllm_input = self.judgment_extractor.format_prompt( + prompt_text, rollout_A_text, rollout_B_text, reference_answer + ) + vllm_inputs.append(vllm_input) + batch_pairwise_indices.append((a, b)) + + return vllm_inputs, batch_pairwise_indices + + def construct_pairs_with_pivot( + self, + prompt_text, + i_batch_request_output, + vllm_inputs, + batch_pairwise_indices, + batch_pivot_pos, + reference_answer, + ): + pivot_idx = random.randint(0, len(i_batch_request_output.outputs) - 1) + pivot_rollout = i_batch_request_output.outputs[pivot_idx].text + for a in range(len(i_batch_request_output.outputs)): + rollout_A_text = i_batch_request_output.outputs[a].text + rollout_B_text = pivot_rollout + + batch_pairwise_indices.append((a, pivot_idx)) + batch_pivot_pos.append(1) # specifies which position is the pivot index + vllm_input = self.judgment_extractor.format_prompt( + prompt_text, rollout_A_text, rollout_B_text, reference_answer + ) + vllm_inputs.append(vllm_input) + + batch_pairwise_indices.append((pivot_idx, a)) + batch_pivot_pos.append(0) + vllm_input = self.judgment_extractor.format_prompt( + prompt_text, rollout_B_text, rollout_A_text, reference_answer + ) + vllm_inputs.append(vllm_input) + + return vllm_inputs, batch_pairwise_indices, batch_pivot_pos + + def construct_random_pairs( + self, + prompt_text, + i_batch_request_output, + vllm_inputs, + batch_pairwise_indices, + reference_answer, + ): + all_pairs = [ + (i, j) + for i in range(len(i_batch_request_output.outputs)) + for j in range(len(i_batch_request_output.outputs)) + if i != j + ] + random_pairs = random.sample(all_pairs, len(i_batch_request_output.outputs)) + + for a in range(len(i_batch_request_output.outputs)): + for b in range(len(i_batch_request_output.outputs)): + if (a, b) in random_pairs: + rollout_A_text = i_batch_request_output.outputs[a].text + rollout_B_text = i_batch_request_output.outputs[b].text + vllm_input = self.judgment_extractor.format_prompt( + prompt_text, rollout_A_text, rollout_B_text, reference_answer + ) + vllm_inputs.append(vllm_input) + batch_pairwise_indices.append((a, b)) + + return vllm_inputs, batch_pairwise_indices + + def convert_pairwise_rewards_to_pointwise( + self, + batch_pairwise_rewards, + batch_pairwise_indices, + batch_text, + batch_tokens, + pair_type, + batch_pivot_pos, + ): + B, R = len(batch_text), len(batch_text[0]) # batch size, rollouts + batch_pointwise_rewards = [] + + for i in range(B): + # Extract the pairwise rewards for each input + if pair_type == "pivot": + idx_start, idx_end = i * 2 * R, (i + 1) * 2 * R # 2R pairs + elif pair_type == "random_pairs": + idx_start, idx_end = i * R, (i + 1) * R # R pairs + elif pair_type == "all_pairs": + idx_start, idx_end = i * R * (R - 1), (i + 1) * R * ( + R - 1 + ) # R(R-1) pairs + + prompt_pairwise_rewards = batch_pairwise_rewards[idx_start:idx_end] + prompt_pairwise_indices = batch_pairwise_indices[idx_start:idx_end] + + # If not pivot, create dummy pivots because both rewards will be considered + prompt_pivot_pos = ( + batch_pivot_pos[idx_start:idx_end] + if pair_type == "pivot" + else [0] * (idx_end - idx_start + 1) + ) + + # Sum the rewards for each rollout and count how many times each rollout appears in pairwise judgments + prompt_rewards = [0.0] * R + counts = [0] * R + + for index, rewards, pivot_pos in zip( + prompt_pairwise_indices, prompt_pairwise_rewards, prompt_pivot_pos + ): + non_pivot_pos = 1 - pivot_pos + prompt_rewards[index[non_pivot_pos]] += rewards[non_pivot_pos] + counts[index[non_pivot_pos]] += 1 + + # If not pivot setup, consider rewards of the other (pivot) rollout as well + if pair_type != "pivot": + prompt_rewards[index[non_pivot_pos]] += rewards[non_pivot_pos] + counts[index[non_pivot_pos]] += 1 + + log.info(f"Counts of each rollout: {counts}") + + log.info(f"Number of rollouts wrt batch tokens = {len(batch_tokens[i])}") + assert len(batch_tokens[i]) == R + + # Compute average pointwise rewards + avg_prompt_rewards = [0.0] * R + for j in range(len(prompt_rewards)): + if counts[j] > 0: + avg_prompt_rewards[j] = round(prompt_rewards[j] / counts[j], 4) + # num_tokens = len(batch_tokens[i][j]) + # log.info(f"Num tokens: {num_tokens}") + # correctness_reward = prompt_rewards[j] / counts[j] + # log.info(f"Correctness reward: {correctness_reward}") + # length_penalty = 0.001 * num_tokens + # avg_prompt_rewards[j] = round(correctness_reward - length_penalty, 4) + log.info(f"Overall reward: {avg_prompt_rewards[j]}") + + batch_pointwise_rewards.append(avg_prompt_rewards) + + return batch_pointwise_rewards @override def process_rollouts( @@ -729,11 +1276,15 @@ def process_rollouts( batch_text = [] batch_tokens = [] batch_pairwise_indices = [] + batch_pivot_pos = [] if vllm_outputs is None: vllm_outputs = [None] * len(prompt_batch.prompts) text_prompts = prompt_batch.meta_info.get(self.prompt_key) + reference_answers = prompt_batch.meta_info.get(self.answer_key) + if reference_answers is None: + reference_answers = [None] * len(prompt_batch.prompts) for i, (i_batch_request_output, prompt_text) in enumerate( zip(vllm_outputs, text_prompts) ): @@ -747,19 +1298,35 @@ def process_rollouts( batch_text.append(rollouts_text) batch_tokens.append(rollouts_tokens) - prompt_pairwise_indices = [] - for a in range(len(i_batch_request_output.outputs)): - for b in range(len(i_batch_request_output.outputs)): - if a != b: - rollout_A_text = i_batch_request_output.outputs[a].text - rollout_B_text = i_batch_request_output.outputs[b].text - vllm_input = self.judgment_extractor.format_prompt( - prompt_text, rollout_A_text, rollout_B_text - ) - vllm_inputs.append(vllm_input) - prompt_pairwise_indices.append((a, b)) - - batch_pairwise_indices.append(prompt_pairwise_indices) + if self.pair_type == "all_pairs": + vllm_inputs, batch_pairwise_indices = self.construct_all_pairs( + prompt_text, + i_batch_request_output, + vllm_inputs, + batch_pairwise_indices, + reference_answers[i], + ) + elif self.pair_type == "pivot": + ( + vllm_inputs, + batch_pairwise_indices, + batch_pivot_pos, + ) = self.construct_pairs_with_pivot( + prompt_text, + i_batch_request_output, + vllm_inputs, + batch_pairwise_indices, + batch_pivot_pos, + reference_answers[i], + ) + elif self.pair_type == "random_pairs": + vllm_inputs, batch_pairwise_indices = self.construct_random_pairs( + prompt_text, + i_batch_request_output, + vllm_inputs, + batch_pairwise_indices, + reference_answers[i], + ) batch_pairwise_judgments = generate_rewards_generative( vllm_inputs, @@ -767,6 +1334,12 @@ def process_rollouts( vllm_model=self.reward_model, ) + log.info(f"Number of pairwise comparisons: {len(batch_pairwise_judgments)}") + log.info( + f"Number of judgments per pairwise comparison: {len(batch_pairwise_judgments[0].outputs)}" + ) + log.info(f"Sample judgment: {batch_pairwise_judgments[0].outputs[0].text}") + batch_pairwise_rewards = [] for per_rollout_judgments in batch_pairwise_judgments: per_rollout_rewards = [ @@ -777,29 +1350,346 @@ def process_rollouts( self.judgment_extractor.aggregate(per_rollout_rewards) ) + batch_rewards = self.convert_pairwise_rewards_to_pointwise( + batch_pairwise_rewards, + batch_pairwise_indices, + batch_text, + batch_tokens, + self.pair_type, + batch_pivot_pos, + ) + + log.info(f"Batch Rewards: {batch_rewards}") + + return {"text": batch_text, "tokens": batch_tokens, "rewards": batch_rewards} + + def prepare_preference_batch( + self, prompt_batch: PromptBatch, rollouts + ) -> PreferenceBatch: + + reward_output = self.process_rollouts(rollouts, prompt_batch) + + chosen_batch = [] + rejected_batch = [] + prompt_lens = [] + dummy_batch_ids = [] # keep posiitons of dummy pairs here + + # choosing first rollouts with reward 1 as chosen and 0 as rejected (sort of random given that we sample rollouts randomly) + for i_batch, (i_batch_rewards, i_batch_tokens) in enumerate( + zip(reward_output["rewards"], reward_output["tokens"]) + ): + + chosen_rollout_position = i_batch_rewards.index(max(i_batch_rewards)) + rejected_rollout_position = i_batch_rewards.index(min(i_batch_rewards)) + + if chosen_rollout_position == rejected_rollout_position: + # cant form preference pair when we dont have such rollouts + # this will be dummy batch and we zero out loss + dummy_batch_ids.append(i_batch) + + chosen_rollout_tokens = list(i_batch_tokens[chosen_rollout_position]) + rejected_rollout_tokens = list(i_batch_tokens[rejected_rollout_position]) + prompt_tokens = prompt_batch.prompts[i_batch] + + chosen_tokens = prompt_tokens + chosen_rollout_tokens + chosen_batch.append(chosen_tokens) + + rejected_tokens = prompt_tokens + rejected_rollout_tokens + rejected_batch.append(rejected_tokens) + + prompt_lens.append(len(prompt_tokens)) + + filter_batch = lambda batch: [ + item for index, item in enumerate(batch) if index not in dummy_batch_ids + ] + + if len(dummy_batch_ids) == len(reward_output["tokens"]): + # entire batch does not have a valid preference pair + # we use it as dummy batch and zero the loss in the end + is_bad_batch = True + else: + # removing dummy pairs from the batch + chosen_batch = filter_batch(chosen_batch) + rejected_batch = filter_batch(rejected_batch) + prompt_lens = filter_batch(prompt_lens) + is_bad_batch = False + + prompt_lens = torch.tensor(prompt_lens) + + chosen_batch = [ + torch.tensor(sequence, device=self._gangs.dp.device) + for sequence in chosen_batch + ] + chosen_batch = collate_with_target_mask( + chosen_batch, prompt_lens, device=self._gangs.dp.device + ) + + rejected_batch = [ + torch.tensor(sequence, device=self._gangs.dp.device) + for sequence in rejected_batch + ] + rejected_batch = collate_with_target_mask( + rejected_batch, prompt_lens, device=self._gangs.dp.device + ) + + batch = PreferenceBatch( + chosen=chosen_batch, + rejected=rejected_batch, + reference_score_chosen=None, + reference_score_rejected=None, + ) + + return batch, is_bad_batch, reward_output + + +class GenerativeKwiseVerifierHandler(VLLMOutputRewardHandler): + def __init__(self): + pass + + @override + def create(self, reward_model, reward_name, reward_config, gangs, context): + if reward_config.tokenizer is None: + raise RuntimeError("Generative judges require tokenizer") + + if reward_config.judgment_extractor is None: + raise RuntimeError( + "Generative judges require implementing and specifying a judgment extractor" + ) + + if reward_config.k is None: + raise RuntimeError( + "Kwise generative judges require specifying the size of the tuple k" + ) + + return GenerativeKwiseVerifier( + gangs, + context, + reward_model, + reward_name, + judgment_extractor=reward_config.judgment_extractor, + k=reward_config.k, + answer_key=reward_config.answer_key, + prompt_key=reward_config.prompt_key, + tokenizer=reward_config.tokenizer, + ) + + @property + @override + def name(self): + return "generative_kwise_verifier" + + @property + @override + def config_kls(self): + return None + + +class GenerativeKwiseVerifier(VLLMOutputReward): + def __init__( + self, + gangs, + context, + reward_model, + reward_name, + judgment_extractor, + k, + answer_key, + prompt_key, + tokenizer, + ): + self.answer_key = answer_key + self.prompt_key = prompt_key + self._gangs = gangs + self._context = context + self.reward_model = reward_model + self.reward_name = reward_name + self.judgment_extractor = judgment_extractor + self.k = k + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) + + judgment_extractor_registry = self._context.get_registry( + JudgmentExtractorHandler + ) + judgment_extractor_handler = judgment_extractor_registry.get(judgment_extractor) + self.judgment_extractor = judgment_extractor_handler.create( + self.tokenizer, self.k + ) + + # def construct_all_k_tuples( + # self, + # prompt_text, + # i_batch_request_output, + # vllm_inputs, + # batch_kwise_indices, + # reference_answer, + # R, + # k, + # ): + # all_k_tuples = list(itertools.combinations(list(range(R)), k)) + # for k_tuple in all_k_tuples: + # k_list = list(k_tuple) + # random.shuffle(k_list) + # batch_kwise_indices.append(k_list) + # response_string = "" + # for assistant_id, idx in enumerate(k_list): + # rollout = i_batch_request_output.outputs[idx].text + # response_string += f"[Start of Assistant {assistant_id+1} Answer]\n{rollout}\n[End of Assistant {assistant_id+1} Answer]\n\n" + # response_string = response_string.strip() + + # vllm_input = self.judgment_extractor.format_prompt( + # prompt_text, response_string, reference_answer + # ) + # vllm_inputs.append(vllm_input) + + # return vllm_inputs, batch_kwise_indices + + def construct_all_k_tuples( + self, + prompt_text, + i_batch_request_output, + vllm_inputs, + batch_kwise_indices, + reference_answer, + R, + k, + ): + all_k_tuples = list(itertools.combinations(list(range(R)), k)) + for k_tuple in all_k_tuples: + for k_list in itertools.permutations(k_tuple): + k_list = list(k_list) + batch_kwise_indices.append(k_list) + response_list = [i_batch_request_output.outputs[idx].text for idx in k_list] + + vllm_input = self.judgment_extractor.format_prompt( + prompt_text, response_list, reference_answer + ) + vllm_inputs.append(vllm_input) + return vllm_inputs, batch_kwise_indices + + + def convert_kwise_rewards_to_pointwise( + self, + batch_kwise_rewards, + batch_kwise_indices, + batch_text, + batch_tokens, + k, + ): B, R = len(batch_text), len(batch_text[0]) # batch size, rollouts + batch_pointwise_rewards = [] - # Logic to convert pairwise scores into pointwise rewards - # Can be done differently too - batch_rewards = [] - for i in range(B): - prompt_pairwise_rewards = batch_pairwise_rewards[ - i * R * (R - 1) : (i + 1) * R * (R - 1) - ] - prompt_pairwise_indices = batch_pairwise_indices[i] + for prompt_idx in range(B): + # Extract the kwise rewards for each input + # num = math.comb(R, k) + num = math.perm(R, k) + idx_start, idx_end = ( + prompt_idx * num, + (prompt_idx + 1) * num, + ) # R choose k tuples + + prompt_kwise_rewards = batch_kwise_rewards[idx_start:idx_end] + prompt_kwise_indices = batch_kwise_indices[idx_start:idx_end] + + # Sum the rewards for each rollout and count how many times each rollout appears in pairwise judgments prompt_rewards = [0.0] * R - for index, rewards in zip(prompt_pairwise_indices, prompt_pairwise_rewards): - prompt_rewards[index[0]] += rewards[0] - prompt_rewards[index[1]] += rewards[1] - - # Average score over 2*(R-1) pairwise comparisons - if (R - 1) > 0: - prompt_rewards = [ - round(prompt_reward / (2 * (R - 1)), 4) - for prompt_reward in prompt_rewards - ] - - batch_rewards.append(prompt_rewards) + counts = [0] * R + + # For example, indices would be [0, 3, 4] which means rollout 0, 3 and 4 + # rewards would be [7, 8, 9] which means rollout 0 has reward 7 and so on + for indices, rewards in zip(prompt_kwise_indices, prompt_kwise_rewards): + for rollout_idx in range(k): + prompt_rewards[indices[rollout_idx]] += rewards[rollout_idx] + counts[indices[rollout_idx]] += 1 + + log.info(f"Counts of each rollout: {counts}") + + log.info( + f"Number of rollouts wrt batch tokens = {len(batch_tokens[prompt_idx])}" + ) + assert len(batch_tokens[prompt_idx]) == R + + # Compute average pointwise rewards + avg_prompt_rewards = [0.0] * R + for j in range(R): + if counts[j] > 0: + avg_prompt_rewards[j] = round(prompt_rewards[j] / counts[j], 4) + log.info(f"Overall reward: {avg_prompt_rewards[j]}") + + batch_pointwise_rewards.append(avg_prompt_rewards) + + return batch_pointwise_rewards + + @override + def process_rollouts( + self, vllm_outputs: list[RequestOutput], prompt_batch: PromptBatch + ): + vllm_inputs = [] + batch_text = [] + batch_tokens = [] + batch_kwise_indices = [] + + if vllm_outputs is None: + vllm_outputs = [None] * len(prompt_batch.prompts) + + text_prompts = prompt_batch.meta_info.get(self.prompt_key) + reference_answers = prompt_batch.meta_info.get(self.answer_key) + if reference_answers is None: + reference_answers = [None] * len(prompt_batch.prompts) + for i, (i_batch_request_output, prompt_text) in enumerate( + zip(vllm_outputs, text_prompts) + ): + rollouts_text = [ + rollout_output.text for rollout_output in i_batch_request_output.outputs + ] + rollouts_tokens = [ + rollout_output.token_ids + for rollout_output in i_batch_request_output.outputs + ] + batch_text.append(rollouts_text) + batch_tokens.append(rollouts_tokens) + + R = len(rollouts_text) + vllm_inputs, batch_kwise_indices = self.construct_all_k_tuples( + prompt_text, + i_batch_request_output, + vllm_inputs, + batch_kwise_indices, + reference_answers[i], + R, + self.k, + ) + + batch_kwise_judgments = generate_rewards_generative( + vllm_inputs, + dp_gang=self._gangs.dp, + vllm_model=self.reward_model, + ) + + log.info(f"Number of kwise comparisons: {len(batch_kwise_judgments)}") + log.info( + f"Number of judgments per kwise comparison: {len(batch_kwise_judgments[0].outputs)}" + ) + log.info(f"Sample judgment: {batch_kwise_judgments[0].outputs[0].text}") + + batch_kwise_rewards = [] + for per_rollout_judgments in batch_kwise_judgments: + per_rollout_rewards = [ + self.judgment_extractor.extract(judgment.text) + for judgment in per_rollout_judgments.outputs + ] + batch_kwise_rewards.append( + self.judgment_extractor.aggregate(per_rollout_rewards) + ) + + batch_rewards = self.convert_kwise_rewards_to_pointwise( + batch_kwise_rewards, + batch_kwise_indices, + batch_text, + batch_tokens, + self.k, + ) + + log.info(f"Batch Rewards: {batch_rewards}") return {"text": batch_text, "tokens": batch_tokens, "rewards": batch_rewards} diff --git a/src/fairseq2/recipes/lm/_online_finetune/third_party/ace_math.py b/src/fairseq2/recipes/lm/_online_finetune/third_party/ace_math.py new file mode 100644 index 000000000..3fc33275d --- /dev/null +++ b/src/fairseq2/recipes/lm/_online_finetune/third_party/ace_math.py @@ -0,0 +1,33 @@ +import torch +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +from fairseq2.logging import log + + +class AceMathRMPipeline: + def __init__(self, *args, **kwargs): + model_path = "/datasets/pretrained-llms/AceMath-7B-RM" + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True + ) + self.model = AutoModelForSequenceClassification.from_pretrained( + model_path, + num_labels=1, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + device_map="auto", + ).eval() + self.model.config.pad_token_id = self.tokenizer.pad_token_id + + def __call__(self, prompt_chunk): + inputs = self.tokenizer( + prompt_chunk, return_tensors="pt", padding=True, add_special_tokens=False + ).to(self.model.device) + + outputs = self.model(**inputs)[0] + log.info(f"outputs = {outputs}") + rewards = [output[0] for output in outputs] + + log.info(f"Length of rewards = {len(rewards)}") + + return rewards diff --git a/src/fairseq2/recipes/lm/_online_finetune/third_party/skywork.py b/src/fairseq2/recipes/lm/_online_finetune/third_party/skywork.py new file mode 100644 index 000000000..3a60da8e7 --- /dev/null +++ b/src/fairseq2/recipes/lm/_online_finetune/third_party/skywork.py @@ -0,0 +1,30 @@ +import torch +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +from fairseq2.logging import log + + +class SkyworkRMPipeline: + def __init__(self, *args, **kwargs): + model_path = "/datasets/pretrained-llms/Skywork-Reward-V2-Llama-3.1-8B" + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True + ) + self.model = AutoModelForSequenceClassification.from_pretrained( + model_path, + num_labels=1, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + device_map="auto", + ).eval() + self.model.config.pad_token_id = self.tokenizer.pad_token_id + + def __call__(self, prompt_chunk): + inputs = self.tokenizer( + prompt_chunk, return_tensors="pt", padding=True, add_special_tokens=False + ).to(self.model.device) + + outputs = self.model(**inputs)[0] + rewards = [output[0] for output in outputs] + + return rewards diff --git a/src/fairseq2/setup/_metrics.py b/src/fairseq2/setup/_metrics.py index 0a38a92f2..071d78044 100644 --- a/src/fairseq2/setup/_metrics.py +++ b/src/fairseq2/setup/_metrics.py @@ -79,6 +79,8 @@ def register(name: str, *args: Any, **kwargs: Any) -> None: register("simpo_loss", "SimPO Loss", 0, format_as_float) register("grpo_loss", "GRPO Loss", 0, format_as_float) register("avg_reward", "Reward", 1, format_as_float) + register("avg_second_reward", "Second Reward", 1, format_as_float) + register("reward_matches", "Reward Matches", 1, format_as_float) register("std_reward", "StdDev Reward", 1, format_as_float) register("avg_reward_len_norm","Length Normalized Reward", 1, format_as_float) register("chosen_logps", "Chosen Sequence Log Probabilities", 50, format_as_float) diff --git a/src/fairseq2/setup/_po_finetune_units.py b/src/fairseq2/setup/_po_finetune_units.py index 4db4bd48e..b52b880b3 100644 --- a/src/fairseq2/setup/_po_finetune_units.py +++ b/src/fairseq2/setup/_po_finetune_units.py @@ -11,19 +11,25 @@ from fairseq2.context import RuntimeContext from fairseq2.recipes.lm import ( # GroupDpoFinetuneUnitHandler, AtheneVerifierHandler, + SkyworkVerifierHandler, + AceMathVerifierHandler, CpoFinetuneUnitHandler, DpoFinetuneUnitHandler, GeneralVerifierExtractorHandler, GenerativePairwiseVerifierHandler, + GenerativeKwiseVerifierHandler, GenerativePointwiseVerifierHandler, GrpoFinetuneUnitHandler, GSM8kVerifierHandler, J1PairwiseScoreExtractorHandler, + J1KwiseScoreExtractorHandler, J1PointwiseExtractorHandler, JudgmentExtractorHandler, MathVerifyHandler, NoEnvAtheneRewardPipeline, NoEnvGeneralVerifierPipeline, + NoEnvAceMathRMPipeline, + NoEnvSkyworkRMPipeline, OnlineDpoFinetuneUnitHandler, OnlineFinetuneUnitHandler, OrpoFinetuneUnitHandler, @@ -86,6 +92,14 @@ def _register_online_finetune_units(context: RuntimeContext) -> None: # GSM8kVerifier handler = GSM8kVerifierHandler() registry.register(handler.name, handler) + + # SkyworkVerifier + handler = SkyworkVerifierHandler() + registry.register(handler.name, handler) + + # AceMath RM + handler = AceMathVerifierHandler() + registry.register(handler.name, handler) # AtheneVerifier handler = AtheneVerifierHandler() @@ -102,6 +116,10 @@ def _register_online_finetune_units(context: RuntimeContext) -> None: # GenerativePairwiseVerifier handler = GenerativePairwiseVerifierHandler() registry.register(handler.name, handler) + + # GenerativeKwiseVerifier + handler = GenerativeKwiseVerifierHandler() + registry.register(handler.name, handler) registry = context.get_registry(RemoteModelHandler) @@ -112,6 +130,14 @@ def _register_online_finetune_units(context: RuntimeContext) -> None: # NoEnvGeneralVerifierPipeline handler = NoEnvGeneralVerifierPipeline registry.register(handler.name, handler) + + # NoEnvAceMathRMPipeline + handler = NoEnvAceMathRMPipeline + registry.register(handler.name, handler) + + # NoEnvAceMathRMPipeline + handler = NoEnvSkyworkRMPipeline + registry.register(handler.name, handler) # Generative judgment extractors registry = context.get_registry(JudgmentExtractorHandler) @@ -121,6 +147,9 @@ def _register_online_finetune_units(context: RuntimeContext) -> None: handler = J1PairwiseScoreExtractorHandler() registry.register(handler.name, handler) + + handler = J1KwiseScoreExtractorHandler() + registry.register(handler.name, handler) handler = GeneralVerifierExtractorHandler() registry.register(handler.name, handler)