Skip to content

Commit d06042b

Browse files
committed
rewrite reward fn
1 parent a6085ff commit d06042b

File tree

2 files changed

+68
-8
lines changed

2 files changed

+68
-8
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def loop(self) -> None:
127127
eval_statistics = {
128128
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
129129
}
130-
eval_statistics = {k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
130+
eval_statistics = {"eval/" + k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
131131
if dist.get_rank() == 0:
132132
if hasattr(self, "wandb_run"):
133133
self.wandb_run.log(eval_statistics, step=eval_global_step)

applications/ColossalChat/coati/distributed/reward/reward_fn.py

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,70 @@
11
import torch
2-
from math_verify import parse, verify
2+
from latex2sympy2_extended import NormalizationConfig
3+
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
34

45
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
56

7+
CANNOT_PARSE_GT_ANSWER = -1
8+
CANNOT_PARSE_PREDICTION = -2
9+
SUCCESS = 1
10+
MATCHING_FAIL = 0
11+
12+
13+
def verify_math_representation(completion, gt_answer):
14+
"""
15+
Verify if the completion is a valid math representation of the gt_answer.
16+
"""
17+
target = (
18+
ExprExtractionConfig(),
19+
LatexExtractionConfig(
20+
normalization_config=NormalizationConfig(
21+
nits=False,
22+
malformed_operators=False,
23+
basic_latex=True,
24+
boxed="all",
25+
units=True,
26+
),
27+
boxed_match_priority=0,
28+
),
29+
)
30+
if not isinstance(gt_answer, str) or len(gt_answer) == 0:
31+
raise ValueError("gt_answer should be a string, please verify your training data.")
32+
if not isinstance(completion, str) or len(completion) == 0:
33+
return MATCHING_FAIL
34+
try:
35+
parsed_gt_answer = parse(gt_answer, extraction_config=target)
36+
if len(parsed_gt_answer) == 0:
37+
return CANNOT_PARSE_GT_ANSWER
38+
parsed_completion = parse(completion, extraction_config=target)
39+
if len(parsed_completion) == 0:
40+
return CANNOT_PARSE_PREDICTION
41+
if verify(parsed_gt_answer, parsed_completion):
42+
return SUCCESS
43+
else:
44+
return MATCHING_FAIL
45+
except Exception:
46+
return MATCHING_FAIL
47+
48+
49+
def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward):
50+
math_verify_result = verify_math_representation(decoded_final_answer, gt_answer)
51+
if math_verify_result == SUCCESS:
52+
ans_acc += 1
53+
reward += acc_score
54+
elif math_verify_result == CANNOT_PARSE_GT_ANSWER or math_verify_result == CANNOT_PARSE_PREDICTION:
55+
if decoded_final_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(
56+
",", ""
57+
) == gt_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", ""):
58+
ans_acc += 1
59+
if math_verify_result == CANNOT_PARSE_GT_ANSWER:
60+
# plain text answer cannot be parsed, but is correct
61+
reward += acc_score
62+
else:
63+
reward += (
64+
acc_score / 2
65+
) # not a valid latex math representation, but the answer is correct, receive half of the score
66+
return reward, ans_acc
67+
668

769
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
870
tokenizer = kwargs["tokenizer"]
@@ -36,9 +98,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
3698
format_acc += 1
3799

38100
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
39-
if format_valid and final_answer is not None and verify(parse(gt_answer.strip()), parse(final_answer.strip())):
40-
ans_acc += 1
41-
reward += acc_score
101+
if format_valid and final_answer is not None:
102+
reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward)
42103

43104
reward = reward + length_reward
44105

@@ -88,9 +149,8 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
88149
reward += format_score
89150

90151
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
91-
if format_valid and final_answer is not None and verify(parse(gt_answer.strip()), parse(final_answer.strip())):
92-
ans_acc += 1
93-
reward += acc_score
152+
if format_valid and final_answer is not None:
153+
reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward)
94154

95155
reward = reward + length_reward
96156
if not eval_mode:

0 commit comments

Comments
 (0)