Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions slime/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,54 @@ def _extract_per_sample(
return log_probs_list, entropy_list


def _build_full_loss_mask(
T: int,
device: torch.device,
loss_masks: list[torch.Tensor],
total_lengths: list[int],
response_lengths: list[int],
qkv_format: str,
max_seq_lens: list[int] | None,
) -> torch.Tensor:
"""Build a boolean mask over all T logit positions marking response tokens to keep.

For each sample, the loss_mask covers the response tokens. This function maps
those per-sample masks onto the full packed/padded logit tensor layout (matching
the layout used by _build_shifted_tokens and _extract_per_sample for cp_size==1).

Positions corresponding to prompt tokens or padding are marked True (kept) so that
the shifted-token alignment is preserved. Only response positions where
loss_mask == 0 are marked False (filtered out).

Returns:
Boolean tensor of shape [T] where True means "compute this position".
"""
# Start with all True — prompts and padding are kept by default
full_mask = torch.ones(T, dtype=torch.bool, device=device)

if qkv_format == "thd":
offset = 0
for i, (total_length, response_length) in enumerate(zip(total_lengths, response_lengths, strict=False)):
end = offset + total_length
start = end - response_length
# The logits for response tokens are at positions [start-1, end-1)
resp_logit_start = start - 1
resp_logit_end = end - 1
mask_i = loss_masks[i].bool()
full_mask[resp_logit_start:resp_logit_end] = mask_i[:resp_logit_end - resp_logit_start]
offset += total_length
else: # bshd
for i, (total_length, response_length) in enumerate(zip(total_lengths, response_lengths, strict=False)):
end = max_seq_lens[i] * i + total_length
start = end - response_length
resp_logit_start = start - 1
resp_logit_end = end - 1
mask_i = loss_masks[i].bool()
full_mask[resp_logit_start:resp_logit_end] = mask_i[:resp_logit_end - resp_logit_start]

return full_mask


def get_log_probs_and_entropy(
logits: torch.Tensor,
*,
Expand All @@ -391,18 +439,29 @@ def get_log_probs_and_entropy(
with_entropy: bool = False,
non_loss_data: bool = True,
max_seq_lens: list[int] | None = None,
loss_masks: list[torch.Tensor] | None = None,
) -> dict[str, list[torch.Tensor]]:
"""Compute per-token log-probabilities (and optionally entropy) on responses.

Computes on the **full** logits ``[T, V]`` tensor at once (instead of
per-sample slicing) so backward traverses ``[T, V]`` only once, then
extracts per-sample response portions.

When ``loss_masks`` is provided and context parallelism is disabled
(cp_size == 1), positions where the mask is zero are filtered out before
the expensive vocab-parallel softmax, and the output is padded back to the
original response length with zeros. This avoids computing log-probs and
entropy for tokens that will be masked out downstream (e.g., tool-result
tokens in multi-turn agent rollouts), reducing memory and compute by the
fraction of masked tokens.

When ``entropy_coef == 0``, entropy is computed under ``torch.no_grad()``
to avoid retaining the computation graph and to skip cloning.
"""
assert non_loss_data
qkv_format = args.qkv_format
cp_size = mpu.get_context_parallel_world_size()
filter_by_mask = loss_masks is not None and cp_size == 1

assert logits.dtype == torch.float32, f"{logits.dtype}"
assert len(logits.shape) == 3, f"{logits.shape}"
Expand All @@ -429,6 +488,20 @@ def get_log_probs_and_entropy(
T, device, unconcat_tokens, total_lengths, response_lengths, qkv_format, max_seq_lens, args.allgather_cp
)

# --- filter by loss_masks: remove masked positions before expensive softmax ---
keep_mask = None
if filter_by_mask:
# Build a full-sequence boolean mask aligned to logits positions
keep_mask = _build_full_loss_mask(
T, device, loss_masks, total_lengths, response_lengths, qkv_format, max_seq_lens
)
num_kept = keep_mask.sum().item()
if num_kept < T:
logits = logits[keep_mask]
full_tokens = full_tokens[keep_mask]
else:
keep_mask = None # all positions kept, no filtering needed

# --- compute on full [T,V] logits at once via calculate_log_probs_and_entropy ---
log_prob_full, entropy_full = calculate_log_probs_and_entropy(
logits,
Expand All @@ -439,6 +512,16 @@ def get_log_probs_and_entropy(
)
log_prob_full = log_prob_full.squeeze(-1) # [T, 1] -> [T]

# --- scatter back to original length if filtering was applied ---
if keep_mask is not None:
full_lp = torch.zeros(T, device=device, dtype=log_prob_full.dtype)
full_lp[keep_mask] = log_prob_full
log_prob_full = full_lp
if entropy_full is not None:
full_ent = torch.zeros(T, device=device, dtype=entropy_full.dtype)
full_ent[keep_mask] = entropy_full
entropy_full = full_ent

# --- extract per-sample response portions ---
log_probs_list, entropy_list = _extract_per_sample(
log_prob_full,
Expand Down Expand Up @@ -478,6 +561,7 @@ def get_values(
with_entropy: bool = False,
non_loss_data: bool = True,
max_seq_lens: list[int] | None = None,
loss_masks: list[torch.Tensor] | None = None,
) -> dict[str, list[torch.Tensor]]:
"""Extract per-token value predictions over response tokens.

Expand Down Expand Up @@ -839,6 +923,7 @@ def policy_loss_function(
response_lengths=response_lengths,
with_entropy=True,
max_seq_lens=max_seq_lens,
loss_masks=batch.get("loss_masks"),
)

log_probs = log_probs_and_entropy["log_probs"]
Expand Down Expand Up @@ -1111,6 +1196,7 @@ def sft_loss_function(
response_lengths=response_lengths,
with_entropy=False,
max_seq_lens=batch.get("max_seq_lens", None),
loss_masks=batch.get("loss_masks"),
)

log_probs = log_probs_and_entropy["log_probs"]
Expand Down
1 change: 1 addition & 0 deletions slime/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def forward_step(
response_lengths=response_lengths,
with_entropy=args.use_rollout_entropy,
max_seq_lens=batch.get("max_seq_lens", None),
loss_masks=batch.get("loss_masks"),
)

# Turn on evaluation mode which disables dropout.
Expand Down
4 changes: 4 additions & 0 deletions train_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def train(args):
if args.check_weight_update_equal:
ray.get(rollout_manager.check_weights.remote(action="compare"))

# Eval before training (parity with train.py).
if args.eval_interval is not None and args.start_rollout_id == 0 and not args.skip_eval_before_train:
ray.get(rollout_manager.eval.remote(args.start_rollout_id))

# async train loop.
rollout_data_next_future = rollout_manager.generate.remote(args.start_rollout_id)
for rollout_id in range(args.start_rollout_id, args.num_rollout):
Expand Down