Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 72 additions & 2 deletions skyrl/backends/skyrl_train/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ class PolicyLossType(StrEnum):
SAPO = "sapo"
CROSS_ENTROPY = "cross_entropy"
IMPORTANCE_SAMPLING = "importance_sampling"
DRO = "dro"


class PolicyLossRegistry(BaseFunctionRegistry):
Expand Down Expand Up @@ -500,6 +501,7 @@ def repopulate_registry(cls):
"sapo": [PolicyLossType.SAPO, sapo_policy_loss],
"cross_entropy": [PolicyLossType.CROSS_ENTROPY, cross_entropy_loss],
"importance_sampling": [PolicyLossType.IMPORTANCE_SAMPLING, importance_sampling_loss],
"dro": [PolicyLossType.DRO, dro_policy_loss],
}

for pl_name, (pl_type, pl_func) in pl_types.items():
Expand Down Expand Up @@ -560,7 +562,8 @@ def ppo_policy_loss(
"token_mean",
"sequence_mean",
"seq_mean_token_sum_norm",
], "loss_reduction must be either 'token_mean', 'sequence_mean', or 'seq_mean_token_sum_norm'"
"sum",
], "loss_reduction must be one of 'token_mean', 'sequence_mean', 'seq_mean_token_sum_norm', or 'sum'"

ratio = safe_exp_delta(log_probs - old_log_probs, clip=20.0, out_dtype=log_probs.dtype)
surr1 = ratio * advantages
Expand Down Expand Up @@ -980,10 +983,71 @@ def importance_sampling_loss(
return loss, {"importance_ratio": mean_ratio.item()}



@register_policy_loss(PolicyLossType.DRO)
def dro_policy_loss(
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
config: AlgorithmConfig,
loss_mask: Optional[torch.Tensor] = None,
rollout_logprobs: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, dict[str, float]]:
"""Distributionally Robust Optimization (DRO) policy loss.

Applies an exponential tilt to the per-token PPO surrogate loss so that
the optimisation focuses on the worst-case region of the advantage
distribution. The ``beta`` parameter in ``config.dro`` controls the
degree of robustness: larger values concentrate more weight on the
highest-loss tokens/sequences.

The loss is:

L_dro = (1/beta) * log( E[ exp(beta * L_ppo) ] )

where L_ppo is the standard clipped surrogate loss (per token).

Note: ``config.loss_reduction`` is ignored — DRO uses its own log-mean-exp reduction.
"""
beta = config.dro.beta

ratio = safe_exp_delta(log_probs - old_log_probs, clip=20.0, out_dtype=log_probs.dtype)
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - config.eps_clip_low, 1 + config.eps_clip_high) * advantages
elementwise_loss = -torch.min(surr1, surr2)

clip_ratio = masked_mean((-surr2 > -surr1).float(), loss_mask).mean().detach().item()

loss_metrics = {"clip_ratio": clip_ratio}

# apply off policy correction
elementwise_loss, loss_mask, off_policy_metrics = apply_off_policy_correction(
elementwise_loss, old_log_probs, rollout_logprobs, loss_mask, config.off_policy_correction
)
loss_metrics.update(off_policy_metrics)

# DRO exponential tilt
if loss_mask is not None:
masked_loss = elementwise_loss * loss_mask
# Stabilise log-sum-exp with the max trick
max_val = masked_loss.max().detach()
exp_term = torch.exp(beta * (masked_loss - max_val))
exp_term = exp_term * loss_mask
log_mean_exp = max_val + torch.log(exp_term.sum() / loss_mask.sum().clamp(min=1))
else:
max_val = elementwise_loss.max().detach()
exp_term = torch.exp(beta * (elementwise_loss - max_val))
log_mean_exp = max_val + torch.log(exp_term.mean())

loss = log_mean_exp / beta

return loss, loss_metrics


def reduce_loss(
loss: torch.Tensor,
loss_mask: Optional[torch.Tensor],
loss_reduction: Literal["token_mean", "sequence_mean", "seq_mean_token_sum_norm"],
loss_reduction: Literal["token_mean", "sequence_mean", "seq_mean_token_sum_norm", "sum"],
max_seq_len: Optional[int] = None,
) -> torch.Tensor:
if loss_reduction == "token_mean":
Expand All @@ -1004,6 +1068,12 @@ def reduce_loss(
# If no mask, assume all tokens are valid
seq_losses = torch.sum(loss, dim=-1) / max_seq_len
loss = torch.mean(seq_losses)
elif loss_reduction == "sum":
# simple sum of all valid (masked) token losses
if loss_mask is not None:
loss = (loss * loss_mask).sum()
else:
loss = loss.sum()
else:
raise ValueError(f"Invalid loss reduction type: {loss_reduction}")
return loss
Expand Down
2 changes: 2 additions & 0 deletions skyrl/train/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
SkyRLGymConfig,
SAPOConfig,
CISPOConfig,
DROConfig,
ClipCovConfig,
KLCovConfig,
KLCtrlConfig,
Expand Down Expand Up @@ -61,6 +62,7 @@
"SkyRLGymConfig",
"SAPOConfig",
"CISPOConfig",
"DROConfig",
"ClipCovConfig",
"KLCovConfig",
"KLCtrlConfig",
Expand Down
11 changes: 10 additions & 1 deletion skyrl/train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,13 @@ class CISPOConfig(BaseConfig):
"""Offset for upper bound of importance sampling ratio clipping (as opposed to PPO token update clipping)."""


@dataclass
class DROConfig(BaseConfig):

beta: float = 0.1
"""Controls the degree of robustness in the exponential tilt. Larger values focus more on worst-case tokens."""


# see https://docs.skyrl.ai/docs/algorithms/off_policy_correction for more details
@dataclass
class OffPolicyCorrectionConfig(BaseConfig):
Expand Down Expand Up @@ -341,7 +348,7 @@ class AlgorithmConfig(BaseConfig):
policy_loss_type: str = "regular"
"""``"regular"``, ``"dual_clip"``, ``"gspo"``, ``"clip_cov"``, ``"kl_cov"``, or custom via ``PolicyLossRegistry``."""
loss_reduction: str = "token_mean"
"""``"token_mean"``, ``"sequence_mean"``, or ``"seq_mean_token_sum_norm"``."""
"""``"token_mean"``, ``"sequence_mean"``, ``"seq_mean_token_sum_norm"``, or ``"sum"``."""
grpo_norm_by_std: bool = True
zero_variance_filter: bool = False
"""Loss-mask prompts with zero-variance rewards. Only applicable when rewards are response-level."""
Expand All @@ -365,6 +372,8 @@ class AlgorithmConfig(BaseConfig):
"""Only used when ``policy_loss_type="kl_cov"``."""
cispo: CISPOConfig = field(default_factory=CISPOConfig)
"""Only used when ``policy_loss_type="cispo"``."""
dro: DROConfig = field(default_factory=DROConfig)
"""Only used when ``policy_loss_type="dro"``."""
max_seq_len: Optional[int] = None
"""Used for ``seq_mean_token_sum_norm`` loss reduction; set explicitly for multi-turn.
If ``None``, calculated as ``generator.max_input_length + generator.sampling_params.max_generate_length``."""
Expand Down
6 changes: 5 additions & 1 deletion skyrl/train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ trainer:
# this adds training batch level normalization to advantages
advantage_batch_normalize: false
value_head_prefix: "value_head"
policy_loss_type: "regular" # "regular", "dual_clip", "gspo", "clip_cov", "kl_cov", or customizable with PolicyLossRegistry
policy_loss_type: "regular" # "regular", "dual_clip", "gspo", "clip_cov", "kl_cov", "dro", or customizable with PolicyLossRegistry
loss_reduction: "token_mean" # "token_mean", "sequence_mean", "seq_mean_token_sum_norm"
grpo_norm_by_std: true # set to false to disable normalization by std in GRPO
zero_variance_filter: false # set to true to loss mask out prompts with zero variance rewards. only applicable when rewards are response-level.
Expand Down Expand Up @@ -194,6 +194,10 @@ trainer:
cispo:
cispo_eps_clip_low: 0 # offset for lower bound of importance sampling ratio clipping (as opposed to PPO token update clipping)
cispo_eps_clip_high: 5 # offset for upper bound of importance sampling ratio clipping (as opposed to PPO token update clipping)

# dro parameters (only used when policy_loss_type: "dro")
dro:
beta: 0.1

# Fully async specific knobs. For more see https://docs.skyrl.ai/docs/tutorials/fully_async#step-2-config-knobs-to-tune-for-fully-async-training
fully_async:
Expand Down
3 changes: 2 additions & 1 deletion skyrl/train/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,10 @@ def validate_cfg(cfg: SkyRLTrainConfig):
"token_mean",
"sequence_mean",
"seq_mean_token_sum_norm",
"sum",
), (
f"invalid loss_reduction: {cfg.trainer.algorithm.loss_reduction}. "
f"Must be one of `['token_mean', 'sequence_mean', 'seq_mean_token_sum_norm']`"
f"Must be one of `['token_mean', 'sequence_mean', 'seq_mean_token_sum_norm', 'sum']`"
)

# TODO (erictang000): remove this after deprecation period
Expand Down
Loading
Loading