Skip to content

Commit 279e9ba

Browse files
authored
Integrate verl GRPO trainer into train script (#1652)
1 parent b6d6a68 commit 279e9ba

File tree

5 files changed

+186
-82
lines changed

5 files changed

+186
-82
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ gpu = [
117117
"nvidia-ml-py>=12.560.30,<12.561",
118118
"bitsandbytes>=0.45.0,<0.46", # Used for QLora, and PagedAdam implementation
119119
"verl>=0.3.0,<0.4", # Used for the VERL_GRPO trainer.
120-
"ray[default]", # Used for the VERL_GRPO trainer.
121120
"vllm>=0.7.3,<0.8.0", # For VLLMInferenceEngine
122121
]
123122

src/oumi/builders/training.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from oumi.core.configs import TrainerType, TrainingParams
2323
from oumi.core.distributed import is_world_process_zero
2424
from oumi.core.processors.base_processor import BaseProcessor
25-
from oumi.core.trainers import BaseTrainer, HuggingFaceTrainer
25+
from oumi.core.trainers import BaseTrainer, HuggingFaceTrainer, VerlGrpoTrainer
2626
from oumi.core.trainers import Trainer as OumiTrainer
2727
from oumi.utils.logging import logger
2828

@@ -94,6 +94,12 @@ def _init_oumi_trainer(*args, **kwargs) -> BaseTrainer:
9494

9595
return _init_oumi_trainer
9696

97+
def _create_verl_grpo_builder_fn() -> Callable[..., BaseTrainer]:
98+
def _init_verl_grpo_trainer(*args, **kwargs) -> BaseTrainer:
99+
return VerlGrpoTrainer(*args, **kwargs)
100+
101+
return _init_verl_grpo_trainer
102+
97103
if trainer_type == TrainerType.TRL_SFT:
98104
return _create_hf_builder_fn(trl.SFTTrainer)
99105
elif trainer_type == TrainerType.TRL_DPO:
@@ -108,5 +114,7 @@ def _init_oumi_trainer(*args, **kwargs) -> BaseTrainer:
108114
"Prefer to use HF trainer when possible."
109115
)
110116
return _create_oumi_builder_fn()
117+
elif trainer_type == TrainerType.VERL_GRPO:
118+
return _create_verl_grpo_builder_fn()
111119

112120
raise NotImplementedError(f"Trainer type {trainer_type} not supported.")

src/oumi/core/configs/params/training_params.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ class TrainerType(Enum):
6868
designed to provide additional flexibility and features.
6969
"""
7070

71+
VERL_GRPO = "verl_grpo"
72+
"""Group Relative Policy Optimization trainer from `verl` library.
73+
74+
This trainer implements the Group Relative Policy Optimization algorithm
75+
introduced in the paper https://arxiv.org/pdf/2402.03300
76+
for fine-tuning language models.
77+
Optionally, supports user-defined reward functions.
78+
"""
79+
7180

7281
class SchedulerType(str, Enum):
7382
"""Enum representing the supported learning rate schedulers.
@@ -153,7 +162,9 @@ class TrainingParams(BaseParams):
153162
- HF: HuggingFace's Trainer
154163
- TRL_SFT: TRL's SFT Trainer
155164
- TRL_DPO: TRL's DPO Trainer
165+
- TRL_GRPO: TRL's GRPO Trainer
156166
- OUMI: Custom generic trainer implementation
167+
- VERL_GRPO: verl's GRPO Trainer
157168
"""
158169

159170
enable_gradient_checkpointing: bool = False
@@ -312,8 +323,14 @@ class TrainingParams(BaseParams):
312323
"""The names of the reward function in the Oumi registry to use for reinforcement
313324
learning.
314325
315-
Only supported with the TRL_GRPO trainer currently. Refer to
316-
https://huggingface.co/docs/trl/main/en/grpo_trainer
326+
Only supported with the TRL_GRPO and VERL_GRPO trainers. Currently,
327+
VERL_GRPO only supports specifying a single reward function.
328+
329+
For TRL_GRPO, refer to https://huggingface.co/docs/trl/main/en/grpo_trainer
330+
for documentation about the function signature.
331+
332+
For VERL_GRPO, refer to
333+
https://verl.readthedocs.io/en/latest/preparation/reward_function.html
317334
for documentation about the function signature.
318335
"""
319336

@@ -798,14 +815,21 @@ def __post_init__(self):
798815

799816
if (
800817
self.trainer_type != TrainerType.TRL_GRPO
818+
and self.trainer_type != TrainerType.VERL_GRPO
801819
and self.reward_functions is not None
802820
):
803821
function_names = [name for name in self.reward_functions if name]
804822
if len(function_names) > 0:
805823
raise ValueError(
806-
"reward_functions may only be defined for the TRL_GRPO trainer. "
807-
f"Actual: {self.trainer_type}"
824+
"reward_functions may only be defined for the TRL_GRPO or VERL_GRPO"
825+
f"trainers. Actual: {self.trainer_type}"
808826
)
827+
if self.trainer_type == TrainerType.VERL_GRPO:
828+
if len(function_names) > 1:
829+
raise ValueError(
830+
"VERL_GRPO only supports a single reward function. "
831+
f"Actual: {function_names}"
832+
)
809833

810834
# TODO: #1540 - Remove when TRL bug is fixed.
811835
if (

src/oumi/core/configs/training_config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,12 @@ def __post_init__(self):
203203
dataset_params.dataset_kwargs["processor_kwargs"] = {
204204
**self.model.processor_kwargs
205205
}
206+
207+
# Verl will error without a validation dataset.
208+
if (
209+
self.training.trainer_type == TrainerType.VERL_GRPO
210+
and not self.data.validation.datasets
211+
):
212+
raise ValueError(
213+
"At least one validation dataset is required for VERL_GRPO training."
214+
)

0 commit comments

Comments
 (0)