Skip to content

Commit c8e670b

Browse files
Configurable reward functions (#2698)
1 parent 541c730 commit c8e670b

File tree

10 files changed

+314
-220
lines changed

10 files changed

+314
-220
lines changed

recipes/configs/dev/qwen3B_async_grpo.yaml

+11
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,17 @@ training:
9797
epsilon: 0.2
9898
seed: null
9999

100+
reward_functions:
101+
- _component_: torchtune.dev.rl.rewards.FormattedMathCorrectnessReward
102+
answer_tag: answer
103+
positive_reward: 10.0
104+
negative_reward: 0.0
105+
- _component_: torchtune.dev.rl.rewards.ThinkingAnswerFormattingReward
106+
think_tag: think
107+
answer_tag: answer
108+
positive_reward: 1.0
109+
negative_reward: 0.0
110+
100111
# All logging args
101112
metric_logger:
102113
_component_: torchtune.training.metric_logging.WandBLogger

recipes/configs/dev/qwen3B_sync_grpo.yaml

+10
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,16 @@ compile: False # pytorch compile, set to true for better perf/memory
107107
# Reduced precision
108108
dtype: bf16
109109

110+
reward_functions:
111+
- _component_: torchtune.dev.rl.rewards.FormattedMathCorrectnessReward
112+
answer_tag: answer
113+
positive_reward: 10.0
114+
negative_reward: 0.0
115+
- _component_: torchtune.dev.rl.rewards.ThinkingAnswerFormattingReward
116+
think_tag: think
117+
answer_tag: answer
118+
positive_reward: 1.0
119+
negative_reward: 0.0
110120

111121
# Logging
112122
metric_logger:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
import torch
9+
from torchtune.dev.rl.rewards import RewardOutput
10+
11+
12+
class TestRewardOutput:
13+
@pytest.fixture
14+
def sample_reward_output(self):
15+
return RewardOutput(
16+
reward_base_name="test_reward",
17+
total_reward=torch.tensor([1.0, 2.0, 3.0]),
18+
successes=torch.tensor([1.0, 0.0, 1.0]),
19+
rewards={
20+
"sub_reward_1": torch.tensor([0.5, 1.5, 2.5]),
21+
"sub_reward_2": torch.tensor([10.0, 20.0, 30.0]),
22+
},
23+
)
24+
25+
def test_log(self, sample_reward_output):
26+
log_dict = sample_reward_output.log(prefix="train")
27+
expected_log = {
28+
"train/test_reward/sub_reward_1": 1.5,
29+
"train/test_reward/sub_reward_2": 20.0,
30+
"train/test_reward": 2.0,
31+
"train/test_reward/successes": 2.0 / 3.0,
32+
}
33+
assert log_dict.keys() == expected_log.keys()
34+
for key in expected_log:
35+
assert log_dict[key] == pytest.approx(expected_log[key])
36+
37+
def test_log_no_prefix(self, sample_reward_output):
38+
log_dict = sample_reward_output.log()
39+
expected_log = {
40+
"test_reward/sub_reward_1": 1.5,
41+
"test_reward/sub_reward_2": 20.0,
42+
"test_reward": 2.0,
43+
"test_reward/successes": 2.0 / 3.0,
44+
}
45+
assert log_dict.keys() == expected_log.keys()
46+
for key in expected_log:
47+
assert log_dict[key] == pytest.approx(expected_log[key])

tests/torchtune/dev/rl/workers/test_postprocessing.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
import time
99

1010
import pytest
11+
import torch
12+
from omegaconf import OmegaConf
13+
from tests.test_utils import gen_log_file_name, gpu_test, rl_test, skip_if_lt_python_310
1114

1215
_has_ray = importlib.util.find_spec("ray") is not None
1316

@@ -25,10 +28,6 @@ def remote(*args, **kwargs):
2528
return lambda cls: cls
2629

2730

28-
import torch
29-
from omegaconf import OmegaConf
30-
from tests.test_utils import gen_log_file_name, gpu_test, rl_test, skip_if_lt_python_310
31-
3231
grpo_samples = 4
3332
max_generated_tokens = 32
3433

@@ -130,12 +129,10 @@ def test_run(self, cfg, log_file):
130129
).to(dtype=torch.bool),
131130
seq_lens=torch.randint(0, 100, (grpo_samples,)),
132131
answers=NonTensorData(["42"] * grpo_samples),
132+
sequence_ids=None,
133133
policy_version=None,
134-
rewards=None,
135134
advantages=None,
136-
successes=None,
137-
reward_metadata=None,
138-
sequence_ids=None,
135+
reward_outputs=None,
139136
)
140137
)
141138
replay_buffer = []

torchtune/dev/rl/datatypes/trajectory.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Dict, List
7+
from typing import List
88

99
import torch
1010
from tensordict import TensorClass
11+
from torchtune.dev.rl.rewards import RewardOutput
1112

1213

1314
class Trajectory(TensorClass["nocast"]):
@@ -19,8 +20,6 @@ class Trajectory(TensorClass["nocast"]):
1920
seq_lens: torch.Tensor
2021
answers: torch.Tensor
2122
policy_version: int
22-
rewards: torch.Tensor
2323
advantages: torch.Tensor
24-
successes: torch.Tensor
25-
reward_metadata: Dict[str, List[str]]
24+
reward_outputs: List[RewardOutput]
2625
sequence_ids: List[str]

0 commit comments

Comments
 (0)