|
1 | 1 | import torch
|
2 |
| -from math_verify import parse, verify |
| 2 | +from latex2sympy2_extended import NormalizationConfig |
| 3 | +from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify |
3 | 4 |
|
4 | 5 | from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
|
5 | 6 |
|
| 7 | +CANNOT_PARSE_GT_ANSWER = -1 |
| 8 | +CANNOT_PARSE_PREDICTION = -2 |
| 9 | +SUCCESS = 1 |
| 10 | +MATCHING_FAIL = 0 |
| 11 | + |
| 12 | + |
| 13 | +def verify_math_representation(completion, gt_answer): |
| 14 | + """ |
| 15 | + Verify if the completion is a valid math representation of the gt_answer. |
| 16 | + """ |
| 17 | + target = ( |
| 18 | + ExprExtractionConfig(), |
| 19 | + LatexExtractionConfig( |
| 20 | + normalization_config=NormalizationConfig( |
| 21 | + nits=False, |
| 22 | + malformed_operators=False, |
| 23 | + basic_latex=True, |
| 24 | + boxed="all", |
| 25 | + units=True, |
| 26 | + ), |
| 27 | + boxed_match_priority=0, |
| 28 | + ), |
| 29 | + ) |
| 30 | + if not isinstance(gt_answer, str) or len(gt_answer) == 0: |
| 31 | + raise ValueError("gt_answer should be a string, please verify your training data.") |
| 32 | + if not isinstance(completion, str) or len(completion) == 0: |
| 33 | + return MATCHING_FAIL |
| 34 | + try: |
| 35 | + parsed_gt_answer = parse(gt_answer, extraction_config=target) |
| 36 | + if len(parsed_gt_answer) == 0: |
| 37 | + return CANNOT_PARSE_GT_ANSWER |
| 38 | + parsed_completion = parse(completion, extraction_config=target) |
| 39 | + if len(parsed_completion) == 0: |
| 40 | + return CANNOT_PARSE_PREDICTION |
| 41 | + if verify(parsed_gt_answer, parsed_completion): |
| 42 | + return SUCCESS |
| 43 | + else: |
| 44 | + return MATCHING_FAIL |
| 45 | + except Exception: |
| 46 | + return MATCHING_FAIL |
| 47 | + |
| 48 | + |
| 49 | +def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward): |
| 50 | + math_verify_result = verify_math_representation(decoded_final_answer, gt_answer) |
| 51 | + if math_verify_result == SUCCESS: |
| 52 | + ans_acc += 1 |
| 53 | + reward += acc_score |
| 54 | + elif math_verify_result == CANNOT_PARSE_GT_ANSWER or math_verify_result == CANNOT_PARSE_PREDICTION: |
| 55 | + if decoded_final_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace( |
| 56 | + ",", "" |
| 57 | + ) == gt_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", ""): |
| 58 | + ans_acc += 1 |
| 59 | + if math_verify_result == CANNOT_PARSE_GT_ANSWER: |
| 60 | + # plain text answer cannot be parsed, but is correct |
| 61 | + reward += acc_score |
| 62 | + else: |
| 63 | + reward += ( |
| 64 | + acc_score / 2 |
| 65 | + ) # not a valid latex math representation, but the answer is correct, receive half of the score |
| 66 | + return reward, ans_acc |
| 67 | + |
6 | 68 |
|
7 | 69 | def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
8 | 70 | tokenizer = kwargs["tokenizer"]
|
@@ -36,9 +98,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
36 | 98 | format_acc += 1
|
37 | 99 |
|
38 | 100 | # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
39 |
| - if format_valid and final_answer is not None and verify(parse(gt_answer.strip()), parse(final_answer.strip())): |
40 |
| - ans_acc += 1 |
41 |
| - reward += acc_score |
| 101 | + if format_valid and final_answer is not None: |
| 102 | + reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward) |
42 | 103 |
|
43 | 104 | reward = reward + length_reward
|
44 | 105 |
|
@@ -88,9 +149,8 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
88 | 149 | reward += format_score
|
89 | 150 |
|
90 | 151 | # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
91 |
| - if format_valid and final_answer is not None and verify(parse(gt_answer.strip()), parse(final_answer.strip())): |
92 |
| - ans_acc += 1 |
93 |
| - reward += acc_score |
| 152 | + if format_valid and final_answer is not None: |
| 153 | + reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward) |
94 | 154 |
|
95 | 155 | reward = reward + length_reward
|
96 | 156 | if not eval_mode:
|
|
0 commit comments