From 8f3f54f6e4d4a53850f0ada42ae13da1634a096c Mon Sep 17 00:00:00 2001 From: Dev Patel Date: Thu, 15 Jan 2026 07:10:14 +0000 Subject: [PATCH 1/5] x --- skyrl-train/skyrl_train/dataset/preprocess.py | 34 +++++++++- .../skyrl_train/dataset/replay_buffer.py | 1 + skyrl-train/skyrl_train/generators/base.py | 1 + .../skyrl_train/inference_engines/base.py | 1 + .../inference_engine_client.py | 6 ++ .../inference_engines/vllm/vllm_engine.py | 65 +++++++++++++++++++ skyrl-train/skyrl_train/model_wrapper.py | 6 +- skyrl-train/skyrl_train/trainer.py | 10 +++ skyrl-train/skyrl_train/training_batch.py | 3 + skyrl-train/skyrl_train/utils/torch_utils.py | 35 ++++++++++ .../skyrl_train/utils/trainer_utils.py | 11 ++++ skyrl-train/skyrl_train/workers/worker.py | 2 + .../skyrl_train/workers/worker_utils.py | 3 + 13 files changed, 175 insertions(+), 3 deletions(-) diff --git a/skyrl-train/skyrl_train/dataset/preprocess.py b/skyrl-train/skyrl_train/dataset/preprocess.py index a17cea1b91..6a2efcc46c 100644 --- a/skyrl-train/skyrl_train/dataset/preprocess.py +++ b/skyrl-train/skyrl_train/dataset/preprocess.py @@ -1,7 +1,7 @@ from typing import List, Tuple, Optional import torch from transformers import AutoTokenizer -from jaxtyping import Float +from jaxtyping import Float, Integer def _verify_inputs( @@ -32,6 +32,7 @@ def convert_prompts_responses_to_batch_tensors( rewards: List[List[float]], loss_masks: List[List[int]], logprobs: Optional[List[List[float]]] = None, + sampling_masks: Optional[List[List[List[int]]]] = None, ) -> Tuple[ Float[torch.Tensor, "batch seq_len"], Float[torch.Tensor, "batch seq_len"], @@ -39,6 +40,7 @@ def convert_prompts_responses_to_batch_tensors( Float[torch.Tensor, "batch response_len"], Float[torch.Tensor, "batch response_len"], Optional[Float[torch.Tensor, "batch response_len"]], + Optional[Integer[torch.Tensor, "batch response_len mask_size"]], ]: """ Convert prompts and responses to batch tensors for training. @@ -59,6 +61,7 @@ def convert_prompts_responses_to_batch_tensors( rewards: List of rewards for each response loss_masks: List of loss masks for each response logprobs: List of rollout log probs for each response + sampling_masks: Optional list of sampling masks (top-k/top-p valid token indices) for each response Returns: sequences: Full trajectories (padded and concatenated prompts and responses). Size: (batch, seq_len). @@ -66,6 +69,8 @@ def convert_prompts_responses_to_batch_tensors( action_mask: Response mask for the model. Size: (batch, response_len) rewards: Rewards for each output. Size: (batch, response_len) loss_masks: Loss masks for each output. Size: (batch, response_len) + logprobs_tensor: Rollout log probs for each output. Size: (batch, response_len) + sampling_masks_tensor: Sampling masks tensor. Size: (batch, response_len, max_k) with -1 padding """ _verify_inputs(prompts, responses, rewards, loss_masks) @@ -129,4 +134,29 @@ def convert_prompts_responses_to_batch_tensors( ] logprobs_tensor = torch.tensor(padded_logprobs, dtype=torch.float) - return sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, logprobs_tensor + sampling_masks_tensor = None + if sampling_masks: + batch_size = len(sampling_masks) + max_seq_len = action_mask.size(1) + + max_k = 0 + for sample_masks in sampling_masks: + for step_mask in sample_masks: + max_k = max(max_k, len(step_mask)) + + if max_k > 0: + # shape: (batch_size, seq_len, max_k) + sampling_masks_tensor = torch.full( + (batch_size, max_seq_len, max_k), + fill_value=-1, + dtype=torch.int64, + ) + + for i, sample_masks in enumerate(sampling_masks): + for j, step_mask in enumerate(sample_masks): + if j < max_seq_len: + num_valid = len(step_mask) + if num_valid > 0: + sampling_masks_tensor[i, j, :num_valid] = torch.tensor(step_mask, dtype=torch.int64) + + return sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, logprobs_tensor, sampling_masks_tensor diff --git a/skyrl-train/skyrl_train/dataset/replay_buffer.py b/skyrl-train/skyrl_train/dataset/replay_buffer.py index b4929919a9..57e25e53c7 100644 --- a/skyrl-train/skyrl_train/dataset/replay_buffer.py +++ b/skyrl-train/skyrl_train/dataset/replay_buffer.py @@ -66,6 +66,7 @@ class Experience: loss_mask: Optional[Integer[torch.LongTensor, "batch response_len"]] action_mask: Optional[Integer[torch.Tensor, "batch response_len"]] rollout_logprobs: Optional[Float[torch.Tensor, "batch response_len"]] + sampling_mask: Optional[Integer[torch.Tensor, "batch seq_len mask_size"]] num_actions: int info: Optional[dict] kl: Optional[Float[torch.Tensor, "batch response_len"]] = None diff --git a/skyrl-train/skyrl_train/generators/base.py b/skyrl-train/skyrl_train/generators/base.py index b0c677f852..93c6164a27 100644 --- a/skyrl-train/skyrl_train/generators/base.py +++ b/skyrl-train/skyrl_train/generators/base.py @@ -41,6 +41,7 @@ class GeneratorOutput(TypedDict): trajectory_ids: Optional[List[TrajectoryID]] # Applicable only for step-wise training is_last_step: Optional[List[bool]] + sampling_masks: Optional[List[List[List[int]]]] class MetricsOutput(TypedDict): diff --git a/skyrl-train/skyrl_train/inference_engines/base.py b/skyrl-train/skyrl_train/inference_engines/base.py index 392e2100eb..a49bae94ec 100644 --- a/skyrl-train/skyrl_train/inference_engines/base.py +++ b/skyrl-train/skyrl_train/inference_engines/base.py @@ -29,6 +29,7 @@ class InferenceEngineOutput(TypedDict): response_ids: List[List[int]] stop_reasons: List[str] response_logprobs: Optional[List[List[float]]] + sampling_masks: Optional[List[List[List[int]]]] class InferenceEngineInterface(ABC): diff --git a/skyrl-train/skyrl_train/inference_engines/inference_engine_client.py b/skyrl-train/skyrl_train/inference_engines/inference_engine_client.py index ae4ffc2510..92d2e93200 100644 --- a/skyrl-train/skyrl_train/inference_engines/inference_engine_client.py +++ b/skyrl-train/skyrl_train/inference_engines/inference_engine_client.py @@ -137,6 +137,8 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu # a bit hacky for now add_resp_logprobs = False + sampling_masks: List[List[List[int]]] = [[] for _ in range(n)] + for indices, result in zip(indices_list, results): for local_idx, original_idx in enumerate(indices): responses[original_idx] = result["responses"][local_idx] @@ -145,12 +147,16 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu if result.get("response_logprobs", None): add_resp_logprobs = True response_logprobs[original_idx] = result["response_logprobs"][local_idx] + # TODO(devpatel): see patch in vllm_engine.py for more details. + if result.get("sampling_masks", None): + sampling_masks[original_idx] = result["sampling_masks"][local_idx] return InferenceEngineOutput( responses=responses, stop_reasons=stop_reasons, response_ids=response_ids, response_logprobs=response_logprobs if add_resp_logprobs else None, + sampling_masks=sampling_masks, ) async def _generate_single_with_retry( diff --git a/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py index 698c60db3f..42a9240abf 100644 --- a/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -1,5 +1,6 @@ import os from typing import List, Any, Dict, Optional, TYPE_CHECKING +import threading if TYPE_CHECKING: from skyrl_train.weight_sync.transfer_strategy import WeightSyncInitInfo @@ -37,6 +38,55 @@ from packaging import version +# TODO(devpatel): This is a hack to get the sampling masks. We should find a better way to do this... fast +_sampling_masks = threading.local() +_sampler_patched = False + + +def _reset_sampling_masks() -> None: + _sampling_masks.items = [] + + +def _append_sampling_mask(mask: torch.Tensor) -> None: + if not hasattr(_sampling_masks, "items"): + _sampling_masks.items = [] + _sampling_masks.items.append(mask) + + +def _consume_sampling_masks() -> Optional[List[torch.Tensor]]: + masks = getattr(_sampling_masks, "items", None) + _sampling_masks.items = [] + return masks + + +def _patch_vllm_sampler() -> None: + global _sampler_patched + if _sampler_patched: + return + try: + from vllm.v1.sample.ops import topk_topp_sampler as sampler + except Exception as exc: + logger.warning(f"Could not import vLLM topk_topp_sampler op and/or Sampler class: {exc}") + return + + original_top_k_top_p = sampler.apply_top_k_top_p + original_top_k_only = sampler.apply_top_k_only + + def _wrapped_top_k_top_p(logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None) -> torch.Tensor: + output = original_top_k_top_p(logits, k, p) + _append_sampling_mask(torch.isfinite(output).to(dtype=torch.bool).cpu()) + return output + + def _wrapped_top_k_only(logits: torch.Tensor, k: torch.Tensor) -> torch.Tensor: + output = original_top_k_only(logits, k) + _append_sampling_mask(torch.isfinite(output).to(dtype=torch.bool).cpu()) + return output + + sampler.apply_top_k_top_p = _wrapped_top_k_top_p + sampler.apply_top_k_only = _wrapped_top_k_only + _sampler_patched = True + + @dataclass class Logprob: logprob: float @@ -137,6 +187,7 @@ class BaseVLLMInferenceEngine(InferenceEngineInterface): def __init__(self, *args, bundle_indices: list = None, **kwargs): setup_envvars_for_vllm(kwargs, bundle_indices) + _patch_vllm_sampler() vllm_v1_disable_multiproc = kwargs.pop("vllm_v1_disable_multiproc", False) if vllm_v1_disable_multiproc or vllm.__version__ == "0.8.2": # https://github.com/vllm-project/vllm/blob/effc5d24fae10b29996256eb7a88668ff7941aed/examples/offline_inference/reproduciblity.py#L11 @@ -169,6 +220,7 @@ def _create_engine(self, *args, **kwargs): def _preprocess_prompts(self, input_batch: InferenceEngineInput): """Common prompt preprocessing logic.""" + _reset_sampling_masks() prompts = input_batch.get("prompts") prompt_token_ids = input_batch.get("prompt_token_ids") request_sampling_params = input_batch.get("sampling_params") @@ -213,11 +265,24 @@ def _postprocess_outputs(self, outputs): if len(response_logprobs) and response_logprobs[0] is None: response_logprobs = None # hack: assume uniform sampling params + sampling_masks = None + masks = _consume_sampling_masks() + if masks: + sampling_masks = [] + # TODO(devpatel): We don't have the request_ids in the sampling metadata, so order by index. + for output_idx in range(len(outputs)): + per_request = [] + for step_mask in masks: + if output_idx < step_mask.shape[0]: + per_request.append(step_mask[output_idx].nonzero(as_tuple=False).squeeze(-1).tolist()) + sampling_masks.append(per_request) + return InferenceEngineOutput( responses=responses, stop_reasons=stop_reasons, response_ids=response_ids, response_logprobs=response_logprobs, + sampling_masks=sampling_masks, ) def _get_engine(self): diff --git a/skyrl-train/skyrl_train/model_wrapper.py b/skyrl-train/skyrl_train/model_wrapper.py index 847dc61027..3b94f8f303 100644 --- a/skyrl-train/skyrl_train/model_wrapper.py +++ b/skyrl-train/skyrl_train/model_wrapper.py @@ -13,7 +13,7 @@ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, BitsAndBytesConfig import numpy as np from skyrl_train.distributed.ulysses.utils import ulysses_pad_and_slice_inputs, gather_outputs_and_unpad -from skyrl_train.utils.torch_utils import chunked_entropy_from_logits, logprobs_from_logits +from skyrl_train.utils.torch_utils import chunked_entropy_from_logits, logprobs_from_logits, apply_sampling_mask from flash_attn.bert_padding import pad_input, unpad_input from packaging.version import Version @@ -267,6 +267,7 @@ def forward( return_output=False, compute_entropy=False, entropy_requires_grad=True, + sampling_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Returns action log probs""" position_ids = attention_mask.long().cumsum(-1) - 1 @@ -313,6 +314,9 @@ def forward( logits_BSV = output["logits"] logits_BSV.div_(temperature) + if sampling_mask: + logits_BSV = apply_sampling_mask(logits_BSV, sampling_mask) + # NOTE: this is slightly inaccurate with sample packing because last token from nth seq -> first token of n+1th seq loss is added. log_probs = logprobs_from_logits( logits_BSV, diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index 6a9695ebdd..9682c94bc1 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -569,6 +569,8 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis rewards: List[List[float]] = generator_output["rewards"] loss_masks: List[List[int]] = generator_output["loss_masks"] + # TODO(devpatel): test if handoff is working correctly for batching. + sampling_masks: Optional[List[List[List[int]]]] = generator_output.get("sampling_masks", None) logprobs: Optional[List[List[float]]] = generator_output.get("rollout_logprobs", None) ( @@ -578,6 +580,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis rewards_tensor, loss_masks_tensor, rollout_logprobs_tensor, + sampling_masks_tensor, ) = convert_prompts_responses_to_batch_tensors( self.tokenizer, prompt_ids, @@ -585,6 +588,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis rewards, loss_masks, logprobs, + sampling_masks, ) # sanity check for tis if self.cfg.trainer.algorithm.use_tis: @@ -592,6 +596,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis rollout_logprobs_tensor is not None ), "expected non-null rollout logprobs tensor with `trainer.algorithm.use_tis` as `True`" assert rollout_logprobs_tensor.shape == loss_masks_tensor.shape, "Logprobs should look like responses" + training_input = TrainingInputBatch( { "sequences": sequences_tensor, # Full trajectories (padded and concatenated prompts and responses) @@ -605,6 +610,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis if generator_output.get("is_last_step", None) is not None else None ), + "sampling_mask": sampling_masks_tensor, }, ) training_input.metadata = {"uids": uids} @@ -861,6 +867,10 @@ def pad_batch(self, training_input: TrainingInputBatch) -> TrainingInputBatch: elif key == "loss_mask": # ensures that padding tensors don't count towards the loss padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device) + elif key == "sampling_mask": + padding_tensor = torch.full( + (pad_size, *additional_dims), fill_value=-1, dtype=tensor.dtype, device=tensor.device + ) else: # ensures all padding tensors are in a valid format by cloning `pad_size` from the original input # `pad_size` is guaranteed to be smaller than batch_size diff --git a/skyrl-train/skyrl_train/training_batch.py b/skyrl-train/skyrl_train/training_batch.py index 1c455db1c3..9dd19de153 100644 --- a/skyrl-train/skyrl_train/training_batch.py +++ b/skyrl-train/skyrl_train/training_batch.py @@ -333,6 +333,9 @@ class TrainingInput(TypedDict, total=False): kl: Float[torch.Tensor, "batch_size seq_len"] rewards: Optional[Float[torch.Tensor, "batch_size seq_len"]] rollout_logprobs: Optional[Float[torch.Tensor, "batch_size seq_len"]] + sampling_mask: Optional[ + Integer[torch.Tensor, "batch_size seq_len mask_size"] + ] ## logits mask for sampling truncation, see https://arxiv.org/pdf/2512.02556 (3.1) class TrainingInputBatch(TensorBatch[TrainingInput]): diff --git a/skyrl-train/skyrl_train/utils/torch_utils.py b/skyrl-train/skyrl_train/utils/torch_utils.py index 06b89a01b1..1b41f955f1 100644 --- a/skyrl-train/skyrl_train/utils/torch_utils.py +++ b/skyrl-train/skyrl_train/utils/torch_utils.py @@ -175,3 +175,38 @@ def logprobs_from_logits_v2( logprobs_labels.append(row_logprobs_labels) logprobs_labels = torch.stack(logprobs_labels) return logprobs_labels + + +# def compute_sampling_mask( +# logits: Float[torch.Tensor, "batch_size seqlen vocab_size"], +# top_k: int = None, +# top_p: float = None, +# min_p: float = None, +# ) -> Float[torch.Tensor, "batch_size seqlen vocab_size"]: +# pass + + +def apply_sampling_mask( + logits: Float[torch.Tensor, "batch_size seqlen top_tokens"], + sampling_mask: Integer[torch.Tensor, "batch_size seqlen mask_size"], +) -> Float[torch.Tensor, "batch_size seqlen top_tokens"]: + + if sampling_mask is None: + return logits + + batch_size, seqlen, vocab_size = logits.shape + device = logits.device + + valid_token_mask = torch.zeros((batch_size, seqlen, vocab_size), dtype=torch.bool, device=device) + + for b in range(batch_size): + for s in range(seqlen): + valid_indices = sampling_mask[b, s] + valid_indices = valid_indices[valid_indices >= 0] + if len(valid_indices) > 0: + valid_token_mask[b, s, valid_indices] = True + + masked_logits = logits.clone() + masked_logits[~valid_token_mask] = float("-inf") + + return masked_logits diff --git a/skyrl-train/skyrl_train/utils/trainer_utils.py b/skyrl-train/skyrl_train/utils/trainer_utils.py index e53aa20172..d29972ad6c 100644 --- a/skyrl-train/skyrl_train/utils/trainer_utils.py +++ b/skyrl-train/skyrl_train/utils/trainer_utils.py @@ -417,6 +417,8 @@ def handle_replace_sampling( if generator_output["rollout_logprobs"]: generator_output["rollout_logprobs"][bad_idx] = generator_output["rollout_logprobs"][replacement_idx] + if generator_output.get("sampling_masks"): + generator_output["sampling_masks"][bad_idx] = generator_output["sampling_masks"][replacement_idx] # Update UIDs accordingly replaced_uids = uids.copy() @@ -555,6 +557,9 @@ def filter_generator_output(output: GeneratorOutput, kept_indices: List[int]) -> "rollout_logprobs": ( [output["rollout_logprobs"][i] for i in kept_indices] if output["rollout_logprobs"] else None ), + "sampling_masks": ( + [output["sampling_masks"][i] for i in kept_indices] if output.get("sampling_masks") else None + ), } if output.get("stop_reasons"): @@ -613,6 +618,7 @@ def validate_generator_output(num_prompts: int, generator_output: GeneratorOutpu "loss_masks", "rewards", "rollout_logprobs", + "sampling_masks", ]: assert len(generator_output[key]) == len(generator_output["response_ids"]), ( f"Generator output {key} length must be equal to response_ids length, " @@ -639,6 +645,11 @@ def validate_generator_output(num_prompts: int, generator_output: GeneratorOutpu f"Response ids and rollout logprobs must have the same length, " f"for sample {i} got {len(response_ids)} and {len(generator_output['rollout_logprobs'][i])}" ) + if generator_output.get("sampling_masks"): + assert len(response_ids) == len(generator_output["sampling_masks"][i]), ( + f"Response ids and sampling masks must have the same length, " + f"for sample {i} got {len(response_ids)} and {len(generator_output['sampling_masks'][i])}" + ) # loss masks should be non-zero for at least one element for trainer if np.concatenate(generator_output["loss_masks"]).sum() == 0: diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 361a0777ec..2df0c52943 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -864,6 +864,7 @@ def _forward_micro_batch(self, micro_batch: TrainingInputBatch) -> TrainingOutpu sequences = micro_batch["sequences"] response_length = micro_batch.metadata["response_length"] attention_mask = micro_batch["attention_mask"] + sampling_mask = micro_batch.get("sampling_mask", None) with torch.no_grad(), torch.autocast(dtype=torch.bfloat16, device_type="cuda"): policy_logprob = self.model( @@ -872,6 +873,7 @@ def _forward_micro_batch(self, micro_batch: TrainingInputBatch) -> TrainingOutpu attention_mask, return_output=False, temperature=self.cfg.generator.sampling_params.temperature, + sampling_mask=sampling_mask, ) policy_logprob = policy_logprob.to("cpu") output = TrainingOutputBatch( diff --git a/skyrl-train/skyrl_train/workers/worker_utils.py b/skyrl-train/skyrl_train/workers/worker_utils.py index 897d032ea8..25b1ef6b76 100644 --- a/skyrl-train/skyrl_train/workers/worker_utils.py +++ b/skyrl-train/skyrl_train/workers/worker_utils.py @@ -60,6 +60,9 @@ def batch_to_experience(batch: TrainingInputBatch): action_mask=batch["response_mask"], num_actions=batch.metadata["response_length"], # int rollout_logprobs=batch["rollout_logprobs"] if "rollout_logprobs" in batch else None, + sampling_mask=( + batch["sampling_mask"] if "sampling_mask" in batch else None + ), # shape: (batch_size, seq_len, top_tokens) # additional info # can be used to log metrics etc for micro-batches in the worker info={}, From 36ea047d3bc363c0730813a11f93d50cdde4f8cc Mon Sep 17 00:00:00 2001 From: Dev Patel Date: Thu, 15 Jan 2026 16:47:12 +0000 Subject: [PATCH 2/5] x --- skyrl-train/skyrl_train/utils/torch_utils.py | 11 ++++------- skyrl-train/skyrl_train/utils/trainer_utils.py | 11 ----------- skyrl-train/tests/cpu/dataset/test_preprocess.py | 4 +++- .../cpu/inf_engines/test_inference_engine_client.py | 2 +- 4 files changed, 8 insertions(+), 20 deletions(-) diff --git a/skyrl-train/skyrl_train/utils/torch_utils.py b/skyrl-train/skyrl_train/utils/torch_utils.py index 1b41f955f1..6df2dc4070 100644 --- a/skyrl-train/skyrl_train/utils/torch_utils.py +++ b/skyrl-train/skyrl_train/utils/torch_utils.py @@ -197,14 +197,11 @@ def apply_sampling_mask( batch_size, seqlen, vocab_size = logits.shape device = logits.device + # TODO(devpatel) if we sort the tokens, then indices might be wrong valid_token_mask = torch.zeros((batch_size, seqlen, vocab_size), dtype=torch.bool, device=device) - - for b in range(batch_size): - for s in range(seqlen): - valid_indices = sampling_mask[b, s] - valid_indices = valid_indices[valid_indices >= 0] - if len(valid_indices) > 0: - valid_token_mask[b, s, valid_indices] = True + valid = sampling_mask >= 0 + idx = sampling_mask.clamp(min=0) + valid_token_mask.scatter_(dim=2, index=idx, src=valid) masked_logits = logits.clone() masked_logits[~valid_token_mask] = float("-inf") diff --git a/skyrl-train/skyrl_train/utils/trainer_utils.py b/skyrl-train/skyrl_train/utils/trainer_utils.py index d29972ad6c..e53aa20172 100644 --- a/skyrl-train/skyrl_train/utils/trainer_utils.py +++ b/skyrl-train/skyrl_train/utils/trainer_utils.py @@ -417,8 +417,6 @@ def handle_replace_sampling( if generator_output["rollout_logprobs"]: generator_output["rollout_logprobs"][bad_idx] = generator_output["rollout_logprobs"][replacement_idx] - if generator_output.get("sampling_masks"): - generator_output["sampling_masks"][bad_idx] = generator_output["sampling_masks"][replacement_idx] # Update UIDs accordingly replaced_uids = uids.copy() @@ -557,9 +555,6 @@ def filter_generator_output(output: GeneratorOutput, kept_indices: List[int]) -> "rollout_logprobs": ( [output["rollout_logprobs"][i] for i in kept_indices] if output["rollout_logprobs"] else None ), - "sampling_masks": ( - [output["sampling_masks"][i] for i in kept_indices] if output.get("sampling_masks") else None - ), } if output.get("stop_reasons"): @@ -618,7 +613,6 @@ def validate_generator_output(num_prompts: int, generator_output: GeneratorOutpu "loss_masks", "rewards", "rollout_logprobs", - "sampling_masks", ]: assert len(generator_output[key]) == len(generator_output["response_ids"]), ( f"Generator output {key} length must be equal to response_ids length, " @@ -645,11 +639,6 @@ def validate_generator_output(num_prompts: int, generator_output: GeneratorOutpu f"Response ids and rollout logprobs must have the same length, " f"for sample {i} got {len(response_ids)} and {len(generator_output['rollout_logprobs'][i])}" ) - if generator_output.get("sampling_masks"): - assert len(response_ids) == len(generator_output["sampling_masks"][i]), ( - f"Response ids and sampling masks must have the same length, " - f"for sample {i} got {len(response_ids)} and {len(generator_output['sampling_masks'][i])}" - ) # loss masks should be non-zero for at least one element for trainer if np.concatenate(generator_output["loss_masks"]).sum() == 0: diff --git a/skyrl-train/tests/cpu/dataset/test_preprocess.py b/skyrl-train/tests/cpu/dataset/test_preprocess.py index 4143ea9c0a..29b49d67a3 100644 --- a/skyrl-train/tests/cpu/dataset/test_preprocess.py +++ b/skyrl-train/tests/cpu/dataset/test_preprocess.py @@ -69,14 +69,16 @@ def test_convert_prompts_responses_to_batch_tensors_exact(tokenizer, cfg): loss_masks = [[1, 1, 0], [1, 1, 1, 0, 0]] rewards = [torch.tensor([0, 1, 0]), torch.tensor([1, 0, 0, 0, 0])] + sampling_masks = [[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]] - sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, ret_log_probs = ( + sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, ret_log_probs, ret_sampling_masks = ( convert_prompts_responses_to_batch_tensors( tokenizer, prompts, outputs, rewards, loss_masks, + sampling_masks, ) ) diff --git a/skyrl-train/tests/cpu/inf_engines/test_inference_engine_client.py b/skyrl-train/tests/cpu/inf_engines/test_inference_engine_client.py index 6a9853027f..b705051af0 100644 --- a/skyrl-train/tests/cpu/inf_engines/test_inference_engine_client.py +++ b/skyrl-train/tests/cpu/inf_engines/test_inference_engine_client.py @@ -1,5 +1,5 @@ """ -Test for `skyrl-train/skyrl_train/inference_engines/inference_engine_client.py` functinoalities +Test for `skyrl-train/skyrl_train/inference_engines/inference_engine_client.py` functionalities that can be mocked. Also tests for `skyrl-train/skyrl_train/inference_engines/utils.py`. Run with: From 0877ed1c72a4d0028db64c17d33a5130a72fcc0d Mon Sep 17 00:00:00 2001 From: Dev Patel Date: Thu, 15 Jan 2026 16:58:02 +0000 Subject: [PATCH 3/5] done --- skyrl-train/skyrl_train/dataset/preprocess.py | 1 - 1 file changed, 1 deletion(-) diff --git a/skyrl-train/skyrl_train/dataset/preprocess.py b/skyrl-train/skyrl_train/dataset/preprocess.py index 6a2efcc46c..4f08e59f6e 100644 --- a/skyrl-train/skyrl_train/dataset/preprocess.py +++ b/skyrl-train/skyrl_train/dataset/preprocess.py @@ -69,7 +69,6 @@ def convert_prompts_responses_to_batch_tensors( action_mask: Response mask for the model. Size: (batch, response_len) rewards: Rewards for each output. Size: (batch, response_len) loss_masks: Loss masks for each output. Size: (batch, response_len) - logprobs_tensor: Rollout log probs for each output. Size: (batch, response_len) sampling_masks_tensor: Sampling masks tensor. Size: (batch, response_len, max_k) with -1 padding """ _verify_inputs(prompts, responses, rewards, loss_masks) From 407463449ba7488da9924a869e7f6bcc2cfa086e Mon Sep 17 00:00:00 2001 From: Dev Patel Date: Thu, 15 Jan 2026 17:08:53 +0000 Subject: [PATCH 4/5] fix tests --- skyrl-train/tests/cpu/dataset/test_preprocess.py | 5 +++-- skyrl-train/tests/cpu/generators/test_skyrl_gym_generator.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/skyrl-train/tests/cpu/dataset/test_preprocess.py b/skyrl-train/tests/cpu/dataset/test_preprocess.py index 29b49d67a3..7843be8435 100644 --- a/skyrl-train/tests/cpu/dataset/test_preprocess.py +++ b/skyrl-train/tests/cpu/dataset/test_preprocess.py @@ -78,7 +78,8 @@ def test_convert_prompts_responses_to_batch_tensors_exact(tokenizer, cfg): outputs, rewards, loss_masks, - sampling_masks, + logprobs=None, + sampling_masks=sampling_masks, ) ) @@ -100,7 +101,7 @@ def test_convert_prompts_responses_to_batch_tensors_different_lengths(cfg, token rewards = [torch.tensor([1.0, 0.5, 0.3]), torch.tensor([0.8])] loss_masks = [[1, 1, 1], [1]] - sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, ret_log_probs = ( + sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, ret_log_probs, ret_sampling_masks = ( convert_prompts_responses_to_batch_tensors( tokenizer, prompts, diff --git a/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator.py b/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator.py index e807cdb6a9..274f78bae9 100644 --- a/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator.py +++ b/skyrl-train/tests/cpu/generators/test_skyrl_gym_generator.py @@ -368,6 +368,7 @@ def test_generator_output_concatenation(): # optional but present in the signature "trajectory_ids", "is_last_step", + "sampling_masks", ] assert set(GeneratorOutput.__annotations__.keys()) == set(expected_fields), ( "GeneratorOutput fields are not what we expect. " From 5ffee3a35d963736cd7e85ce3662f0e15e6d4bc9 Mon Sep 17 00:00:00 2001 From: Dev Patel Date: Thu, 15 Jan 2026 18:01:54 +0000 Subject: [PATCH 5/5] x --- .../tests/cpu/dataset/test_preprocess.py | 121 ++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/skyrl-train/tests/cpu/dataset/test_preprocess.py b/skyrl-train/tests/cpu/dataset/test_preprocess.py index 7843be8435..be872c841a 100644 --- a/skyrl-train/tests/cpu/dataset/test_preprocess.py +++ b/skyrl-train/tests/cpu/dataset/test_preprocess.py @@ -162,3 +162,124 @@ def test_convert_prompts_responses_to_batch_tensors_mismatched_lengths(cfg, toke rewards, loss_masks, ) + + +def test_convert_prompts_responses_to_batch_tensors_sampling_masks(tokenizer, cfg): + prompts = ["abc", "12"] + outputs = ["def", "3456"] # different response lengths: 3 and 4 + prompts = tokenizer(prompts)["input_ids"] + outputs = tokenizer(outputs)["input_ids"] + + loss_masks = [[1, 1, 1], [1, 1, 1, 1]] + rewards = [torch.tensor([1.0, 1.0, 1.0]), torch.tensor([1.0, 1.0, 1.0, 1.0])] + + sampling_masks = [ + [ + [10, 20], + [30, 40, 50], + [60], + ], + [ + [100, 200, 300], + [400, 500], + [600, 700, 800, 900], + [1000, 1100], + ], + ] + + sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, ret_log_probs, ret_sampling_masks = ( + convert_prompts_responses_to_batch_tensors( + tokenizer, + prompts, + outputs, + rewards, + loss_masks, + logprobs=None, + sampling_masks=sampling_masks, + ) + ) + + assert ret_sampling_masks is not None + + batch_size = len(prompts) + max_response_len = max(len(o) for o in outputs) + max_k = 4 + + assert ret_sampling_masks.shape == (batch_size, max_response_len, max_k) + assert ret_sampling_masks.dtype == torch.int64 + + assert torch.equal( + ret_sampling_masks, + torch.tensor( + [ + [[10, 20, -1, -1], [30, 40, 50, -1], [60, -1, -1, -1], [-1, -1, -1, -1]], + [[100, 200, 300, -1], [400, 500, -1, -1], [600, 700, 800, 900], [1000, 1100, -1, -1]], + ] + ), + ) + + +def test_convert_prompts_responses_to_batch_tensors_no_sampling_masks(tokenizer, cfg): + """Test that when sampling_masks is None, the return value is also None.""" + prompts = ["abc"] + outputs = ["def"] + prompts = tokenizer(prompts)["input_ids"] + outputs = tokenizer(outputs)["input_ids"] + + loss_masks = [[1, 1, 1]] + rewards = [torch.tensor([1.0, 1.0, 1.0])] + + sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, ret_log_probs, ret_sampling_masks = ( + convert_prompts_responses_to_batch_tensors( + tokenizer, + prompts, + outputs, + rewards, + loss_masks, + logprobs=None, + sampling_masks=None, + ) + ) + + # when sampling_masks is None, ret_sampling_masks should also be None + assert ret_sampling_masks is None + + +def test_convert_prompts_responses_to_batch_tensors_empty_sampling_masks(tokenizer, cfg): + """Test that when sampling_masks contains empty lists, it's handled correctly.""" + prompts = ["abc", "de"] + outputs = ["fgh", "ij"] + prompts = tokenizer(prompts)["input_ids"] + outputs = tokenizer(outputs)["input_ids"] + + loss_masks = [[1, 1, 1], [1, 1]] + rewards = [torch.tensor([1.0, 1.0, 1.0]), torch.tensor([1.0, 1.0])] + + # sampling masks with some empty lists (all tokens were filtered out at that step) + sampling_masks = [ + [ + [10, 20], + [], + [30], + ], + [ + [], + [40, 50], + ], + ] + + sequences, attention_mask, action_mask, ret_rewards, ret_loss_masks, ret_log_probs, ret_sampling_masks = ( + convert_prompts_responses_to_batch_tensors( + tokenizer, + prompts, + outputs, + rewards, + loss_masks, + logprobs=None, + sampling_masks=sampling_masks, + ) + ) + + assert torch.equal( + ret_sampling_masks, torch.tensor([[[10, 20], [-1, -1], [30, -1]], [[-1, -1], [40, 50], [-1, -1]]]) + )