Skip to content

[feat] Update reward verification #6292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,4 @@ applications/ColossalChat/logs
applications/ColossalChat/tests/logs
applications/ColossalChat/wandb
applications/ColossalChat/model
applications/ColossalChat/eval
28 changes: 27 additions & 1 deletion applications/ColossalChat/coati/distributed/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
minibatch_size: int = 1,
save_interval: int = 100,
save_dir: str = "./model",
eval_interval: int = -1,
):
self.num_producers = num_producers
self.num_episodes = num_episodes
Expand All @@ -51,6 +52,7 @@ def __init__(
self.save_dir = save_dir
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // minibatch_size
self.eval_interval = eval_interval

self.model_config = model_config
self.plugin_config = plugin_config
Expand Down Expand Up @@ -93,7 +95,6 @@ def setup(self) -> None:
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")

self.buffer = []

self.recv_cnt = 0

def state_dict(self) -> Dict[str, torch.Tensor]:
Expand All @@ -110,6 +111,27 @@ def loop(self) -> None:
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
for step in pbar:
i = 0
if self.eval_interval > 0 and step % self.eval_interval == 0:
eval_statistics = None
eval_global_step = None
for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
local_eval_result = ray_broadcast_tensor_dict(
None, src=0, device=self.device, group_name=f"sync_data_{r}"
)
assert "consumer_global_step" in local_eval_result
eval_global_step = local_eval_result.pop("consumer_global_step").item()
if eval_statistics is None:
eval_statistics = local_eval_result
else:
eval_statistics = {
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
}
eval_statistics = {"eval/" + k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
if dist.get_rank() == 0:
if hasattr(self, "wandb_run"):
self.wandb_run.log(eval_statistics, step=eval_global_step)
print(f"Eval statistics: {eval_statistics}")
for _ in range(self.num_recv_per_update):
# receive data from producers
for r in range(self.num_producers):
Expand Down Expand Up @@ -195,6 +217,7 @@ def __init__(
minibatch_size=1,
save_interval: int = 100,
save_dir="./model",
eval_interval: int = -1,
):
super().__init__(
num_producers,
Expand All @@ -209,6 +232,9 @@ def __init__(
model_config,
plugin_config,
minibatch_size,
save_interval,
save_dir,
eval_interval,
)
path = model_config.pop("path")
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
project_name=None,
save_interval: int = 100,
save_dir="./model",
eval_interval: int = -1,
):
print(f"Using GRPO config: {grpo_config}")
if grpo_config.get("loss_variation", "sample_level") == "token_level":
Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(
minibatch_size,
save_interval=save_interval,
save_dir=save_dir,
eval_interval=eval_interval,
)
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
Expand Down Expand Up @@ -528,4 +530,5 @@ def state_dict(self):
self.policy_model._force_wait_all_gather()
model = self.policy_model.unwrap()
state_dict = model.state_dict()
state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device)
return state_dict
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ def __init__(
generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG)
generate_config.update({"n": num_generations})
self.generate_config = SamplingParams(**generate_config)
self.generate_config = generate_config
self.sample_params = SamplingParams(**generate_config)
self.model_config = model_config
self.tokenizer = tokenizer
self.num_generations = num_generations
Expand All @@ -219,8 +220,9 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
micro_batch_input_ids_no_padding = [
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
]
sample_params = kwargs.get("sample_params", self.sample_params)
outputs = self.llm.generate(
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=sample_params, use_tqdm=False
)
out_tokens = []
out_len = []
Expand Down Expand Up @@ -266,11 +268,11 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
"response_idx": response_idx,
}

data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()}

if "gt_answer" in kwargs:
# repeat gt_answer for each prompt.
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1)
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1)
data = {k: v.to(get_current_device()) for k, v in data.items()}
return data

Expand Down
16 changes: 12 additions & 4 deletions applications/ColossalChat/coati/distributed/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def launch_distributed(
inference_microbatch_size: int,
train_batch_size: int,
train_minibatch_size: int,
dataset_config: Dict[str, Any],
train_dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any],
inference_model_config: Dict[str, Any],
generate_config: Dict[str, Any],
Expand All @@ -50,6 +50,9 @@ def launch_distributed(
project_name: Optional[str] = None,
save_interval: int = 100,
save_dir: str = "./model",
eval_dataset_config: Optional[Dict[str, Any]] = None,
eval_interval: int = 100,
eval_save_dir: Optional[str] = None,
):

if core_algo not in ALGO_MAP:
Expand All @@ -60,9 +63,9 @@ def launch_distributed(
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0

dataset_path = dataset_config["path"]
dataset_path = train_dataset_config["path"]
num_samples = get_jsonl_size_fast(dataset_path)
global_inference_batch_size = inference_batch_size * num_producers
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer
num_update_per_episode = num_samples // global_inference_batch_size
num_recv_per_update = inference_batch_size // inference_microbatch_size

Expand All @@ -74,7 +77,7 @@ def launch_distributed(
num_consumer_procs=num_consumer_procs,
num_episodes=num_episodes,
batch_size=inference_batch_size,
dataset_config=dataset_config,
train_dataset_config=train_dataset_config,
dataloaders_config=dataloaders_config,
model_config=inference_model_config,
generate_config=generate_config,
Expand All @@ -83,6 +86,10 @@ def launch_distributed(
backend=inference_backend,
num_generations=num_generations,
consumer_plugin_config=plugin_config,
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval * num_recv_per_update,
evaluation_function_type=grpo_config["reward_fn_type"],
eval_save_dir=eval_save_dir,
)
procs.append(producer)
generate_config_consumer = copy.deepcopy(generate_config)
Expand Down Expand Up @@ -111,6 +118,7 @@ def launch_distributed(
project_name=project_name,
save_interval=save_interval,
save_dir=save_dir,
eval_interval=eval_interval,
)
procs.append(consumer)
ray.get([p.setup.remote() for p in procs])
Expand Down
Loading