diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 256338fca..a04fc7cbe 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -381,6 +381,54 @@ def _extract_per_sample( return log_probs_list, entropy_list +def _build_full_loss_mask( + T: int, + device: torch.device, + loss_masks: list[torch.Tensor], + total_lengths: list[int], + response_lengths: list[int], + qkv_format: str, + max_seq_lens: list[int] | None, +) -> torch.Tensor: + """Build a boolean mask over all T logit positions marking response tokens to keep. + + For each sample, the loss_mask covers the response tokens. This function maps + those per-sample masks onto the full packed/padded logit tensor layout (matching + the layout used by _build_shifted_tokens and _extract_per_sample for cp_size==1). + + Positions corresponding to prompt tokens or padding are marked True (kept) so that + the shifted-token alignment is preserved. Only response positions where + loss_mask == 0 are marked False (filtered out). + + Returns: + Boolean tensor of shape [T] where True means "compute this position". + """ + # Start with all True — prompts and padding are kept by default + full_mask = torch.ones(T, dtype=torch.bool, device=device) + + if qkv_format == "thd": + offset = 0 + for i, (total_length, response_length) in enumerate(zip(total_lengths, response_lengths, strict=False)): + end = offset + total_length + start = end - response_length + # The logits for response tokens are at positions [start-1, end-1) + resp_logit_start = start - 1 + resp_logit_end = end - 1 + mask_i = loss_masks[i].bool() + full_mask[resp_logit_start:resp_logit_end] = mask_i[:resp_logit_end - resp_logit_start] + offset += total_length + else: # bshd + for i, (total_length, response_length) in enumerate(zip(total_lengths, response_lengths, strict=False)): + end = max_seq_lens[i] * i + total_length + start = end - response_length + resp_logit_start = start - 1 + resp_logit_end = end - 1 + mask_i = loss_masks[i].bool() + full_mask[resp_logit_start:resp_logit_end] = mask_i[:resp_logit_end - resp_logit_start] + + return full_mask + + def get_log_probs_and_entropy( logits: torch.Tensor, *, @@ -391,6 +439,7 @@ def get_log_probs_and_entropy( with_entropy: bool = False, non_loss_data: bool = True, max_seq_lens: list[int] | None = None, + loss_masks: list[torch.Tensor] | None = None, ) -> dict[str, list[torch.Tensor]]: """Compute per-token log-probabilities (and optionally entropy) on responses. @@ -398,11 +447,21 @@ def get_log_probs_and_entropy( per-sample slicing) so backward traverses ``[T, V]`` only once, then extracts per-sample response portions. + When ``loss_masks`` is provided and context parallelism is disabled + (cp_size == 1), positions where the mask is zero are filtered out before + the expensive vocab-parallel softmax, and the output is padded back to the + original response length with zeros. This avoids computing log-probs and + entropy for tokens that will be masked out downstream (e.g., tool-result + tokens in multi-turn agent rollouts), reducing memory and compute by the + fraction of masked tokens. + When ``entropy_coef == 0``, entropy is computed under ``torch.no_grad()`` to avoid retaining the computation graph and to skip cloning. """ assert non_loss_data qkv_format = args.qkv_format + cp_size = mpu.get_context_parallel_world_size() + filter_by_mask = loss_masks is not None and cp_size == 1 assert logits.dtype == torch.float32, f"{logits.dtype}" assert len(logits.shape) == 3, f"{logits.shape}" @@ -429,6 +488,20 @@ def get_log_probs_and_entropy( T, device, unconcat_tokens, total_lengths, response_lengths, qkv_format, max_seq_lens, args.allgather_cp ) + # --- filter by loss_masks: remove masked positions before expensive softmax --- + keep_mask = None + if filter_by_mask: + # Build a full-sequence boolean mask aligned to logits positions + keep_mask = _build_full_loss_mask( + T, device, loss_masks, total_lengths, response_lengths, qkv_format, max_seq_lens + ) + num_kept = keep_mask.sum().item() + if num_kept < T: + logits = logits[keep_mask] + full_tokens = full_tokens[keep_mask] + else: + keep_mask = None # all positions kept, no filtering needed + # --- compute on full [T,V] logits at once via calculate_log_probs_and_entropy --- log_prob_full, entropy_full = calculate_log_probs_and_entropy( logits, @@ -439,6 +512,16 @@ def get_log_probs_and_entropy( ) log_prob_full = log_prob_full.squeeze(-1) # [T, 1] -> [T] + # --- scatter back to original length if filtering was applied --- + if keep_mask is not None: + full_lp = torch.zeros(T, device=device, dtype=log_prob_full.dtype) + full_lp[keep_mask] = log_prob_full + log_prob_full = full_lp + if entropy_full is not None: + full_ent = torch.zeros(T, device=device, dtype=entropy_full.dtype) + full_ent[keep_mask] = entropy_full + entropy_full = full_ent + # --- extract per-sample response portions --- log_probs_list, entropy_list = _extract_per_sample( log_prob_full, @@ -478,6 +561,7 @@ def get_values( with_entropy: bool = False, non_loss_data: bool = True, max_seq_lens: list[int] | None = None, + loss_masks: list[torch.Tensor] | None = None, ) -> dict[str, list[torch.Tensor]]: """Extract per-token value predictions over response tokens. @@ -839,6 +923,7 @@ def policy_loss_function( response_lengths=response_lengths, with_entropy=True, max_seq_lens=max_seq_lens, + loss_masks=batch.get("loss_masks"), ) log_probs = log_probs_and_entropy["log_probs"] @@ -1111,6 +1196,7 @@ def sft_loss_function( response_lengths=response_lengths, with_entropy=False, max_seq_lens=batch.get("max_seq_lens", None), + loss_masks=batch.get("loss_masks"), ) log_probs = log_probs_and_entropy["log_probs"] diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index f326b1d0d..78331a769 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -339,6 +339,7 @@ def forward_step( response_lengths=response_lengths, with_entropy=args.use_rollout_entropy, max_seq_lens=batch.get("max_seq_lens", None), + loss_masks=batch.get("loss_masks"), ) # Turn on evaluation mode which disables dropout. diff --git a/train_async.py b/train_async.py index 6960bd055..9f208cde9 100644 --- a/train_async.py +++ b/train_async.py @@ -31,6 +31,10 @@ def train(args): if args.check_weight_update_equal: ray.get(rollout_manager.check_weights.remote(action="compare")) + # Eval before training (parity with train.py). + if args.eval_interval is not None and args.start_rollout_id == 0 and not args.skip_eval_before_train: + ray.get(rollout_manager.eval.remote(args.start_rollout_id)) + # async train loop. rollout_data_next_future = rollout_manager.generate.remote(args.start_rollout_id) for rollout_id in range(args.start_rollout_id, args.num_rollout):