Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/gpu_skyrl_train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,5 @@ jobs:
ANYSCALE_CLI_TOKEN: ${{ secrets.ANYSCALE_CLI_TOKEN }}
ANYSCALE_HOST: https://console.anyscale.com
run: |
anyscale job submit -f ci/anyscale_gpu_ci_skyrl_train.yaml --timeout 10000
anyscale job wait --cloud sky-anyscale-aws-us-east-1 --name skyrl-train-gpu-ci --timeout 10000
anyscale job submit -f ci/anyscale_gpu_ci_skyrl_train.yaml --timeout 12000
anyscale job wait --cloud sky-anyscale-aws-us-east-1 --name skyrl-train-gpu-ci --timeout 12000
57 changes: 51 additions & 6 deletions examples/train/search/run_search.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,32 @@
set -x

# Colocated GRPO training+generation for Qwen2.5-Coder-3B-Instruct on SearchR1 data.
# follow the instructions in examples/search/README.md for setting up the dataset
# and for starting the local search server
# export WANDB_API_KEY=<your_key_here>
# bash examples/train/search/run_search.sh
# Colocated GRPO training+generation for Qwen2.5-3B-Instruct on SearchR1 data.
# Follow the instructions in docs/content/docs/recipes/searchr1.mdx for setup.
#
# Usage:
# export WANDB_API_KEY=<your_key_here>
# bash examples/train/search/run_search.sh
#
# Configurable knobs (override via env vars or command-line args):
# USE_CONVERSATION_MULTI_TURN - set to "true" to use conversation multi-turn format (default: false)
# When true, also enables append_eos_token_after_stop_str_in_multi_turn=true so that
# each turn's response ends with the model's EOS token (required for correct behavior
# when stop strings like </search> or </answer> terminate generation instead of EOS).
# STEP_WISE - set to "true" to enable step-wise training (default: false)
# Requires USE_CONVERSATION_MULTI_TURN=true.
#
# Examples:
# # Default (non-conversation, non-step-wise):
# bash examples/train/search/run_search.sh
#
# # Conversation multi-turn format:
# USE_CONVERSATION_MULTI_TURN=true bash examples/train/search/run_search.sh
#
# # Step-wise with conversation multi-turn:
# USE_CONVERSATION_MULTI_TURN=true STEP_WISE=true bash examples/train/search/run_search.sh
#
# # Override any config via positional args (passed to Hydra):
# bash examples/train/search/run_search.sh trainer.epochs=2 trainer.eval_interval=10

# path for dataset (.parquet files) containing the prompts and metadata for each question
DATA_DIR="$HOME/data/searchR1"
Expand All @@ -14,6 +36,28 @@ RUN_NAME="skyrl-search_4turns_maxgeneratelen_500-multiturn-sync-TIS_2.0"
TIS_TYPE=token
TIS_IMP_RATIO_CAP=2.0

# Configurable knobs with defaults
: "${USE_CONVERSATION_MULTI_TURN:=false}"
: "${STEP_WISE:=false}"

# Build conditional args
MULTI_TURN_ARGS=""
if [ "$USE_CONVERSATION_MULTI_TURN" = "true" ]; then
MULTI_TURN_ARGS="generator.use_conversation_multi_turn=true generator.append_eos_token_after_stop_str_in_multi_turn=true"
else
MULTI_TURN_ARGS="generator.use_conversation_multi_turn=false"
fi

STEP_WISE_ARGS=""
if [ "$STEP_WISE" = "true" ]; then
STEP_WISE_ARGS="generator.step_wise_trajectories=true"
# Step-wise requires conversation multi-turn
if [ "$USE_CONVERSATION_MULTI_TURN" != "true" ]; then
echo "WARNING: STEP_WISE=true requires USE_CONVERSATION_MULTI_TURN=true. Enabling it automatically."
MULTI_TURN_ARGS="generator.use_conversation_multi_turn=true generator.append_eos_token_after_stop_str_in_multi_turn=true"
fi
fi

uv run --isolated --frozen --extra fsdp -m skyrl.train.entrypoints.main_base \
data.train_data="['${DATA_DIR}/train.parquet']" \
data.val_data="['${DATA_DIR}/validation.parquet']" \
Expand Down Expand Up @@ -49,7 +93,8 @@ uv run --isolated --frozen --extra fsdp -m skyrl.train.entrypoints.main_base \
generator.sampling_params.max_generate_length=500 \
generator.inference_engine.async_engine=true \
generator.batched=false \
generator.use_conversation_multi_turn=false \
$MULTI_TURN_ARGS \
$STEP_WISE_ARGS \
generator.n_samples_per_prompt=5 \
generator.max_turns=4 \
generator.sampling_params.temperature=1.0 \
Expand Down
87 changes: 0 additions & 87 deletions examples/train/search/run_search_conversation_format.sh

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,10 @@ def loss_func(logits, data):
return loss, metrics

def forward_step(batch_iter, model):
# NOTE(Charlie): despite the name, methods like `remove_left_padding()` are padding-agnostic
# (can be left, or right) as it uses attention_mask to locate real tokens. Same thing
# for recover_left_padding and setup_per_microbatch_replay_forward. Especially relevant
# after this PR https://github.com/NovaSky-AI/SkyRL/pull/1285.
batch = next(batch_iter)

rollout_expert_indices = batch.pop("rollout_expert_indices", None)
Expand Down
134 changes: 80 additions & 54 deletions skyrl/train/dataset/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import logging
from typing import List, Optional, Tuple

import torch
from jaxtyping import Float, Integer
from transformers import AutoTokenizer

logger = logging.getLogger(__name__)


def _verify_inputs(
prompts: List[List[int]],
Expand Down Expand Up @@ -34,6 +37,7 @@ def convert_prompts_responses_to_batch_tensors(
loss_masks: List[List[int]],
logprobs: Optional[List[List[float]]] = None,
rollout_expert_indices: Optional[List[List[List[List[int]]]]] = None,
max_seq_len: Optional[int] = None,
) -> Tuple[
Float[torch.Tensor, "batch seq_len"],
Float[torch.Tensor, "batch seq_len"],
Expand All @@ -46,12 +50,33 @@ def convert_prompts_responses_to_batch_tensors(
"""
Convert prompts and responses to batch tensors for training.

This function concatenates all prompts and responses to the following format:
Each sequence is laid out as a single left-padded block:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this in general makes sense to me, but can you test the forward pass for megatron in test_megatron_worker::test_forward to make sure things look ok for the logic there that removes padding?

just running a basic forward pass for any model should suffice

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made a megatron gsm8k run after rebasing, seems normal. Added to the PR description. I'll address some R3 things before merging


| [PAD] [PAD] prompt prompt prompt respon respon |
| [PAD] prompt prompt prompt respon respon respon |
| prompt prompt prompt respon respon respon respon |
|<---- max_response_len ---->|

The padded sequence length is ``max(prompt_len_i + response_len_i)``.
This way, the max padded sequence length is ``max_seq_len``.

This makes the response-level tensors (action_mask, rewards, loss_masks, logprobs):
| prompt prompt respon respon |
| prompt respon respon respon |
| respon respon respon respon |

So the action_mask is:
| 0 0 1 1 |
| 0 1 1 1 |
| 1 1 1 1 |

Attention mask is 1 for all real tokens, 0 for padding.
Action mask is 1 for the last ``response_len_i`` positions, 0 for padding.

| [PAD] [PAD] token token token | token token [PAD] [PAD] |
| token token token token token | token token [PAD] [PAD] |
| [PAD] [PAD] [PAD] token token | token token token [PAD] |
|<---------- prompt ----------->|<-------- answer ------->|
Response-level tensors are **right-aligned** within ``(batch, max_response_len)``: non-padded
values occupy the last ``response_len_i`` positions, with leading zeros. This matches the model
forward pass which extracts ``log_probs[:, -num_actions-1:-1]`` —- response tokens are always at
the end of the sequence, so their logprobs are right-aligned in the slice.

Assumes that the responses already contain an eos token at index -1.

Expand All @@ -62,88 +87,89 @@ def convert_prompts_responses_to_batch_tensors(
rewards: List of rewards for each response
loss_masks: List of loss masks for each response
logprobs: List of rollout log probs for each response
max_seq_len: Optional. If provided and ``max(prompt_i + response_i)``
exceeds it, a warning is logged (no truncation is performed).

Returns:
sequences: Full trajectories (padded and concatenated prompts and responses). Size: (batch, seq_len).
attention_mask: Attention mask for the model. Size: (batch, seq_len)
action_mask: Response mask for the model. Size: (batch, response_len)
rewards: Rewards for each output. Size: (batch, response_len)
loss_masks: Loss masks for each output. Size: (batch, response_len)
sequences: ``(batch, max_total)`` where ``max_total = max(prompt_i + response_i)``.
attention_mask: ``(batch, max_total)``
action_mask: ``(batch, max_response)`` — right-aligned response indicator.
rewards: ``(batch, max_response)`` — right-aligned.
loss_masks: ``(batch, max_response)`` — right-aligned.
logprobs: ``(batch, max_response)`` — right-aligned, or ``None``.
"""
_verify_inputs(prompts, responses, rewards, loss_masks)

max_input_len, max_output_len = 0, 0
prompt_token_lens, response_token_lens = [], []
inputs_token_ids, outputs_token_ids = [], []
for prompt, response in zip(prompts, responses):
prompt_token_lens = [len(p) for p in prompts]
response_token_lens = [len(r) for r in responses]

inputs_token_ids.append(prompt)
outputs_token_ids.append(response)
max_response = max(response_token_lens)
# Pad to the tightest bound: max per-sample total.
max_total = max(p + r for p, r in zip(prompt_token_lens, response_token_lens))

prompt_token_len = len(prompt)
response_token_len = len(response)
prompt_token_lens.append(prompt_token_len)
response_token_lens.append(response_token_len)

max_input_len = max(max_input_len, prompt_token_len)
max_output_len = max(max_output_len, response_token_len)
if max_seq_len is not None and max_total > max_seq_len:
logger.warning(
f"Max sequence length in batch ({max_total}) exceeds max_seq_len ({max_seq_len}). "
f"No truncation is performed; consider checking generator settings."
)

pad_token_id = tokenizer.pad_token_id
sequences = []
attention_masks = []
action_masks = []
for i, prompt in enumerate(prompts):
# left padding input
input_len = prompt_token_lens[i]
input_ids = [pad_token_id] * (max_input_len - input_len) + list(inputs_token_ids[i])
input_attention_mask = [0] * (max_input_len - input_len) + [1] * input_len

# right padding output
output_len = response_token_lens[i]
output_ids = list(outputs_token_ids[i]) + [pad_token_id] * (max_output_len - output_len)
output_attention_mask = [1] * output_len + [0] * (max_output_len - output_len)

# concat input and output
sequences.append(input_ids + output_ids)
attention_masks.append(input_attention_mask + output_attention_mask)
action_masks.append(output_attention_mask)
for i in range(len(prompts)):
total_real = prompt_token_lens[i] + response_token_lens[i]
pad_len = max_total - total_real

# Unified left-pad: [PAD ... PAD PROMPT RESPONSE]
seq = [pad_token_id] * pad_len + prompts[i] + responses[i]
attention_mask_i = [0] * pad_len + [1] * total_real

# Response indicator within the last max_response positions (right-aligned).
resp_pad = max_response - response_token_lens[i]
action_mask_i = [0] * resp_pad + [1] * response_token_lens[i]

sequences.append(seq)
attention_masks.append(attention_mask_i)
action_masks.append(action_mask_i)

sequences = torch.tensor(sequences)
attention_mask = torch.tensor(attention_masks, dtype=torch.int64)
action_mask = torch.tensor(action_masks, dtype=torch.int64)

# initialize ret loss masks to be the same as action mask
ret_loss_masks = torch.zeros_like(action_mask, dtype=torch.float)
for i, loss_mask in enumerate(loss_masks):
ret_loss_masks[i, : len(loss_mask)] = torch.tensor(loss_mask)
# Response-level tensors are RIGHT-ALIGNED to match the model output.
# The model's log_probs[:, -num_actions-1:-1] returns logprobs where
# response tokens occupy the last response_len_i positions.
ret_loss_masks = torch.zeros(len(prompts), max_response, dtype=torch.float)
for i, lm in enumerate(loss_masks):
ret_loss_masks[i, max_response - len(lm) :] = torch.tensor(lm, dtype=torch.float)

# do the same for custom rewards
ret_rewards = torch.zeros_like(action_mask, dtype=torch.float)
# Same thing for rewards.
ret_rewards = torch.zeros(len(prompts), max_response, dtype=torch.float)
for i, custom_reward in enumerate(rewards):
if isinstance(custom_reward, list):
custom_reward = torch.tensor(custom_reward)
ret_rewards[i, : len(custom_reward)] = custom_reward
ret_rewards[i, max_response - len(custom_reward) :] = custom_reward

# Same thing for logprobs.
logprobs_tensor = None
if logprobs:
max_output_len = action_mask.size(1)
padded_logprobs = [
sample_logprobs + [0.0] * (max_output_len - len(sample_logprobs)) for sample_logprobs in logprobs
]
logprobs_tensor = torch.tensor(padded_logprobs, dtype=torch.float)
logprobs_tensor = torch.zeros(len(prompts), max_response, dtype=torch.float)
for i, sample_logprobs in enumerate(logprobs):
lp = torch.tensor(sample_logprobs, dtype=torch.float)
logprobs_tensor[i, max_response - len(sample_logprobs) :] = lp

rollout_expert_indices_tensor = None
if rollout_expert_indices:
first_non_empty = next((x for x in rollout_expert_indices if x), None)
if first_non_empty:
total_seq_len = max_input_len + max_output_len
num_layers = len(first_non_empty[0])
topk = len(first_non_empty[0][0]) if num_layers > 0 else 0
padded = torch.zeros(len(rollout_expert_indices), total_seq_len, num_layers, topk, dtype=torch.int32)
padded = torch.zeros(len(rollout_expert_indices), max_total, num_layers, topk, dtype=torch.int32)
for i, sample_indices in enumerate(rollout_expert_indices):
if sample_indices:
left_pad = max_input_len - prompt_token_lens[i]
n = min(len(sample_indices), total_seq_len - left_pad)
left_pad = max_total - (prompt_token_lens[i] + response_token_lens[i])
n = min(len(sample_indices), max_total - left_pad)
padded[i, left_pad : left_pad + n] = torch.tensor(sample_indices[:n], dtype=torch.int32)
rollout_expert_indices_tensor = padded

Expand Down
Loading
Loading