Skip to content

Commit 17928ad

Browse files
authored
Merge pull request #6292 from hpcaitech/grpo-latest-dev-reward-update
[feat] Update reward verification
2 parents 5fd4bcb + d06042b commit 17928ad

File tree

9 files changed

+307
-82
lines changed

9 files changed

+307
-82
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,4 @@ applications/ColossalChat/logs
165165
applications/ColossalChat/tests/logs
166166
applications/ColossalChat/wandb
167167
applications/ColossalChat/model
168+
applications/ColossalChat/eval

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
minibatch_size: int = 1,
3737
save_interval: int = 100,
3838
save_dir: str = "./model",
39+
eval_interval: int = -1,
3940
):
4041
self.num_producers = num_producers
4142
self.num_episodes = num_episodes
@@ -51,6 +52,7 @@ def __init__(
5152
self.save_dir = save_dir
5253
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
5354
self.num_microbatches = batch_size // minibatch_size
55+
self.eval_interval = eval_interval
5456

5557
self.model_config = model_config
5658
self.plugin_config = plugin_config
@@ -93,7 +95,6 @@ def setup(self) -> None:
9395
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
9496

9597
self.buffer = []
96-
9798
self.recv_cnt = 0
9899

99100
def state_dict(self) -> Dict[str, torch.Tensor]:
@@ -110,6 +111,27 @@ def loop(self) -> None:
110111
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
111112
for step in pbar:
112113
i = 0
114+
if self.eval_interval > 0 and step % self.eval_interval == 0:
115+
eval_statistics = None
116+
eval_global_step = None
117+
for r in range(self.num_producers):
118+
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
119+
local_eval_result = ray_broadcast_tensor_dict(
120+
None, src=0, device=self.device, group_name=f"sync_data_{r}"
121+
)
122+
assert "consumer_global_step" in local_eval_result
123+
eval_global_step = local_eval_result.pop("consumer_global_step").item()
124+
if eval_statistics is None:
125+
eval_statistics = local_eval_result
126+
else:
127+
eval_statistics = {
128+
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
129+
}
130+
eval_statistics = {"eval/" + k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
131+
if dist.get_rank() == 0:
132+
if hasattr(self, "wandb_run"):
133+
self.wandb_run.log(eval_statistics, step=eval_global_step)
134+
print(f"Eval statistics: {eval_statistics}")
113135
for _ in range(self.num_recv_per_update):
114136
# receive data from producers
115137
for r in range(self.num_producers):
@@ -195,6 +217,7 @@ def __init__(
195217
minibatch_size=1,
196218
save_interval: int = 100,
197219
save_dir="./model",
220+
eval_interval: int = -1,
198221
):
199222
super().__init__(
200223
num_producers,
@@ -209,6 +232,9 @@ def __init__(
209232
model_config,
210233
plugin_config,
211234
minibatch_size,
235+
save_interval,
236+
save_dir,
237+
eval_interval,
212238
)
213239
path = model_config.pop("path")
214240
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
project_name=None,
4141
save_interval: int = 100,
4242
save_dir="./model",
43+
eval_interval: int = -1,
4344
):
4445
print(f"Using GRPO config: {grpo_config}")
4546
if grpo_config.get("loss_variation", "sample_level") == "token_level":
@@ -72,6 +73,7 @@ def __init__(
7273
minibatch_size,
7374
save_interval=save_interval,
7475
save_dir=save_dir,
76+
eval_interval=eval_interval,
7577
)
7678
path = model_config.pop("path")
7779
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@@ -528,4 +530,5 @@ def state_dict(self):
528530
self.policy_model._force_wait_all_gather()
529531
model = self.policy_model.unwrap()
530532
state_dict = model.state_dict()
533+
state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device)
531534
return state_dict

applications/ColossalChat/coati/distributed/inference_backend.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,8 @@ def __init__(
205205
generate_config = generate_config.copy()
206206
generate_config.update(self.FORCE_GENERATE_CONFIG)
207207
generate_config.update({"n": num_generations})
208-
self.generate_config = SamplingParams(**generate_config)
208+
self.generate_config = generate_config
209+
self.sample_params = SamplingParams(**generate_config)
209210
self.model_config = model_config
210211
self.tokenizer = tokenizer
211212
self.num_generations = num_generations
@@ -219,8 +220,9 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
219220
micro_batch_input_ids_no_padding = [
220221
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
221222
]
223+
sample_params = kwargs.get("sample_params", self.sample_params)
222224
outputs = self.llm.generate(
223-
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
225+
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=sample_params, use_tqdm=False
224226
)
225227
out_tokens = []
226228
out_len = []
@@ -266,11 +268,11 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
266268
"response_idx": response_idx,
267269
}
268270

269-
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
271+
data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()}
270272

271273
if "gt_answer" in kwargs:
272274
# repeat gt_answer for each prompt.
273-
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1)
275+
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1)
274276
data = {k: v.to(get_current_device()) for k, v in data.items()}
275277
return data
276278

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def launch_distributed(
3434
inference_microbatch_size: int,
3535
train_batch_size: int,
3636
train_minibatch_size: int,
37-
dataset_config: Dict[str, Any],
37+
train_dataset_config: Dict[str, Any],
3838
dataloaders_config: Dict[str, Any],
3939
inference_model_config: Dict[str, Any],
4040
generate_config: Dict[str, Any],
@@ -50,6 +50,9 @@ def launch_distributed(
5050
project_name: Optional[str] = None,
5151
save_interval: int = 100,
5252
save_dir: str = "./model",
53+
eval_dataset_config: Optional[Dict[str, Any]] = None,
54+
eval_interval: int = 100,
55+
eval_save_dir: Optional[str] = None,
5356
):
5457

5558
if core_algo not in ALGO_MAP:
@@ -60,9 +63,9 @@ def launch_distributed(
6063
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
6164
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
6265

63-
dataset_path = dataset_config["path"]
66+
dataset_path = train_dataset_config["path"]
6467
num_samples = get_jsonl_size_fast(dataset_path)
65-
global_inference_batch_size = inference_batch_size * num_producers
68+
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer
6669
num_update_per_episode = num_samples // global_inference_batch_size
6770
num_recv_per_update = inference_batch_size // inference_microbatch_size
6871

@@ -74,7 +77,7 @@ def launch_distributed(
7477
num_consumer_procs=num_consumer_procs,
7578
num_episodes=num_episodes,
7679
batch_size=inference_batch_size,
77-
dataset_config=dataset_config,
80+
train_dataset_config=train_dataset_config,
7881
dataloaders_config=dataloaders_config,
7982
model_config=inference_model_config,
8083
generate_config=generate_config,
@@ -83,6 +86,10 @@ def launch_distributed(
8386
backend=inference_backend,
8487
num_generations=num_generations,
8588
consumer_plugin_config=plugin_config,
89+
eval_dataset_config=eval_dataset_config,
90+
eval_interval=eval_interval * num_recv_per_update,
91+
evaluation_function_type=grpo_config["reward_fn_type"],
92+
eval_save_dir=eval_save_dir,
8693
)
8794
procs.append(producer)
8895
generate_config_consumer = copy.deepcopy(generate_config)
@@ -111,6 +118,7 @@ def launch_distributed(
111118
project_name=project_name,
112119
save_interval=save_interval,
113120
save_dir=save_dir,
121+
eval_interval=eval_interval,
114122
)
115123
procs.append(consumer)
116124
ray.get([p.setup.remote() for p in procs])

0 commit comments

Comments
 (0)