@@ -68,6 +68,15 @@ class TrainerType(Enum):
6868 designed to provide additional flexibility and features.
6969 """
7070
71+ VERL_GRPO = "verl_grpo"
72+ """Group Relative Policy Optimization trainer from `verl` library.
73+
74+ This trainer implements the Group Relative Policy Optimization algorithm
75+ introduced in the paper https://arxiv.org/pdf/2402.03300
76+ for fine-tuning language models.
77+ Optionally, supports user-defined reward functions.
78+ """
79+
7180
7281class SchedulerType (str , Enum ):
7382 """Enum representing the supported learning rate schedulers.
@@ -153,7 +162,9 @@ class TrainingParams(BaseParams):
153162 - HF: HuggingFace's Trainer
154163 - TRL_SFT: TRL's SFT Trainer
155164 - TRL_DPO: TRL's DPO Trainer
165+ - TRL_GRPO: TRL's GRPO Trainer
156166 - OUMI: Custom generic trainer implementation
167+ - VERL_GRPO: verl's GRPO Trainer
157168 """
158169
159170 enable_gradient_checkpointing : bool = False
@@ -312,8 +323,14 @@ class TrainingParams(BaseParams):
312323 """The names of the reward function in the Oumi registry to use for reinforcement
313324 learning.
314325
315- Only supported with the TRL_GRPO trainer currently. Refer to
316- https://huggingface.co/docs/trl/main/en/grpo_trainer
326+ Only supported with the TRL_GRPO and VERL_GRPO trainers. Currently,
327+ VERL_GRPO only supports specifying a single reward function.
328+
329+ For TRL_GRPO, refer to https://huggingface.co/docs/trl/main/en/grpo_trainer
330+ for documentation about the function signature.
331+
332+ For VERL_GRPO, refer to
333+ https://verl.readthedocs.io/en/latest/preparation/reward_function.html
317334 for documentation about the function signature.
318335 """
319336
@@ -798,14 +815,21 @@ def __post_init__(self):
798815
799816 if (
800817 self .trainer_type != TrainerType .TRL_GRPO
818+ and self .trainer_type != TrainerType .VERL_GRPO
801819 and self .reward_functions is not None
802820 ):
803821 function_names = [name for name in self .reward_functions if name ]
804822 if len (function_names ) > 0 :
805823 raise ValueError (
806- "reward_functions may only be defined for the TRL_GRPO trainer. "
807- f"Actual: { self .trainer_type } "
824+ "reward_functions may only be defined for the TRL_GRPO or VERL_GRPO "
825+ f"trainers. Actual: { self .trainer_type } "
808826 )
827+ if self .trainer_type == TrainerType .VERL_GRPO :
828+ if len (function_names ) > 1 :
829+ raise ValueError (
830+ "VERL_GRPO only supports a single reward function. "
831+ f"Actual: { function_names } "
832+ )
809833
810834 # TODO: #1540 - Remove when TRL bug is fixed.
811835 if (
0 commit comments